Add dependencies

This commit is contained in:
Christine Dodrill 2015-07-31 01:31:38 -07:00
parent bfae5cb265
commit df7052dd40
254 changed files with 42481 additions and 2 deletions

2
.gitignore vendored
View File

@ -21,7 +21,5 @@ _testmain.go
*.exe
/vendor/src
# Data, etc
/data

86
vendor/manifest vendored
View File

@ -1,11 +1,97 @@
{
"version": 0,
"dependencies": [
{
"importpath": "github.com/Xe/middleware",
"repository": "https://github.com/Xe/middleware",
"revision": "7d23200fbed9e7f3be4ac76b4f7f6bd19cc4aba0",
"branch": "master"
},
{
"importpath": "github.com/codegangsta/negroni",
"repository": "https://github.com/codegangsta/negroni",
"revision": "c7477ad8e330bef55bf1ebe300cf8aa67c492d1b",
"branch": "master"
},
{
"importpath": "github.com/disintegration/imaging",
"repository": "https://github.com/disintegration/imaging",
"revision": "3ab6ec550f20d497d2755ed3c48a3e45ad6b7eb9",
"branch": "master"
},
{
"importpath": "github.com/drone/routes",
"repository": "https://github.com/drone/routes",
"revision": "853bef2b231162bb7b09355720416d3af1510d88",
"branch": "master"
},
{
"importpath": "github.com/garyburd/redigo",
"repository": "https://github.com/garyburd/redigo",
"revision": "a47585eaae68b1d14b02940d2af1b9194f3caa9c",
"branch": "master"
},
{
"importpath": "github.com/gorilla/context",
"repository": "https://github.com/gorilla/context",
"revision": "215affda49addc4c8ef7e2534915df2c8c35c6cd",
"branch": "master"
},
{
"importpath": "github.com/gorilla/mux",
"repository": "https://github.com/gorilla/mux",
"revision": "f15e0c49460fd49eebe2bcc8486b05d1bef68d3a",
"branch": "master"
},
{
"importpath": "github.com/jinzhu/gorm",
"repository": "https://github.com/jinzhu/gorm",
"revision": "6a7dda9a32e187c044178aadb0a4510f053a73fa",
"branch": "master"
},
{
"importpath": "github.com/lib/pq",
"repository": "https://github.com/lib/pq",
"revision": "0dad96c0b94f8dee039aa40467f767467392a0af",
"branch": "master"
},
{
"importpath": "github.com/sebest/xff",
"repository": "https://github.com/sebest/xff",
"revision": "d90d345f39f4e84675192d6662c42f33a46ec830",
"branch": "master"
},
{
"importpath": "github.com/thoj/go-ircevent",
"repository": "https://github.com/thoj/go-ircevent",
"revision": "c47f9d8e3db1e137c31efbd755bd563d1bf29efc",
"branch": "master"
},
{
"importpath": "github.com/unrolled/render",
"repository": "https://github.com/unrolled/render",
"revision": "aa61028b1d32873eaa3e261a3ef0e892a153107b",
"branch": "v1"
},
{
"importpath": "github.com/yosssi/ace-proxy",
"repository": "https://github.com/yosssi/ace-proxy",
"revision": "ecd9b785e6023e00b1a451cdf584a30d4eff14c0",
"branch": "master"
},
{
"importpath": "golang.org/x/image/bmp",
"repository": "https://go.googlesource.com/image",
"revision": "5ec5e003b21ac1f06e175898413ada23a6797fc0",
"branch": "master",
"path": "/bmp"
},
{
"importpath": "golang.org/x/image/tiff",
"repository": "https://go.googlesource.com/image",
"revision": "5ec5e003b21ac1f06e175898413ada23a6797fc0",
"branch": "master",
"path": "/tiff"
}
]
}

View File

@ -0,0 +1,2 @@
# middleware
All of the useful middlewares I use

View File

@ -0,0 +1,13 @@
package middleware
import (
"github.com/Xe/middleware/xff"
"github.com/Xe/middleware/xrequestid"
"github.com/codegangsta/negroni"
)
// Inject adds x-request-id and x-forwarded-for support to an existing negroni instance.
func Inject(n *negroni.Negroni) {
n.Use(negroni.HandlerFunc(xff.XFF))
n.Use(xrequestid.New(26))
}

View File

@ -0,0 +1,43 @@
# X-Forwarded-For middleware fo Go [![godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/sebest/xff)
Package `xff` is a `net/http` middleware/handler to parse [Forwarded HTTP Extension](http://tools.ietf.org/html/rfc7239) in Golang.
## Example usage
Install `xff`:
go get github.com/sebest/xff
Edit `server.go`:
```go
package main
import (
"net/http"
"github.com/sebest/xff"
)
func main() {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello from " + r.RemoteAddr + "\n"))
})
http.ListenAndServe(":8080", xff.Handler(handler))
}
```
Then run your server:
go run server.go
The server now runs on `localhost:8080`:
$ curl -D - -H 'X-Forwarded-For: 42.42.42.42' http://localhost:8080/
HTTP/1.1 200 OK
Date: Fri, 20 Feb 2015 20:03:02 GMT
Content-Length: 29
Content-Type: text/plain; charset=utf-8
hello from 42.42.42.42:52661

View File

@ -0,0 +1,23 @@
package main
import (
"net/http"
"github.com/codegangsta/negroni"
"github.com/gorilla/mux"
"github.com/sebest/xff"
)
func main() {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello from " + r.RemoteAddr + "\n"))
})
mux := mux.NewRouter()
mux.Handle("/", handler)
n := negroni.Classic()
n.Use(negroni.HandlerFunc(xff.XFF))
n.UseHandler(mux)
n.Run(":3000")
}

View File

@ -0,0 +1,15 @@
package main
import (
"net/http"
"github.com/sebest/xff"
)
func main() {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello from " + r.RemoteAddr + "\n"))
})
http.ListenAndServe(":3000", xff.Handler(handler))
}

View File

@ -0,0 +1,77 @@
package xff
import (
"net"
"net/http"
"strings"
)
var privateMasks = func() []net.IPNet {
masks := []net.IPNet{}
for _, cidr := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "fc00::/7"} {
_, net, err := net.ParseCIDR(cidr)
if err != nil {
panic(err)
}
masks = append(masks, *net)
}
return masks
}()
// IsPublicIP returns true if the given IP can be routed on the Internet
func IsPublicIP(ip net.IP) bool {
if !ip.IsGlobalUnicast() {
return false
}
for _, mask := range privateMasks {
if mask.Contains(ip) {
return false
}
}
return true
}
// Parse parses the value of the X-Forwarded-For Header and returns the IP address.
func Parse(ipList string) string {
for _, ip := range strings.Split(ipList, ",") {
ip = strings.TrimSpace(ip)
if IP := net.ParseIP(ip); IP != nil && IsPublicIP(IP) {
return ip
}
}
return ""
}
// GetRemoteAddr parses the given request, resolves the X-Forwarded-For header
// and returns the resolved remote address.
func GetRemoteAddr(r *http.Request) string {
xff := r.Header.Get("X-Forwarded-For")
var ip string
if xff != "" {
ip = Parse(xff)
}
_, oport, err := net.SplitHostPort(r.RemoteAddr)
if err == nil && ip != "" {
return net.JoinHostPort(ip, oport)
}
return r.RemoteAddr
}
// Handler is a middleware to update RemoteAdd from X-Fowarded-* Headers.
func Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.RemoteAddr = GetRemoteAddr(r)
h.ServeHTTP(w, r)
})
}
// HandlerFunc is a Martini compatible handler
func HandlerFunc(w http.ResponseWriter, r *http.Request) {
r.RemoteAddr = GetRemoteAddr(r)
}
// XFF is a Negroni compatible interface
func XFF(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
r.RemoteAddr = GetRemoteAddr(r)
next(w, r)
}

View File

@ -0,0 +1,67 @@
package xff
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParse_none(t *testing.T) {
res := Parse("")
assert.Equal(t, "", res)
}
func TestParse_localhost(t *testing.T) {
res := Parse("127.0.0.1")
assert.Equal(t, "", res)
}
func TestParse_invalid(t *testing.T) {
res := Parse("invalid")
assert.Equal(t, "", res)
}
func TestParse_invalid_sioux(t *testing.T) {
res := Parse("123#1#2#3")
assert.Equal(t, "", res)
}
func TestParse_invalid_private_lookalike(t *testing.T) {
res := Parse("102.3.2.1")
assert.Equal(t, "102.3.2.1", res)
}
func TestParse_valid(t *testing.T) {
res := Parse("68.45.152.220")
assert.Equal(t, "68.45.152.220", res)
}
func TestParse_multi_first(t *testing.T) {
res := Parse("12.13.14.15, 68.45.152.220")
assert.Equal(t, "12.13.14.15", res)
}
func TestParse_multi_last(t *testing.T) {
res := Parse("192.168.110.162, 190.57.149.90")
assert.Equal(t, "190.57.149.90", res)
}
func TestParse_multi_with_invalid(t *testing.T) {
res := Parse("192.168.110.162, invalid, 190.57.149.90")
assert.Equal(t, "190.57.149.90", res)
}
func TestParse_multi_with_invalid2(t *testing.T) {
res := Parse("192.168.110.162, 190.57.149.90, invalid")
assert.Equal(t, "190.57.149.90", res)
}
func TestParse_multi_with_invalid_sioux(t *testing.T) {
res := Parse("192.168.110.162, 190.57.149.90, 123#1#2#3")
assert.Equal(t, "190.57.149.90", res)
}
func TestParse_ipv6_with_port(t *testing.T) {
res := Parse("2604:2000:71a9:bf00:f178:a500:9a2d:670d")
assert.Equal(t, "2604:2000:71a9:bf00:f178:a500:9a2d:670d", res)
}

View File

@ -0,0 +1,22 @@
The MIT License (MIT)
Copyright (c) 2014 Andrea Franz (http://gravityblast.com)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,10 @@
GO_CMD=go
GOLINT_CMD=golint
GO_TEST=$(GO_CMD) test -v ./...
GO_VET=$(GO_CMD) vet ./...
GO_LINT=$(GOLINT_CMD) ./...
all:
$(GO_VET)
$(GO_LINT)
$(GO_TEST)

View File

@ -0,0 +1,5 @@
# xrequestid
> Package xrequestid implements an http middleware for Negroni that assigns a random id to each request. It's written in the Go programming language.
Docs at http://godoc.org/github.com/pilu/xrequestid

View File

@ -0,0 +1,76 @@
// Package xrequestid implements an http middleware for Negroni that assigns a random id to each request
//
// Example:
// package main
//
// import (
// "fmt"
// "net/http"
//
// "github.com/codegangsta/negroni"
// "github.com/pilu/xrequestid"
// )
//
// func main() {
// mux := http.NewServeMux()
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// fmt.Fprintf(w, "X-Request-Id is `%s`", r.Header.Get("X-Request-Id"))
// })
//
// n := negroni.New()
// n.Use(xrequestid.New(16))
// n.UseHandler(mux)
// n.Run(":3000")
// }
package xrequestid
import (
"crypto/rand"
"encoding/hex"
"net/http"
)
// By default the middleware set the generated random string to this key in the request header
const DefaultHeaderKey = "X-Request-Id"
// GenerateFunc is the func used by the middleware to generates the random string.
type GenerateFunc func(int) (string, error)
// XRequestID is a middleware that adds a random ID to the request X-Request-Id header
type XRequestID struct {
// Size specifies the length of the random length. The length of the result string is twice of n.
Size int
// Generate is a GenerateFunc that generates the random string. The default one uses crypto/rand
Generate GenerateFunc
// HeaderKey is the header name where the middleware set the random string. By default it uses the DefaultHeaderKey constant value
HeaderKey string
}
// New returns a new XRequestID middleware instance. n specifies the length of the random length. The length of the result string is twice of n.
func New(n int) *XRequestID {
return &XRequestID{
Size: n,
Generate: generateID,
HeaderKey: DefaultHeaderKey,
}
}
func (m *XRequestID) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
id, err := m.Generate(m.Size)
if err == nil {
r.Header.Set(m.HeaderKey, id)
rw.Header().Set(m.HeaderKey, id)
}
next(rw, r)
}
func generateID(n int) (string, error) {
r := make([]byte, n)
_, err := rand.Read(r)
if err != nil {
return "", err
}
return hex.EncodeToString(r), nil
}

View File

@ -0,0 +1,20 @@
package xrequestid
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestXRequestID(t *testing.T) {
recorder := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
middleware := New(16)
middleware.Generate = func(n int) (string, error) { return "test-id", nil }
middleware.ServeHTTP(recorder, req, func(w http.ResponseWriter, r *http.Request) {})
if id := req.Header.Get("X-Request-ID"); id != "test-id" {
t.Fatalf("Expected X-Request-Id to be `test-id`, got `%v`", id)
}
}

View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Jeremy Saenz
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,181 @@
# Negroni [![GoDoc](https://godoc.org/github.com/codegangsta/negroni?status.svg)](http://godoc.org/github.com/codegangsta/negroni) [![wercker status](https://app.wercker.com/status/13688a4a94b82d84a0b8d038c4965b61/s "wercker status")](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
Negroni is an idiomatic approach to web middleware in Go. It is tiny, non-intrusive, and encourages use of `net/http` Handlers.
If you like the idea of [Martini](http://github.com/go-martini/martini), but you think it contains too much magic, then Negroni is a great fit.
Language Translations:
* [Português Brasileiro (pt_BR)](translations/README_pt_br.md)
## Getting Started
After installing Go and setting up your [GOPATH](http://golang.org/doc/code.html#GOPATH), create your first `.go` file. We'll call it `server.go`.
~~~ go
package main
import (
"github.com/codegangsta/negroni"
"net/http"
"fmt"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.Classic()
n.UseHandler(mux)
n.Run(":3000")
}
~~~
Then install the Negroni package (**go 1.1** and greater is required):
~~~
go get github.com/codegangsta/negroni
~~~
Then run your server:
~~~
go run server.go
~~~
You will now have a Go net/http webserver running on `localhost:3000`.
## Need Help?
If you have a question or feature request, [go ask the mailing list](https://groups.google.com/forum/#!forum/negroni-users). The GitHub issues for Negroni will be used exclusively for bug reports and pull requests.
## Is Negroni a Framework?
Negroni is **not** a framework. It is a library that is designed to work directly with net/http.
## Routing?
Negroni is BYOR (Bring your own Router). The Go community already has a number of great http routers available, Negroni tries to play well with all of them by fully supporting `net/http`. For instance, integrating with [Gorilla Mux](http://github.com/gorilla/mux) looks like so:
~~~ go
router := mux.NewRouter()
router.HandleFunc("/", HomeHandler)
n := negroni.New(Middleware1, Middleware2)
// Or use a middleware with the Use() function
n.Use(Middleware3)
// router goes last
n.UseHandler(router)
n.Run(":3000")
~~~
## `negroni.Classic()`
`negroni.Classic()` provides some default middleware that is useful for most applications:
* `negroni.Recovery` - Panic Recovery Middleware.
* `negroni.Logging` - Request/Response Logging Middleware.
* `negroni.Static` - Static File serving under the "public" directory.
This makes it really easy to get started with some useful features from Negroni.
## Handlers
Negroni provides a bidirectional middleware flow. This is done through the `negroni.Handler` interface:
~~~ go
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}
~~~
If a middleware hasn't already written to the ResponseWriter, it should call the next `http.HandlerFunc` in the chain to yield to the next middleware handler. This can be used for great good:
~~~ go
func MyMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// do some stuff before
next(rw, r)
// do some stuff after
}
~~~
And you can map it to the handler chain with the `Use` function:
~~~ go
n := negroni.New()
n.Use(negroni.HandlerFunc(MyMiddleware))
~~~
You can also map plain old `http.Handler`s:
~~~ go
n := negroni.New()
mux := http.NewServeMux()
// map your routes
n.UseHandler(mux)
n.Run(":3000")
~~~
## `Run()`
Negroni has a convenience function called `Run`. `Run` takes an addr string identical to [http.ListenAndServe](http://golang.org/pkg/net/http#ListenAndServe).
~~~ go
n := negroni.Classic()
// ...
log.Fatal(http.ListenAndServe(":8080", n))
~~~
## Route Specific Middleware
If you have a route group of routes that need specific middleware to be executed, you can simply create a new Negroni instance and use it as your route handler.
~~~ go
router := mux.NewRouter()
adminRoutes := mux.NewRouter()
// add admin routes here
// Create a new negroni for the admin middleware
router.Handle("/admin", negroni.New(
Middleware1,
Middleware2,
negroni.Wrap(adminRoutes),
))
~~~
## Third Party Middleware
Here is a current list of Negroni compatible middlware. Feel free to put up a PR linking your middleware if you have built one:
| Middleware | Author | Description |
| -----------|--------|-------------|
| [RestGate](https://github.com/pjebs/restgate) | [Prasanga Siripala](https://github.com/pjebs) | Secure authentication for REST API endpoints |
| [Graceful](https://github.com/stretchr/graceful) | [Tyler Bunnell](https://github.com/tylerb) | Graceful HTTP Shutdown |
| [secure](https://github.com/unrolled/secure) | [Cory Jacobsen](https://github.com/unrolled) | Middleware that implements a few quick security wins |
| [JWT Middleware](https://github.com/auth0/go-jwt-middleware) | [Auth0](https://github.com/auth0) | Middleware checks for a JWT on the `Authorization` header on incoming requests and decodes it|
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | Data binding from HTTP requests into structs |
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | Logrus-based logger |
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | Render JSON, XML and HTML templates |
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | New Relic agent for Go runtime |
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | GZIP response compression |
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | oAuth2 middleware |
| [sessions](https://github.com/goincremental/negroni-sessions) | [David Bochenski](https://github.com/bochenski) | Session Management |
| [permissions2](https://github.com/xyproto/permissions2) | [Alexander Rødseth](https://github.com/xyproto) | Cookies, users and permissions |
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | Generate TinySVG, HTML and CSS on the fly |
| [cors](https://github.com/rs/cors) | [Olivier Poitrey](https://github.com/rs) | [Cross Origin Resource Sharing](http://www.w3.org/TR/cors/) (CORS) support |
| [xrequestid](https://github.com/pilu/xrequestid) | [Andrea Franz](https://github.com/pilu) | Middleware that assigns a random X-Request-Id header to each request |
| [VanGoH](https://github.com/auroratechnologies/vangoh) | [Taylor Wrobel](https://github.com/twrobel3) | Configurable [AWS-Style](http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) HMAC authentication middleware |
| [stats](https://github.com/thoas/stats) | [Florent Messa](https://github.com/thoas) | Store information about your web application (response time, etc.) |
## Examples
[Alexander Rødseth](https://github.com/xyproto) created [mooseware](https://github.com/xyproto/mooseware), a skeleton for writing a Negroni middleware handler.
## Live code reload?
[gin](https://github.com/codegangsta/gin) and [fresh](https://github.com/pilu/fresh) both live reload negroni apps.
## Essential Reading for Beginners of Go & Negroni
* [Using a Context to pass information from middleware to end handler](http://elithrar.github.io/article/map-string-interface/)
* [Understanding middleware](http://mattstauffer.co/blog/laravel-5.0-middleware-replacing-filters)
## About
Negroni is obsessively designed by none other than the [Code Gangsta](http://codegangsta.io/)

View File

@ -0,0 +1,25 @@
// Package negroni is an idiomatic approach to web middleware in Go. It is tiny, non-intrusive, and encourages use of net/http Handlers.
//
// If you like the idea of Martini, but you think it contains too much magic, then Negroni is a great fit.
//
// For a full guide visit http://github.com/codegangsta/negroni
//
// package main
//
// import (
// "github.com/codegangsta/negroni"
// "net/http"
// "fmt"
// )
//
// func main() {
// mux := http.NewServeMux()
// mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
// fmt.Fprintf(w, "Welcome to the home page!")
// })
//
// n := negroni.Classic()
// n.UseHandler(mux)
// n.Run(":3000")
// }
package negroni

View File

@ -0,0 +1,29 @@
package negroni
import (
"log"
"net/http"
"os"
"time"
)
// Logger is a middleware handler that logs the request as it goes in and the response as it goes out.
type Logger struct {
// Logger inherits from log.Logger used to log messages with the Logger middleware
*log.Logger
}
// NewLogger returns a new Logger instance
func NewLogger() *Logger {
return &Logger{log.New(os.Stdout, "[negroni] ", 0)}
}
func (l *Logger) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
start := time.Now()
l.Printf("Started %s %s", r.Method, r.URL.Path)
next(rw, r)
res := rw.(ResponseWriter)
l.Printf("Completed %v %s in %v", res.Status(), http.StatusText(res.Status()), time.Since(start))
}

View File

@ -0,0 +1,33 @@
package negroni
import (
"bytes"
"log"
"net/http"
"net/http/httptest"
"testing"
)
func Test_Logger(t *testing.T) {
buff := bytes.NewBufferString("")
recorder := httptest.NewRecorder()
l := NewLogger()
l.Logger = log.New(buff, "[negroni] ", 0)
n := New()
// replace log for testing
n.Use(l)
n.UseHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNotFound)
}))
req, err := http.NewRequest("GET", "http://localhost:3000/foobar", nil)
if err != nil {
t.Error(err)
}
n.ServeHTTP(recorder, req)
expect(t, recorder.Code, http.StatusNotFound)
refute(t, len(buff.String()), 0)
}

View File

@ -0,0 +1,129 @@
package negroni
import (
"log"
"net/http"
"os"
)
// Handler handler is an interface that objects can implement to be registered to serve as middleware
// in the Negroni middleware stack.
// ServeHTTP should yield to the next middleware in the chain by invoking the next http.HandlerFunc
// passed in.
//
// If the Handler writes to the ResponseWriter, the next http.HandlerFunc should not be invoked.
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}
// HandlerFunc is an adapter to allow the use of ordinary functions as Negroni handlers.
// If f is a function with the appropriate signature, HandlerFunc(f) is a Handler object that calls f.
type HandlerFunc func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
func (h HandlerFunc) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
h(rw, r, next)
}
type middleware struct {
handler Handler
next *middleware
}
func (m middleware) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
m.handler.ServeHTTP(rw, r, m.next.ServeHTTP)
}
// Wrap converts a http.Handler into a negroni.Handler so it can be used as a Negroni
// middleware. The next http.HandlerFunc is automatically called after the Handler
// is executed.
func Wrap(handler http.Handler) Handler {
return HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
handler.ServeHTTP(rw, r)
next(rw, r)
})
}
// Negroni is a stack of Middleware Handlers that can be invoked as an http.Handler.
// Negroni middleware is evaluated in the order that they are added to the stack using
// the Use and UseHandler methods.
type Negroni struct {
middleware middleware
handlers []Handler
}
// New returns a new Negroni instance with no middleware preconfigured.
func New(handlers ...Handler) *Negroni {
return &Negroni{
handlers: handlers,
middleware: build(handlers),
}
}
// Classic returns a new Negroni instance with the default middleware already
// in the stack.
//
// Recovery - Panic Recovery Middleware
// Logger - Request/Response Logging
// Static - Static File Serving
func Classic() *Negroni {
return New(NewRecovery(), NewLogger(), NewStatic(http.Dir("public")))
}
func (n *Negroni) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
n.middleware.ServeHTTP(NewResponseWriter(rw), r)
}
// Use adds a Handler onto the middleware stack. Handlers are invoked in the order they are added to a Negroni.
func (n *Negroni) Use(handler Handler) {
n.handlers = append(n.handlers, handler)
n.middleware = build(n.handlers)
}
// UseFunc adds a Negroni-style handler function onto the middleware stack.
func (n *Negroni) UseFunc(handlerFunc func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)) {
n.Use(HandlerFunc(handlerFunc))
}
// UseHandler adds a http.Handler onto the middleware stack. Handlers are invoked in the order they are added to a Negroni.
func (n *Negroni) UseHandler(handler http.Handler) {
n.Use(Wrap(handler))
}
// UseHandler adds a http.HandlerFunc-style handler function onto the middleware stack.
func (n *Negroni) UseHandlerFunc(handlerFunc func(rw http.ResponseWriter, r *http.Request)) {
n.UseHandler(http.HandlerFunc(handlerFunc))
}
// Run is a convenience function that runs the negroni stack as an HTTP
// server. The addr string takes the same format as http.ListenAndServe.
func (n *Negroni) Run(addr string) {
l := log.New(os.Stdout, "[negroni] ", 0)
l.Printf("listening on %s", addr)
l.Fatal(http.ListenAndServe(addr, n))
}
// Returns a list of all the handlers in the current Negroni middleware chain.
func (n *Negroni) Handlers() []Handler {
return n.handlers
}
func build(handlers []Handler) middleware {
var next middleware
if len(handlers) == 0 {
return voidMiddleware()
} else if len(handlers) > 1 {
next = build(handlers[1:])
} else {
next = voidMiddleware()
}
return middleware{handlers[0], &next}
}
func voidMiddleware() middleware {
return middleware{
HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {}),
&middleware{},
}
}

View File

@ -0,0 +1,75 @@
package negroni
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}
func refute(t *testing.T, a interface{}, b interface{}) {
if a == b {
t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}
func TestNegroniRun(t *testing.T) {
// just test that Run doesn't bomb
go New().Run(":3000")
}
func TestNegroniServeHTTP(t *testing.T) {
result := ""
response := httptest.NewRecorder()
n := New()
n.Use(HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
result += "foo"
next(rw, r)
result += "ban"
}))
n.Use(HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
result += "bar"
next(rw, r)
result += "baz"
}))
n.Use(HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
result += "bat"
rw.WriteHeader(http.StatusBadRequest)
}))
n.ServeHTTP(response, (*http.Request)(nil))
expect(t, result, "foobarbatbazban")
expect(t, response.Code, http.StatusBadRequest)
}
// Ensures that a Negroni middleware chain
// can correctly return all of its handlers.
func TestHandlers(t *testing.T) {
response := httptest.NewRecorder()
n := New()
handlers := n.Handlers()
expect(t, 0, len(handlers))
n.Use(HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
rw.WriteHeader(http.StatusOK)
}))
// Expects the length of handlers to be exactly 1
// after adding exactly one handler to the middleware chain
handlers = n.Handlers()
expect(t, 1, len(handlers))
// Ensures that the first handler that is in sequence behaves
// exactly the same as the one that was registered earlier
handlers[0].ServeHTTP(response, (*http.Request)(nil), nil)
expect(t, response.Code, http.StatusOK)
}

View File

@ -0,0 +1,46 @@
package negroni
import (
"fmt"
"log"
"net/http"
"os"
"runtime"
)
// Recovery is a Negroni middleware that recovers from any panics and writes a 500 if there was one.
type Recovery struct {
Logger *log.Logger
PrintStack bool
StackAll bool
StackSize int
}
// NewRecovery returns a new instance of Recovery
func NewRecovery() *Recovery {
return &Recovery{
Logger: log.New(os.Stdout, "[negroni] ", 0),
PrintStack: true,
StackAll: false,
StackSize: 1024 * 8,
}
}
func (rec *Recovery) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
defer func() {
if err := recover(); err != nil {
rw.WriteHeader(http.StatusInternalServerError)
stack := make([]byte, rec.StackSize)
stack = stack[:runtime.Stack(stack, rec.StackAll)]
f := "PANIC: %s\n%s"
rec.Logger.Printf(f, err, stack)
if rec.PrintStack {
fmt.Fprintf(rw, f, err, stack)
}
}
}()
next(rw, r)
}

View File

@ -0,0 +1,28 @@
package negroni
import (
"bytes"
"log"
"net/http"
"net/http/httptest"
"testing"
)
func TestRecovery(t *testing.T) {
buff := bytes.NewBufferString("")
recorder := httptest.NewRecorder()
rec := NewRecovery()
rec.Logger = log.New(buff, "[negroni] ", 0)
n := New()
// replace log for testing
n.Use(rec)
n.UseHandler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
panic("here is a panic!")
}))
n.ServeHTTP(recorder, (*http.Request)(nil))
expect(t, recorder.Code, http.StatusInternalServerError)
refute(t, recorder.Body.Len(), 0)
refute(t, len(buff.String()), 0)
}

View File

@ -0,0 +1,96 @@
package negroni
import (
"bufio"
"fmt"
"net"
"net/http"
)
// ResponseWriter is a wrapper around http.ResponseWriter that provides extra information about
// the response. It is recommended that middleware handlers use this construct to wrap a responsewriter
// if the functionality calls for it.
type ResponseWriter interface {
http.ResponseWriter
http.Flusher
// Status returns the status code of the response or 0 if the response has not been written.
Status() int
// Written returns whether or not the ResponseWriter has been written.
Written() bool
// Size returns the size of the response body.
Size() int
// Before allows for a function to be called before the ResponseWriter has been written to. This is
// useful for setting headers or any other operations that must happen before a response has been written.
Before(func(ResponseWriter))
}
type beforeFunc func(ResponseWriter)
// NewResponseWriter creates a ResponseWriter that wraps an http.ResponseWriter
func NewResponseWriter(rw http.ResponseWriter) ResponseWriter {
return &responseWriter{rw, 0, 0, nil}
}
type responseWriter struct {
http.ResponseWriter
status int
size int
beforeFuncs []beforeFunc
}
func (rw *responseWriter) WriteHeader(s int) {
rw.status = s
rw.callBefore()
rw.ResponseWriter.WriteHeader(s)
}
func (rw *responseWriter) Write(b []byte) (int, error) {
if !rw.Written() {
// The status will be StatusOK if WriteHeader has not been called yet
rw.WriteHeader(http.StatusOK)
}
size, err := rw.ResponseWriter.Write(b)
rw.size += size
return size, err
}
func (rw *responseWriter) Status() int {
return rw.status
}
func (rw *responseWriter) Size() int {
return rw.size
}
func (rw *responseWriter) Written() bool {
return rw.status != 0
}
func (rw *responseWriter) Before(before func(ResponseWriter)) {
rw.beforeFuncs = append(rw.beforeFuncs, before)
}
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface")
}
return hijacker.Hijack()
}
func (rw *responseWriter) CloseNotify() <-chan bool {
return rw.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
func (rw *responseWriter) callBefore() {
for i := len(rw.beforeFuncs) - 1; i >= 0; i-- {
rw.beforeFuncs[i](rw)
}
}
func (rw *responseWriter) Flush() {
flusher, ok := rw.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}

View File

@ -0,0 +1,150 @@
package negroni
import (
"bufio"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)
type closeNotifyingRecorder struct {
*httptest.ResponseRecorder
closed chan bool
}
func newCloseNotifyingRecorder() *closeNotifyingRecorder {
return &closeNotifyingRecorder{
httptest.NewRecorder(),
make(chan bool, 1),
}
}
func (c *closeNotifyingRecorder) close() {
c.closed <- true
}
func (c *closeNotifyingRecorder) CloseNotify() <-chan bool {
return c.closed
}
type hijackableResponse struct {
Hijacked bool
}
func newHijackableResponse() *hijackableResponse {
return &hijackableResponse{}
}
func (h *hijackableResponse) Header() http.Header { return nil }
func (h *hijackableResponse) Write(buf []byte) (int, error) { return 0, nil }
func (h *hijackableResponse) WriteHeader(code int) {}
func (h *hijackableResponse) Flush() {}
func (h *hijackableResponse) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h.Hijacked = true
return nil, nil, nil
}
func TestResponseWriterWritingString(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)
rw.Write([]byte("Hello world"))
expect(t, rec.Code, rw.Status())
expect(t, rec.Body.String(), "Hello world")
expect(t, rw.Status(), http.StatusOK)
expect(t, rw.Size(), 11)
expect(t, rw.Written(), true)
}
func TestResponseWriterWritingStrings(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)
rw.Write([]byte("Hello world"))
rw.Write([]byte("foo bar bat baz"))
expect(t, rec.Code, rw.Status())
expect(t, rec.Body.String(), "Hello worldfoo bar bat baz")
expect(t, rw.Status(), http.StatusOK)
expect(t, rw.Size(), 26)
}
func TestResponseWriterWritingHeader(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)
rw.WriteHeader(http.StatusNotFound)
expect(t, rec.Code, rw.Status())
expect(t, rec.Body.String(), "")
expect(t, rw.Status(), http.StatusNotFound)
expect(t, rw.Size(), 0)
}
func TestResponseWriterBefore(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)
result := ""
rw.Before(func(ResponseWriter) {
result += "foo"
})
rw.Before(func(ResponseWriter) {
result += "bar"
})
rw.WriteHeader(http.StatusNotFound)
expect(t, rec.Code, rw.Status())
expect(t, rec.Body.String(), "")
expect(t, rw.Status(), http.StatusNotFound)
expect(t, rw.Size(), 0)
expect(t, result, "barfoo")
}
func TestResponseWriterHijack(t *testing.T) {
hijackable := newHijackableResponse()
rw := NewResponseWriter(hijackable)
hijacker, ok := rw.(http.Hijacker)
expect(t, ok, true)
_, _, err := hijacker.Hijack()
if err != nil {
t.Error(err)
}
expect(t, hijackable.Hijacked, true)
}
func TestResponseWriteHijackNotOK(t *testing.T) {
hijackable := new(http.ResponseWriter)
rw := NewResponseWriter(*hijackable)
hijacker, ok := rw.(http.Hijacker)
expect(t, ok, true)
_, _, err := hijacker.Hijack()
refute(t, err, nil)
}
func TestResponseWriterCloseNotify(t *testing.T) {
rec := newCloseNotifyingRecorder()
rw := NewResponseWriter(rec)
closed := false
notifier := rw.(http.CloseNotifier).CloseNotify()
rec.close()
select {
case <-notifier:
closed = true
case <-time.After(time.Second):
}
expect(t, closed, true)
}
func TestResponseWriterFlusher(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)
_, ok := rw.(http.Flusher)
expect(t, ok, true)
}

View File

@ -0,0 +1,84 @@
package negroni
import (
"net/http"
"path"
"strings"
)
// Static is a middleware handler that serves static files in the given directory/filesystem.
type Static struct {
// Dir is the directory to serve static files from
Dir http.FileSystem
// Prefix is the optional prefix used to serve the static directory content
Prefix string
// IndexFile defines which file to serve as index if it exists.
IndexFile string
}
// NewStatic returns a new instance of Static
func NewStatic(directory http.FileSystem) *Static {
return &Static{
Dir: directory,
Prefix: "",
IndexFile: "index.html",
}
}
func (s *Static) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if r.Method != "GET" && r.Method != "HEAD" {
next(rw, r)
return
}
file := r.URL.Path
// if we have a prefix, filter requests by stripping the prefix
if s.Prefix != "" {
if !strings.HasPrefix(file, s.Prefix) {
next(rw, r)
return
}
file = file[len(s.Prefix):]
if file != "" && file[0] != '/' {
next(rw, r)
return
}
}
f, err := s.Dir.Open(file)
if err != nil {
// discard the error?
next(rw, r)
return
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
next(rw, r)
return
}
// try to serve index file
if fi.IsDir() {
// redirect if missing trailing slash
if !strings.HasSuffix(r.URL.Path, "/") {
http.Redirect(rw, r, r.URL.Path+"/", http.StatusFound)
return
}
file = path.Join(file, s.IndexFile)
f, err = s.Dir.Open(file)
if err != nil {
next(rw, r)
return
}
defer f.Close()
fi, err = f.Stat()
if err != nil || fi.IsDir() {
next(rw, r)
return
}
}
http.ServeContent(rw, r, file, fi.ModTime(), f)
}

View File

@ -0,0 +1,113 @@
package negroni
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)
func TestStatic(t *testing.T) {
response := httptest.NewRecorder()
response.Body = new(bytes.Buffer)
n := New()
n.Use(NewStatic(http.Dir(".")))
req, err := http.NewRequest("GET", "http://localhost:3000/negroni.go", nil)
if err != nil {
t.Error(err)
}
n.ServeHTTP(response, req)
expect(t, response.Code, http.StatusOK)
expect(t, response.Header().Get("Expires"), "")
if response.Body.Len() == 0 {
t.Errorf("Got empty body for GET request")
}
}
func TestStaticHead(t *testing.T) {
response := httptest.NewRecorder()
response.Body = new(bytes.Buffer)
n := New()
n.Use(NewStatic(http.Dir(".")))
n.UseHandler(http.NotFoundHandler())
req, err := http.NewRequest("HEAD", "http://localhost:3000/negroni.go", nil)
if err != nil {
t.Error(err)
}
n.ServeHTTP(response, req)
expect(t, response.Code, http.StatusOK)
if response.Body.Len() != 0 {
t.Errorf("Got non-empty body for HEAD request")
}
}
func TestStaticAsPost(t *testing.T) {
response := httptest.NewRecorder()
n := New()
n.Use(NewStatic(http.Dir(".")))
n.UseHandler(http.NotFoundHandler())
req, err := http.NewRequest("POST", "http://localhost:3000/negroni.go", nil)
if err != nil {
t.Error(err)
}
n.ServeHTTP(response, req)
expect(t, response.Code, http.StatusNotFound)
}
func TestStaticBadDir(t *testing.T) {
response := httptest.NewRecorder()
n := Classic()
n.UseHandler(http.NotFoundHandler())
req, err := http.NewRequest("GET", "http://localhost:3000/negroni.go", nil)
if err != nil {
t.Error(err)
}
n.ServeHTTP(response, req)
refute(t, response.Code, http.StatusOK)
}
func TestStaticOptionsServeIndex(t *testing.T) {
response := httptest.NewRecorder()
n := New()
s := NewStatic(http.Dir("."))
s.IndexFile = "negroni.go"
n.Use(s)
req, err := http.NewRequest("GET", "http://localhost:3000/", nil)
if err != nil {
t.Error(err)
}
n.ServeHTTP(response, req)
expect(t, response.Code, http.StatusOK)
}
func TestStaticOptionsPrefix(t *testing.T) {
response := httptest.NewRecorder()
n := New()
s := NewStatic(http.Dir("."))
s.Prefix = "/public"
n.Use(s)
// Check file content behaviour
req, err := http.NewRequest("GET", "http://localhost:3000/public/negroni.go", nil)
if err != nil {
t.Error(err)
}
n.ServeHTTP(response, req)
expect(t, response.Code, http.StatusOK)
}

View File

@ -0,0 +1,170 @@
# Negroni [![GoDoc](https://godoc.org/github.com/codegangsta/negroni?status.svg)](http://godoc.org/github.com/codegangsta/negroni) [![wercker status](https://app.wercker.com/status/13688a4a94b82d84a0b8d038c4965b61/s "wercker status")](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
Negroni é uma abordagem idiomática para middleware web em Go. É pequeno, não intrusivo, e incentiva uso da biblioteca `net/http`.
Se gosta da idéia do [Martini](http://github.com/go-martini/martini), mas acha que contém muita mágica, então Negroni é ideal.
## Começando
Depois de instalar Go e definir seu [GOPATH](http://golang.org/doc/code.html#GOPATH), criar seu primeirto arquivo `.go`. Iremos chamá-lo `server.go`.
~~~ go
package main
import (
"github.com/codegangsta/negroni"
"net/http"
"fmt"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.Classic()
n.UseHandler(mux)
n.Run(":3000")
}
~~~
Depois instale o pacote Negroni (**go 1.1** ou superior)
~~~
go get github.com/codegangsta/negroni
~~~
Depois execute seu servidor:
~~~
go run server.go
~~~
Agora terá um servidor web Go net/http rodando em `localhost:3000`.
## Precisa de Ajuda?
Se você tem uma pergunta ou pedido de recurso,[go ask the mailing list](https://groups.google.com/forum/#!forum/negroni-users). O Github issues para o Negroni será usado exclusivamente para Reportar bugs e pull requests.
## Negroni é um Framework?
Negroni **não** é a framework. É uma biblioteca que é desenhada para trabalhar diretamente com net/http.
## Roteamento?
Negroni é TSPR(Traga seu próprio Roteamento). A comunidade Go já tem um grande número de roteadores http disponíveis, Negroni tenta rodar bem com todos eles pelo suporte total `net/http`/ Por exemplo, a integração com [Gorilla Mux](http://github.com/gorilla/mux) se parece com isso:
~~~ go
router := mux.NewRouter()
router.HandleFunc("/", HomeHandler)
n := negroni.New(Middleware1, Middleware2)
// Or use a middleware with the Use() function
n.Use(Middleware3)
// router goes last
n.UseHandler(router)
n.Run(":3000")
~~~
## `negroni.Classic()`
`negroni.Classic()` fornece alguns middlewares padrão que são úteis para maioria das aplicações:
* `negroni.Recovery` - Panic Recovery Middleware.
* `negroni.Logging` - Request/Response Logging Middleware.
* `negroni.Static` - Static File serving under the "public" directory.
Isso torna muito fácil começar com alguns recursos úteis do Negroni.
## Handlers
Negroni fornece um middleware de fluxo bidirecional. Isso é feito através da interface `negroni.Handler`:
~~~ go
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}
~~~
Se um middleware não tenha escrito o ResponseWriter, ele deve chamar a próxima `http.HandlerFunc` na cadeia para produzir o próximo handler middleware. Isso pode ser usado muito bem:
~~~ go
func MyMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// do some stuff before
next(rw, r)
// do some stuff after
}
~~~
E pode mapear isso para a cadeia de handler com a função `Use`:
~~~ go
n := negroni.New()
n.Use(negroni.HandlerFunc(MyMiddleware))
~~~
Você também pode mapear `http.Handler` antigos:
~~~ go
n := negroni.New()
mux := http.NewServeMux()
// map your routes
n.UseHandler(mux)
n.Run(":3000")
~~~
## `Run()`
Negroni tem uma função de conveniência chamada `Run`. `Run` pega um endereço de string idêntico para [http.ListenAndServe](http://golang.org/pkg/net/http#ListenAndServe).
~~~ go
n := negroni.Classic()
// ...
log.Fatal(http.ListenAndServe(":8080", n))
~~~
## Middleware para Rotas Específicas
Se você tem um grupo de rota com rotas que precisam ser executadas por um middleware específico, pode simplesmente criar uma nova instância de Negroni e usar no seu Manipulador de rota.
~~~ go
router := mux.NewRouter()
adminRoutes := mux.NewRouter()
// add admin routes here
// Criar um middleware negroni para admin
router.Handle("/admin", negroni.New(
Middleware1,
Middleware2,
negroni.Wrap(adminRoutes),
))
~~~
## Middleware de Terceiros
Aqui está uma lista atual de Middleware Compatíveis com Negroni. Sinta se livre para mandar um PR vinculando seu middleware se construiu um:
| Middleware | Autor | Descrição |
| -----------|--------|-------------|
| [Graceful](https://github.com/stretchr/graceful) | [Tyler Bunnell](https://github.com/tylerb) | Graceful HTTP Shutdown |
| [secure](https://github.com/unrolled/secure) | [Cory Jacobsen](https://github.com/unrolled) | Implementa rapidamente itens de segurança.|
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | Handler para mapeamento/validação de um request a estrutura. |
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | Logrus-based logger |
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | Pacote para renderizar JSON, XML, e templates HTML. |
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | New Relic agent for Go runtime |
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | Handler para adicionar compreção gzip para as requisições |
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | Handler que prove sistema de login OAuth 2.0 para aplicações Martini. Google Sign-in, Facebook Connect e Github login são suportados. |
| [sessions](https://github.com/goincremental/negroni-sessions) | [David Bochenski](https://github.com/bochenski) | Handler que provê o serviço de sessão. |
| [permissions](https://github.com/xyproto/permissions) | [Alexander Rødseth](https://github.com/xyproto) | Cookies, usuários e permissões. |
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | Pacote para gerar TinySVG, HTML e CSS em tempo real. |
## Exemplos
[Alexander Rødseth](https://github.com/xyproto) criou [mooseware](https://github.com/xyproto/mooseware), uma estrutura para escrever um handler middleware Negroni.
## Servidor com autoreload?
[gin](https://github.com/codegangsta/gin) e [fresh](https://github.com/pilu/fresh) são aplicativos para autoreload do Negroni.
## Leitura Essencial para Iniciantes em Go & Negroni
* [Usando um contexto para passar informação de um middleware para o manipulador final](http://elithrar.github.io/article/map-string-interface/)
* [Entendendo middleware](http://mattstauffer.co/blog/laravel-5.0-middleware-replacing-filters)
## Sobre
Negroni é obsessivamente desenhado por ninguém menos que [Code Gangsta](http://codegangsta.io/)

View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2012-2014 Grigory Dryapak
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,163 @@
# Imaging
Package imaging provides basic image manipulation functions (resize, rotate, flip, crop, etc.).
This package is based on the standard Go image package and works best along with it.
Image manipulation functions provided by the package take any image type
that implements `image.Image` interface as an input, and return a new image of
`*image.NRGBA` type (32bit RGBA colors, not premultiplied by alpha).
## Installation
Imaging requires Go version 1.2 or greater.
go get -u github.com/disintegration/imaging
## Documentation
http://godoc.org/github.com/disintegration/imaging
## Usage examples
A few usage examples can be found below. See the documentation for the full list of supported functions.
### Image resizing
```go
// resize srcImage to size = 128x128px using the Lanczos filter
dstImage128 := imaging.Resize(srcImage, 128, 128, imaging.Lanczos)
// resize srcImage to width = 800px preserving the aspect ratio
dstImage800 := imaging.Resize(srcImage, 800, 0, imaging.Lanczos)
// scale down srcImage to fit the 800x600px bounding box
dstImageFit := imaging.Fit(srcImage, 800, 600, imaging.Lanczos)
// resize and crop the srcImage to make a 100x100px thumbnail
dstImageThumb := imaging.Thumbnail(srcImage, 100, 100, imaging.Lanczos)
```
Imaging supports image resizing using various resampling filters. The most notable ones:
- `NearestNeighbor` - Fastest resampling filter, no antialiasing.
- `Box` - Simple and fast averaging filter appropriate for downscaling. When upscaling it's similar to NearestNeighbor.
- `Linear` - Bilinear filter, smooth and reasonably fast.
- `MitchellNetravali` - А smooth bicubic filter.
- `CatmullRom` - A sharp bicubic filter.
- `Gaussian` - Blurring filter that uses gaussian function, useful for noise removal.
- `Lanczos` - High-quality resampling filter for photographic images yielding sharp results, but it's slower than cubic filters.
The full list of supported filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali, CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine. Custom filters can be created using ResampleFilter struct.
**Resampling filters comparison**
Original image. Will be resized from 512x512px to 128x128px.
![srcImage](http://disintegration.github.io/imaging/in_lena_bw_512.png)
Filter | Resize result
---|---
`imaging.NearestNeighbor` | ![dstImage](http://disintegration.github.io/imaging/out_resize_down_nearest.png)
`imaging.Box` | ![dstImage](http://disintegration.github.io/imaging/out_resize_down_box.png)
`imaging.Linear` | ![dstImage](http://disintegration.github.io/imaging/out_resize_down_linear.png)
`imaging.MitchellNetravali` | ![dstImage](http://disintegration.github.io/imaging/out_resize_down_mitchell.png)
`imaging.CatmullRom` | ![dstImage](http://disintegration.github.io/imaging/out_resize_down_catrom.png)
`imaging.Gaussian` | ![dstImage](http://disintegration.github.io/imaging/out_resize_down_gaussian.png)
`imaging.Lanczos` | ![dstImage](http://disintegration.github.io/imaging/out_resize_down_lanczos.png)
### Gaussian Blur
```go
dstImage := imaging.Blur(srcImage, 0.5)
```
Sigma parameter allows to control the strength of the blurring effect.
Original image | Sigma = 0.5 | Sigma = 1.5
---|---|---
![srcImage](http://disintegration.github.io/imaging/in_lena_bw_128.png) | ![dstImage](http://disintegration.github.io/imaging/out_blur_0.5.png) | ![dstImage](http://disintegration.github.io/imaging/out_blur_1.5.png)
### Sharpening
```go
dstImage := imaging.Sharpen(srcImage, 0.5)
```
Uses gaussian function internally. Sigma parameter allows to control the strength of the sharpening effect.
Original image | Sigma = 0.5 | Sigma = 1.5
---|---|---
![srcImage](http://disintegration.github.io/imaging/in_lena_bw_128.png) | ![dstImage](http://disintegration.github.io/imaging/out_sharpen_0.5.png) | ![dstImage](http://disintegration.github.io/imaging/out_sharpen_1.5.png)
### Gamma correction
```go
dstImage := imaging.AdjustGamma(srcImage, 0.75)
```
Original image | Gamma = 0.75 | Gamma = 1.25
---|---|---
![srcImage](http://disintegration.github.io/imaging/in_lena_bw_128.png) | ![dstImage](http://disintegration.github.io/imaging/out_gamma_0.75.png) | ![dstImage](http://disintegration.github.io/imaging/out_gamma_1.25.png)
### Contrast adjustment
```go
dstImage := imaging.AdjustContrast(srcImage, 20)
```
Original image | Contrast = 20 | Contrast = -20
---|---|---
![srcImage](http://disintegration.github.io/imaging/in_lena_bw_128.png) | ![dstImage](http://disintegration.github.io/imaging/out_contrast_p20.png) | ![dstImage](http://disintegration.github.io/imaging/out_contrast_m20.png)
### Brightness adjustment
```go
dstImage := imaging.AdjustBrightness(srcImage, 20)
```
Original image | Brightness = 20 | Brightness = -20
---|---|---
![srcImage](http://disintegration.github.io/imaging/in_lena_bw_128.png) | ![dstImage](http://disintegration.github.io/imaging/out_brightness_p20.png) | ![dstImage](http://disintegration.github.io/imaging/out_brightness_m20.png)
### Complete code example
Here is the code example that loads several images, makes thumbnails of them
and combines them together side-by-side.
```go
package main
import (
"image"
"image/color"
"runtime"
"github.com/disintegration/imaging"
)
func main() {
// use all CPU cores for maximum performance
runtime.GOMAXPROCS(runtime.NumCPU())
// input files
files := []string{"01.jpg", "02.jpg", "03.jpg"}
// load images and make 100x100 thumbnails of them
var thumbnails []image.Image
for _, file := range files {
img, err := imaging.Open(file)
if err != nil {
panic(err)
}
thumb := imaging.Thumbnail(img, 100, 100, imaging.CatmullRom)
thumbnails = append(thumbnails, thumb)
}
// create a new blank image
dst := imaging.New(100*len(thumbnails), 100, color.NRGBA{0, 0, 0, 0})
// paste thumbnails into the new image side by side
for i, thumb := range thumbnails {
dst = imaging.Paste(dst, thumb, image.Pt(i*100, 0))
}
// save the combined image to file
err := imaging.Save(dst, "dst.jpg")
if err != nil {
panic(err)
}
}
```

View File

@ -0,0 +1,200 @@
package imaging
import (
"image"
"image/color"
"math"
)
// AdjustFunc applies the fn function to each pixel of the img image and returns the adjusted image.
//
// Example:
//
// dstImage = imaging.AdjustFunc(
// srcImage,
// func(c color.NRGBA) color.NRGBA {
// // shift the red channel by 16
// r := int(c.R) + 16
// if r > 255 {
// r = 255
// }
// return color.NRGBA{uint8(r), c.G, c.B, c.A}
// }
// )
//
func AdjustFunc(img image.Image, fn func(c color.NRGBA) color.NRGBA) *image.NRGBA {
src := toNRGBA(img)
width := src.Bounds().Max.X
height := src.Bounds().Max.Y
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
parallel(height, func(partStart, partEnd int) {
for y := partStart; y < partEnd; y++ {
for x := 0; x < width; x++ {
i := y*src.Stride + x*4
j := y*dst.Stride + x*4
r := src.Pix[i+0]
g := src.Pix[i+1]
b := src.Pix[i+2]
a := src.Pix[i+3]
c := fn(color.NRGBA{r, g, b, a})
dst.Pix[j+0] = c.R
dst.Pix[j+1] = c.G
dst.Pix[j+2] = c.B
dst.Pix[j+3] = c.A
}
}
})
return dst
}
// AdjustGamma performs a gamma correction on the image and returns the adjusted image.
// Gamma parameter must be positive. Gamma = 1.0 gives the original image.
// Gamma less than 1.0 darkens the image and gamma greater than 1.0 lightens it.
//
// Example:
//
// dstImage = imaging.AdjustGamma(srcImage, 0.7)
//
func AdjustGamma(img image.Image, gamma float64) *image.NRGBA {
e := 1.0 / math.Max(gamma, 0.0001)
lut := make([]uint8, 256)
for i := 0; i < 256; i++ {
lut[i] = clamp(math.Pow(float64(i)/255.0, e) * 255.0)
}
fn := func(c color.NRGBA) color.NRGBA {
return color.NRGBA{lut[c.R], lut[c.G], lut[c.B], c.A}
}
return AdjustFunc(img, fn)
}
func sigmoid(a, b, x float64) float64 {
return 1 / (1 + math.Exp(b*(a-x)))
}
// AdjustSigmoid changes the contrast of the image using a sigmoidal function and returns the adjusted image.
// It's a non-linear contrast change useful for photo adjustments as it preserves highlight and shadow detail.
// The midpoint parameter is the midpoint of contrast that must be between 0 and 1, typically 0.5.
// The factor parameter indicates how much to increase or decrease the contrast, typically in range (-10, 10).
// If the factor parameter is positive the image contrast is increased otherwise the contrast is decreased.
//
// Examples:
//
// dstImage = imaging.AdjustSigmoid(srcImage, 0.5, 3.0) // increase the contrast
// dstImage = imaging.AdjustSigmoid(srcImage, 0.5, -3.0) // decrease the contrast
//
func AdjustSigmoid(img image.Image, midpoint, factor float64) *image.NRGBA {
if factor == 0 {
return Clone(img)
}
lut := make([]uint8, 256)
a := math.Min(math.Max(midpoint, 0.0), 1.0)
b := math.Abs(factor)
sig0 := sigmoid(a, b, 0)
sig1 := sigmoid(a, b, 1)
e := 1.0e-6
if factor > 0 {
for i := 0; i < 256; i++ {
x := float64(i) / 255.0
sigX := sigmoid(a, b, x)
f := (sigX - sig0) / (sig1 - sig0)
lut[i] = clamp(f * 255.0)
}
} else {
for i := 0; i < 256; i++ {
x := float64(i) / 255.0
arg := math.Min(math.Max((sig1-sig0)*x+sig0, e), 1.0-e)
f := a - math.Log(1.0/arg-1.0)/b
lut[i] = clamp(f * 255.0)
}
}
fn := func(c color.NRGBA) color.NRGBA {
return color.NRGBA{lut[c.R], lut[c.G], lut[c.B], c.A}
}
return AdjustFunc(img, fn)
}
// AdjustContrast changes the contrast of the image using the percentage parameter and returns the adjusted image.
// The percentage must be in range (-100, 100). The percentage = 0 gives the original image.
// The percentage = -100 gives solid grey image.
//
// Examples:
//
// dstImage = imaging.AdjustContrast(srcImage, -10) // decrease image contrast by 10%
// dstImage = imaging.AdjustContrast(srcImage, 20) // increase image contrast by 20%
//
func AdjustContrast(img image.Image, percentage float64) *image.NRGBA {
percentage = math.Min(math.Max(percentage, -100.0), 100.0)
lut := make([]uint8, 256)
v := (100.0 + percentage) / 100.0
for i := 0; i < 256; i++ {
if 0 <= v && v <= 1 {
lut[i] = clamp((0.5 + (float64(i)/255.0-0.5)*v) * 255.0)
} else if 1 < v && v < 2 {
lut[i] = clamp((0.5 + (float64(i)/255.0-0.5)*(1/(2.0-v))) * 255.0)
} else {
lut[i] = uint8(float64(i)/255.0+0.5) * 255
}
}
fn := func(c color.NRGBA) color.NRGBA {
return color.NRGBA{lut[c.R], lut[c.G], lut[c.B], c.A}
}
return AdjustFunc(img, fn)
}
// AdjustBrightness changes the brightness of the image using the percentage parameter and returns the adjusted image.
// The percentage must be in range (-100, 100). The percentage = 0 gives the original image.
// The percentage = -100 gives solid black image. The percentage = 100 gives solid white image.
//
// Examples:
//
// dstImage = imaging.AdjustBrightness(srcImage, -15) // decrease image brightness by 15%
// dstImage = imaging.AdjustBrightness(srcImage, 10) // increase image brightness by 10%
//
func AdjustBrightness(img image.Image, percentage float64) *image.NRGBA {
percentage = math.Min(math.Max(percentage, -100.0), 100.0)
lut := make([]uint8, 256)
shift := 255.0 * percentage / 100.0
for i := 0; i < 256; i++ {
lut[i] = clamp(float64(i) + shift)
}
fn := func(c color.NRGBA) color.NRGBA {
return color.NRGBA{lut[c.R], lut[c.G], lut[c.B], c.A}
}
return AdjustFunc(img, fn)
}
// Grayscale produces grayscale version of the image.
func Grayscale(img image.Image) *image.NRGBA {
fn := func(c color.NRGBA) color.NRGBA {
f := 0.299*float64(c.R) + 0.587*float64(c.G) + 0.114*float64(c.B)
y := uint8(f + 0.5)
return color.NRGBA{y, y, y, c.A}
}
return AdjustFunc(img, fn)
}
// Invert produces inverted (negated) version of the image.
func Invert(img image.Image) *image.NRGBA {
fn := func(c color.NRGBA) color.NRGBA {
return color.NRGBA{255 - c.R, 255 - c.G, 255 - c.B, c.A}
}
return AdjustFunc(img, fn)
}

View File

@ -0,0 +1,504 @@
package imaging
import (
"image"
"testing"
)
func TestGrayscale(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"Grayscale 3x3",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0x3d, 0x3d, 0x3d, 0x01, 0x78, 0x78, 0x78, 0x02, 0x17, 0x17, 0x17, 0x03,
0x1f, 0x1f, 0x1f, 0xff, 0x25, 0x25, 0x25, 0xff, 0x66, 0x66, 0x66, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
}
for _, d := range td {
got := Grayscale(d.src)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestInvert(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"Invert 3x3",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0x33, 0xff, 0xff, 0x01, 0xff, 0x33, 0xff, 0x02, 0xff, 0xff, 0x33, 0x03,
0xee, 0xdd, 0xcc, 0xff, 0xcc, 0xdd, 0xee, 0xff, 0x55, 0xcc, 0x44, 0xff,
0xff, 0xff, 0xff, 0xff, 0xcc, 0xcc, 0xcc, 0xff, 0x00, 0x00, 0x00, 0xff,
},
},
},
}
for _, d := range td {
got := Invert(d.src)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestAdjustContrast(t *testing.T) {
td := []struct {
desc string
src image.Image
p float64
want *image.NRGBA
}{
{
"AdjustContrast 3x3 10",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
10,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xd5, 0x00, 0x00, 0x01, 0x00, 0xd5, 0x00, 0x02, 0x00, 0x00, 0xd5, 0x03,
0x05, 0x18, 0x2b, 0xff, 0x2b, 0x18, 0x05, 0xff, 0xaf, 0x2b, 0xc2, 0xff,
0x00, 0x00, 0x00, 0xff, 0x2b, 0x2b, 0x2b, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
{
"AdjustContrast 3x3 100",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
100,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xff, 0x00, 0x00, 0x01, 0x00, 0xff, 0x00, 0x02, 0x00, 0x00, 0xff, 0x03,
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0xff, 0xff,
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
{
"AdjustContrast 3x3 -10",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
-10,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xc4, 0x0d, 0x0d, 0x01, 0x0d, 0xc4, 0x0d, 0x02, 0x0d, 0x0d, 0xc4, 0x03,
0x1c, 0x2b, 0x3b, 0xff, 0x3b, 0x2b, 0x1c, 0xff, 0xa6, 0x3b, 0xb5, 0xff,
0x0d, 0x0d, 0x0d, 0xff, 0x3b, 0x3b, 0x3b, 0xff, 0xf2, 0xf2, 0xf2, 0xff,
},
},
},
{
"AdjustContrast 3x3 -100",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
-100,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0x80, 0x80, 0x80, 0x01, 0x80, 0x80, 0x80, 0x02, 0x80, 0x80, 0x80, 0x03,
0x80, 0x80, 0x80, 0xff, 0x80, 0x80, 0x80, 0xff, 0x80, 0x80, 0x80, 0xff,
0x80, 0x80, 0x80, 0xff, 0x80, 0x80, 0x80, 0xff, 0x80, 0x80, 0x80, 0xff,
},
},
},
{
"AdjustContrast 3x3 0",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
0,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
}
for _, d := range td {
got := AdjustContrast(d.src, d.p)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestAdjustBrightness(t *testing.T) {
td := []struct {
desc string
src image.Image
p float64
want *image.NRGBA
}{
{
"AdjustBrightness 3x3 10",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
10,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xe6, 0x1a, 0x1a, 0x01, 0x1a, 0xe6, 0x1a, 0x02, 0x1a, 0x1a, 0xe6, 0x03,
0x2b, 0x3c, 0x4d, 0xff, 0x4d, 0x3c, 0x2b, 0xff, 0xc4, 0x4d, 0xd5, 0xff,
0x1a, 0x1a, 0x1a, 0xff, 0x4d, 0x4d, 0x4d, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
{
"AdjustBrightness 3x3 100",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
100,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xff, 0xff, 0xff, 0x01, 0xff, 0xff, 0xff, 0x02, 0xff, 0xff, 0xff, 0x03,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
{
"AdjustBrightness 3x3 -10",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
-10,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xb3, 0x00, 0x00, 0x01, 0x00, 0xb3, 0x00, 0x02, 0x00, 0x00, 0xb3, 0x03,
0x00, 0x09, 0x1a, 0xff, 0x1a, 0x09, 0x00, 0xff, 0x91, 0x1a, 0xa2, 0xff,
0x00, 0x00, 0x00, 0xff, 0x1a, 0x1a, 0x1a, 0xff, 0xe6, 0xe6, 0xe6, 0xff,
},
},
},
{
"AdjustBrightness 3x3 -100",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
-100,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03,
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff,
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff,
},
},
},
{
"AdjustBrightness 3x3 0",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
0,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
}
for _, d := range td {
got := AdjustBrightness(d.src, d.p)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestAdjustGamma(t *testing.T) {
td := []struct {
desc string
src image.Image
p float64
want *image.NRGBA
}{
{
"AdjustGamma 3x3 0.75",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
0.75,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xbd, 0x00, 0x00, 0x01, 0x00, 0xbd, 0x00, 0x02, 0x00, 0x00, 0xbd, 0x03,
0x07, 0x11, 0x1e, 0xff, 0x1e, 0x11, 0x07, 0xff, 0x95, 0x1e, 0xa9, 0xff,
0x00, 0x00, 0x00, 0xff, 0x1e, 0x1e, 0x1e, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
{
"AdjustGamma 3x3 1.5",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
1.5,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xdc, 0x00, 0x00, 0x01, 0x00, 0xdc, 0x00, 0x02, 0x00, 0x00, 0xdc, 0x03,
0x2a, 0x43, 0x57, 0xff, 0x57, 0x43, 0x2a, 0xff, 0xc3, 0x57, 0xcf, 0xff,
0x00, 0x00, 0x00, 0xff, 0x57, 0x57, 0x57, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
{
"AdjustGamma 3x3 1.0",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
1.0,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
}
for _, d := range td {
got := AdjustGamma(d.src, d.p)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestAdjustSigmoid(t *testing.T) {
td := []struct {
desc string
src image.Image
m float64
p float64
want *image.NRGBA
}{
{
"AdjustSigmoid 3x3 0.5 3.0",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
0.5,
3.0,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xd4, 0x00, 0x00, 0x01, 0x00, 0xd4, 0x00, 0x02, 0x00, 0x00, 0xd4, 0x03,
0x0d, 0x1b, 0x2b, 0xff, 0x2b, 0x1b, 0x0d, 0xff, 0xb1, 0x2b, 0xc3, 0xff,
0x00, 0x00, 0x00, 0xff, 0x2b, 0x2b, 0x2b, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
{
"AdjustSigmoid 3x3 0.5 -3.0",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
0.5,
-3.0,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xc4, 0x00, 0x00, 0x01, 0x00, 0xc4, 0x00, 0x02, 0x00, 0x00, 0xc4, 0x03,
0x16, 0x2a, 0x3b, 0xff, 0x3b, 0x2a, 0x16, 0xff, 0xa4, 0x3b, 0xb3, 0xff,
0x00, 0x00, 0x00, 0xff, 0x3b, 0x3b, 0x3b, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
{
"AdjustSigmoid 3x3 0.5 0.0",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
0.5,
0.0,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
},
}
for _, d := range td {
got := AdjustSigmoid(d.src, d.m, d.p)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}

View File

@ -0,0 +1,187 @@
package imaging
import (
"image"
"math"
)
func gaussianBlurKernel(x, sigma float64) float64 {
return math.Exp(-(x*x)/(2*sigma*sigma)) / (sigma * math.Sqrt(2*math.Pi))
}
// Blur produces a blurred version of the image using a Gaussian function.
// Sigma parameter must be positive and indicates how much the image will be blurred.
//
// Usage example:
//
// dstImage := imaging.Blur(srcImage, 3.5)
//
func Blur(img image.Image, sigma float64) *image.NRGBA {
if sigma <= 0 {
// sigma parameter must be positive!
return Clone(img)
}
src := toNRGBA(img)
radius := int(math.Ceil(sigma * 3.0))
kernel := make([]float64, radius+1)
for i := 0; i <= radius; i++ {
kernel[i] = gaussianBlurKernel(float64(i), sigma)
}
var dst *image.NRGBA
dst = blurHorizontal(src, kernel)
dst = blurVertical(dst, kernel)
return dst
}
func blurHorizontal(src *image.NRGBA, kernel []float64) *image.NRGBA {
radius := len(kernel) - 1
width := src.Bounds().Max.X
height := src.Bounds().Max.Y
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
parallel(width, func(partStart, partEnd int) {
for x := partStart; x < partEnd; x++ {
start := x - radius
if start < 0 {
start = 0
}
end := x + radius
if end > width-1 {
end = width - 1
}
weightSum := 0.0
for ix := start; ix <= end; ix++ {
weightSum += kernel[absint(x-ix)]
}
for y := 0; y < height; y++ {
r, g, b, a := 0.0, 0.0, 0.0, 0.0
for ix := start; ix <= end; ix++ {
weight := kernel[absint(x-ix)]
i := y*src.Stride + ix*4
r += float64(src.Pix[i+0]) * weight
g += float64(src.Pix[i+1]) * weight
b += float64(src.Pix[i+2]) * weight
a += float64(src.Pix[i+3]) * weight
}
r = math.Min(math.Max(r/weightSum, 0.0), 255.0)
g = math.Min(math.Max(g/weightSum, 0.0), 255.0)
b = math.Min(math.Max(b/weightSum, 0.0), 255.0)
a = math.Min(math.Max(a/weightSum, 0.0), 255.0)
j := y*dst.Stride + x*4
dst.Pix[j+0] = uint8(r + 0.5)
dst.Pix[j+1] = uint8(g + 0.5)
dst.Pix[j+2] = uint8(b + 0.5)
dst.Pix[j+3] = uint8(a + 0.5)
}
}
})
return dst
}
func blurVertical(src *image.NRGBA, kernel []float64) *image.NRGBA {
radius := len(kernel) - 1
width := src.Bounds().Max.X
height := src.Bounds().Max.Y
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
parallel(height, func(partStart, partEnd int) {
for y := partStart; y < partEnd; y++ {
start := y - radius
if start < 0 {
start = 0
}
end := y + radius
if end > height-1 {
end = height - 1
}
weightSum := 0.0
for iy := start; iy <= end; iy++ {
weightSum += kernel[absint(y-iy)]
}
for x := 0; x < width; x++ {
r, g, b, a := 0.0, 0.0, 0.0, 0.0
for iy := start; iy <= end; iy++ {
weight := kernel[absint(y-iy)]
i := iy*src.Stride + x*4
r += float64(src.Pix[i+0]) * weight
g += float64(src.Pix[i+1]) * weight
b += float64(src.Pix[i+2]) * weight
a += float64(src.Pix[i+3]) * weight
}
r = math.Min(math.Max(r/weightSum, 0.0), 255.0)
g = math.Min(math.Max(g/weightSum, 0.0), 255.0)
b = math.Min(math.Max(b/weightSum, 0.0), 255.0)
a = math.Min(math.Max(a/weightSum, 0.0), 255.0)
j := y*dst.Stride + x*4
dst.Pix[j+0] = uint8(r + 0.5)
dst.Pix[j+1] = uint8(g + 0.5)
dst.Pix[j+2] = uint8(b + 0.5)
dst.Pix[j+3] = uint8(a + 0.5)
}
}
})
return dst
}
// Sharpen produces a sharpened version of the image.
// Sigma parameter must be positive and indicates how much the image will be sharpened.
//
// Usage example:
//
// dstImage := imaging.Sharpen(srcImage, 3.5)
//
func Sharpen(img image.Image, sigma float64) *image.NRGBA {
if sigma <= 0 {
// sigma parameter must be positive!
return Clone(img)
}
src := toNRGBA(img)
blurred := Blur(img, sigma)
width := src.Bounds().Max.X
height := src.Bounds().Max.Y
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
parallel(height, func(partStart, partEnd int) {
for y := partStart; y < partEnd; y++ {
for x := 0; x < width; x++ {
i := y*src.Stride + x*4
for j := 0; j < 4; j++ {
k := i + j
val := int(src.Pix[k]) + (int(src.Pix[k]) - int(blurred.Pix[k]))
if val < 0 {
val = 0
} else if val > 255 {
val = 255
}
dst.Pix[k] = uint8(val)
}
}
}
})
return dst
}

View File

@ -0,0 +1,128 @@
package imaging
import (
"image"
"testing"
)
func TestBlur(t *testing.T) {
td := []struct {
desc string
src image.Image
sigma float64
want *image.NRGBA
}{
{
"Blur 3x3 0.5",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x66, 0xaa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
},
0.5,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0x01, 0x02, 0x04, 0x04, 0x0a, 0x10, 0x18, 0x18, 0x01, 0x02, 0x04, 0x04,
0x09, 0x10, 0x18, 0x18, 0x3f, 0x69, 0x9e, 0x9e, 0x09, 0x10, 0x18, 0x18,
0x01, 0x02, 0x04, 0x04, 0x0a, 0x10, 0x18, 0x18, 0x01, 0x02, 0x04, 0x04,
},
},
},
{
"Blur 3x3 10",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x66, 0xaa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
},
10,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c,
0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c,
0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c,
},
},
},
}
for _, d := range td {
got := Blur(d.src, d.sigma)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestSharpen(t *testing.T) {
td := []struct {
desc string
src image.Image
sigma float64
want *image.NRGBA
}{
{
"Sharpen 3x3 0.5",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x77, 0x77, 0x77, 0x77, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
},
},
0.5,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0x66, 0x66, 0x66, 0x66, 0x64, 0x64, 0x64, 0x64, 0x66, 0x66, 0x66, 0x66,
0x64, 0x64, 0x64, 0x64, 0x7e, 0x7e, 0x7e, 0x7e, 0x64, 0x64, 0x64, 0x64,
0x66, 0x66, 0x66, 0x66, 0x64, 0x64, 0x64, 0x64, 0x66, 0x66, 0x66, 0x66},
},
},
{
"Sharpen 3x3 10",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 2),
Stride: 3 * 4,
Pix: []uint8{
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x77, 0x77, 0x77, 0x77, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66},
},
100,
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 3),
Stride: 3 * 4,
Pix: []uint8{
0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64,
0x64, 0x64, 0x64, 0x64, 0x86, 0x86, 0x86, 0x86, 0x64, 0x64, 0x64, 0x64,
0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64,
},
},
},
}
for _, d := range td {
got := Sharpen(d.src, d.sigma)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}

View File

@ -0,0 +1,436 @@
/*
Package imaging provides basic image manipulation functions (resize, rotate, flip, crop, etc.).
This package is based on the standard Go image package and works best along with it.
Image manipulation functions provided by the package take any image type
that implements `image.Image` interface as an input, and return a new image of
`*image.NRGBA` type (32bit RGBA colors, not premultiplied by alpha).
Imaging package uses parallel goroutines for faster image processing.
To achieve maximum performance, make sure to allow Go to utilize all CPU cores:
runtime.GOMAXPROCS(runtime.NumCPU())
*/
package imaging
import (
"errors"
"image"
"image/color"
"image/gif"
"image/jpeg"
"image/png"
"io"
"os"
"path/filepath"
"strings"
"golang.org/x/image/bmp"
"golang.org/x/image/tiff"
)
type Format int
const (
JPEG Format = iota
PNG
GIF
TIFF
BMP
)
func (f Format) String() string {
switch f {
case JPEG:
return "JPEG"
case PNG:
return "PNG"
case GIF:
return "GIF"
case TIFF:
return "TIFF"
case BMP:
return "BMP"
default:
return "Unsupported"
}
}
var (
ErrUnsupportedFormat = errors.New("imaging: unsupported image format")
)
// Decode reads an image from r.
func Decode(r io.Reader) (image.Image, error) {
img, _, err := image.Decode(r)
if err != nil {
return nil, err
}
return toNRGBA(img), nil
}
// Open loads an image from file
func Open(filename string) (image.Image, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
img, err := Decode(file)
return img, err
}
// Encode writes the image img to w in the specified format (JPEG, PNG, GIF, TIFF or BMP).
func Encode(w io.Writer, img image.Image, format Format) error {
var err error
switch format {
case JPEG:
var rgba *image.RGBA
if nrgba, ok := img.(*image.NRGBA); ok {
if nrgba.Opaque() {
rgba = &image.RGBA{
Pix: nrgba.Pix,
Stride: nrgba.Stride,
Rect: nrgba.Rect,
}
}
}
if rgba != nil {
err = jpeg.Encode(w, rgba, &jpeg.Options{Quality: 95})
} else {
err = jpeg.Encode(w, img, &jpeg.Options{Quality: 95})
}
case PNG:
err = png.Encode(w, img)
case GIF:
err = gif.Encode(w, img, &gif.Options{NumColors: 256})
case TIFF:
err = tiff.Encode(w, img, &tiff.Options{Compression: tiff.Deflate, Predictor: true})
case BMP:
err = bmp.Encode(w, img)
default:
err = ErrUnsupportedFormat
}
return err
}
// Save saves the image to file with the specified filename.
// The format is determined from the filename extension: "jpg" (or "jpeg"), "png", "gif", "tif" (or "tiff") and "bmp" are supported.
func Save(img image.Image, filename string) (err error) {
formats := map[string]Format{
".jpg": JPEG,
".jpeg": JPEG,
".png": PNG,
".tif": TIFF,
".tiff": TIFF,
".bmp": BMP,
".gif": GIF,
}
ext := strings.ToLower(filepath.Ext(filename))
f, ok := formats[ext]
if !ok {
return ErrUnsupportedFormat
}
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
return Encode(file, img, f)
}
// New creates a new image with the specified width and height, and fills it with the specified color.
func New(width, height int, fillColor color.Color) *image.NRGBA {
if width <= 0 || height <= 0 {
return &image.NRGBA{}
}
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
c := color.NRGBAModel.Convert(fillColor).(color.NRGBA)
if c.R == 0 && c.G == 0 && c.B == 0 && c.A == 0 {
return dst
}
cs := []uint8{c.R, c.G, c.B, c.A}
// fill the first row
for x := 0; x < width; x++ {
copy(dst.Pix[x*4:(x+1)*4], cs)
}
// copy the first row to other rows
for y := 1; y < height; y++ {
copy(dst.Pix[y*dst.Stride:y*dst.Stride+width*4], dst.Pix[0:width*4])
}
return dst
}
// Clone returns a copy of the given image.
func Clone(img image.Image) *image.NRGBA {
srcBounds := img.Bounds()
srcMinX := srcBounds.Min.X
srcMinY := srcBounds.Min.Y
dstBounds := srcBounds.Sub(srcBounds.Min)
dstW := dstBounds.Dx()
dstH := dstBounds.Dy()
dst := image.NewNRGBA(dstBounds)
switch src := img.(type) {
case *image.NRGBA:
rowSize := srcBounds.Dx() * 4
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
di := dst.PixOffset(0, dstY)
si := src.PixOffset(srcMinX, srcMinY+dstY)
copy(dst.Pix[di:di+rowSize], src.Pix[si:si+rowSize])
}
})
case *image.NRGBA64:
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
di := dst.PixOffset(0, dstY)
si := src.PixOffset(srcMinX, srcMinY+dstY)
for dstX := 0; dstX < dstW; dstX++ {
dst.Pix[di+0] = src.Pix[si+0]
dst.Pix[di+1] = src.Pix[si+2]
dst.Pix[di+2] = src.Pix[si+4]
dst.Pix[di+3] = src.Pix[si+6]
di += 4
si += 8
}
}
})
case *image.RGBA:
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
di := dst.PixOffset(0, dstY)
si := src.PixOffset(srcMinX, srcMinY+dstY)
for dstX := 0; dstX < dstW; dstX++ {
a := src.Pix[si+3]
dst.Pix[di+3] = a
switch a {
case 0:
dst.Pix[di+0] = 0
dst.Pix[di+1] = 0
dst.Pix[di+2] = 0
case 0xff:
dst.Pix[di+0] = src.Pix[si+0]
dst.Pix[di+1] = src.Pix[si+1]
dst.Pix[di+2] = src.Pix[si+2]
default:
dst.Pix[di+0] = uint8(uint16(src.Pix[si+0]) * 0xff / uint16(a))
dst.Pix[di+1] = uint8(uint16(src.Pix[si+1]) * 0xff / uint16(a))
dst.Pix[di+2] = uint8(uint16(src.Pix[si+2]) * 0xff / uint16(a))
}
di += 4
si += 4
}
}
})
case *image.RGBA64:
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
di := dst.PixOffset(0, dstY)
si := src.PixOffset(srcMinX, srcMinY+dstY)
for dstX := 0; dstX < dstW; dstX++ {
a := src.Pix[si+6]
dst.Pix[di+3] = a
switch a {
case 0:
dst.Pix[di+0] = 0
dst.Pix[di+1] = 0
dst.Pix[di+2] = 0
case 0xff:
dst.Pix[di+0] = src.Pix[si+0]
dst.Pix[di+1] = src.Pix[si+2]
dst.Pix[di+2] = src.Pix[si+4]
default:
dst.Pix[di+0] = uint8(uint16(src.Pix[si+0]) * 0xff / uint16(a))
dst.Pix[di+1] = uint8(uint16(src.Pix[si+2]) * 0xff / uint16(a))
dst.Pix[di+2] = uint8(uint16(src.Pix[si+4]) * 0xff / uint16(a))
}
di += 4
si += 8
}
}
})
case *image.Gray:
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
di := dst.PixOffset(0, dstY)
si := src.PixOffset(srcMinX, srcMinY+dstY)
for dstX := 0; dstX < dstW; dstX++ {
c := src.Pix[si]
dst.Pix[di+0] = c
dst.Pix[di+1] = c
dst.Pix[di+2] = c
dst.Pix[di+3] = 0xff
di += 4
si += 1
}
}
})
case *image.Gray16:
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
di := dst.PixOffset(0, dstY)
si := src.PixOffset(srcMinX, srcMinY+dstY)
for dstX := 0; dstX < dstW; dstX++ {
c := src.Pix[si]
dst.Pix[di+0] = c
dst.Pix[di+1] = c
dst.Pix[di+2] = c
dst.Pix[di+3] = 0xff
di += 4
si += 2
}
}
})
case *image.YCbCr:
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
di := dst.PixOffset(0, dstY)
switch src.SubsampleRatio {
case image.YCbCrSubsampleRatio422:
siy0 := dstY * src.YStride
sic0 := dstY * src.CStride
for dstX := 0; dstX < dstW; dstX = dstX + 1 {
siy := siy0 + dstX
sic := sic0 + ((srcMinX+dstX)/2 - srcMinX/2)
r, g, b := color.YCbCrToRGB(src.Y[siy], src.Cb[sic], src.Cr[sic])
dst.Pix[di+0] = r
dst.Pix[di+1] = g
dst.Pix[di+2] = b
dst.Pix[di+3] = 0xff
di += 4
}
case image.YCbCrSubsampleRatio420:
siy0 := dstY * src.YStride
sic0 := ((srcMinY+dstY)/2 - srcMinY/2) * src.CStride
for dstX := 0; dstX < dstW; dstX = dstX + 1 {
siy := siy0 + dstX
sic := sic0 + ((srcMinX+dstX)/2 - srcMinX/2)
r, g, b := color.YCbCrToRGB(src.Y[siy], src.Cb[sic], src.Cr[sic])
dst.Pix[di+0] = r
dst.Pix[di+1] = g
dst.Pix[di+2] = b
dst.Pix[di+3] = 0xff
di += 4
}
case image.YCbCrSubsampleRatio440:
siy0 := dstY * src.YStride
sic0 := ((srcMinY+dstY)/2 - srcMinY/2) * src.CStride
for dstX := 0; dstX < dstW; dstX = dstX + 1 {
siy := siy0 + dstX
sic := sic0 + dstX
r, g, b := color.YCbCrToRGB(src.Y[siy], src.Cb[sic], src.Cr[sic])
dst.Pix[di+0] = r
dst.Pix[di+1] = g
dst.Pix[di+2] = b
dst.Pix[di+3] = 0xff
di += 4
}
default:
siy0 := dstY * src.YStride
sic0 := dstY * src.CStride
for dstX := 0; dstX < dstW; dstX++ {
siy := siy0 + dstX
sic := sic0 + dstX
r, g, b := color.YCbCrToRGB(src.Y[siy], src.Cb[sic], src.Cr[sic])
dst.Pix[di+0] = r
dst.Pix[di+1] = g
dst.Pix[di+2] = b
dst.Pix[di+3] = 0xff
di += 4
}
}
}
})
case *image.Paletted:
plen := len(src.Palette)
pnew := make([]color.NRGBA, plen)
for i := 0; i < plen; i++ {
pnew[i] = color.NRGBAModel.Convert(src.Palette[i]).(color.NRGBA)
}
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
di := dst.PixOffset(0, dstY)
si := src.PixOffset(srcMinX, srcMinY+dstY)
for dstX := 0; dstX < dstW; dstX++ {
c := pnew[src.Pix[si]]
dst.Pix[di+0] = c.R
dst.Pix[di+1] = c.G
dst.Pix[di+2] = c.B
dst.Pix[di+3] = c.A
di += 4
si += 1
}
}
})
default:
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
di := dst.PixOffset(0, dstY)
for dstX := 0; dstX < dstW; dstX++ {
c := color.NRGBAModel.Convert(img.At(srcMinX+dstX, srcMinY+dstY)).(color.NRGBA)
dst.Pix[di+0] = c.R
dst.Pix[di+1] = c.G
dst.Pix[di+2] = c.B
dst.Pix[di+3] = c.A
di += 4
}
}
})
}
return dst
}
// This function used internally to convert any image type to NRGBA if needed.
func toNRGBA(img image.Image) *image.NRGBA {
srcBounds := img.Bounds()
if srcBounds.Min.X == 0 && srcBounds.Min.Y == 0 {
if src0, ok := img.(*image.NRGBA); ok {
return src0
}
}
return Clone(img)
}

View File

@ -0,0 +1,361 @@
package imaging
import (
"bytes"
"image"
"image/color"
"testing"
)
func compareNRGBA(img1, img2 *image.NRGBA, delta int) bool {
if !img1.Rect.Eq(img2.Rect) {
return false
}
if len(img1.Pix) != len(img2.Pix) {
return false
}
for i := 0; i < len(img1.Pix); i++ {
if absint(int(img1.Pix[i])-int(img2.Pix[i])) > delta {
return false
}
}
return true
}
func TestEncodeDecode(t *testing.T) {
imgWithAlpha := image.NewNRGBA(image.Rect(0, 0, 3, 3))
imgWithAlpha.Pix = []uint8{
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
244, 245, 246, 247, 248, 249, 250, 252, 252, 253, 254, 255,
}
imgWithoutAlpha := image.NewNRGBA(image.Rect(0, 0, 3, 3))
imgWithoutAlpha.Pix = []uint8{
0, 1, 2, 255, 4, 5, 6, 255, 8, 9, 10, 255,
127, 128, 129, 255, 131, 132, 133, 255, 135, 136, 137, 255,
244, 245, 246, 255, 248, 249, 250, 255, 252, 253, 254, 255,
}
for _, format := range []Format{JPEG, PNG, GIF, BMP, TIFF} {
img := imgWithoutAlpha
if format == PNG {
img = imgWithAlpha
}
buf := &bytes.Buffer{}
err := Encode(buf, img, format)
if err != nil {
t.Errorf("fail encoding format %s", format)
continue
}
img2, err := Decode(buf)
if err != nil {
t.Errorf("fail decoding format %s", format)
continue
}
img2cloned := Clone(img2)
delta := 0
if format == JPEG {
delta = 3
} else if format == GIF {
delta = 16
}
if !compareNRGBA(img, img2cloned, delta) {
t.Errorf("test [DecodeEncode %s] failed: %#v %#v", format, img, img2cloned)
continue
}
}
buf := &bytes.Buffer{}
err := Encode(buf, imgWithAlpha, Format(100))
if err != ErrUnsupportedFormat {
t.Errorf("expected ErrUnsupportedFormat")
}
}
func TestNew(t *testing.T) {
td := []struct {
desc string
w, h int
c color.Color
dstBounds image.Rectangle
dstPix []uint8
}{
{
"New 1x1 black",
1, 1,
color.NRGBA{0, 0, 0, 0},
image.Rect(0, 0, 1, 1),
[]uint8{0x00, 0x00, 0x00, 0x00},
},
{
"New 1x2 red",
1, 2,
color.NRGBA{255, 0, 0, 255},
image.Rect(0, 0, 1, 2),
[]uint8{0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff},
},
{
"New 2x1 white",
2, 1,
color.NRGBA{255, 255, 255, 255},
image.Rect(0, 0, 2, 1),
[]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
},
}
for _, d := range td {
got := New(d.w, d.h, d.c)
want := image.NewNRGBA(d.dstBounds)
want.Pix = d.dstPix
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestClone(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"Clone NRGBA",
&image.NRGBA{
Rect: image.Rect(-1, -1, 0, 1),
Stride: 1 * 4,
Pix: []uint8{0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 2),
Stride: 1 * 4,
Pix: []uint8{0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff},
},
},
{
"Clone NRGBA64",
&image.NRGBA64{
Rect: image.Rect(-1, -1, 0, 1),
Stride: 1 * 8,
Pix: []uint8{
0x00, 0x00, 0x11, 0x11, 0x22, 0x22, 0x33, 0x33,
0xcc, 0xcc, 0xdd, 0xdd, 0xee, 0xee, 0xff, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 2),
Stride: 1 * 4,
Pix: []uint8{0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff},
},
},
{
"Clone RGBA",
&image.RGBA{
Rect: image.Rect(-1, -1, 0, 1),
Stride: 1 * 4,
Pix: []uint8{0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 2),
Stride: 1 * 4,
Pix: []uint8{0x00, 0x55, 0xaa, 0x33, 0xcc, 0xdd, 0xee, 0xff},
},
},
{
"Clone RGBA64",
&image.RGBA64{
Rect: image.Rect(-1, -1, 0, 1),
Stride: 1 * 8,
Pix: []uint8{
0x00, 0x00, 0x11, 0x11, 0x22, 0x22, 0x33, 0x33,
0xcc, 0xcc, 0xdd, 0xdd, 0xee, 0xee, 0xff, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 2),
Stride: 1 * 4,
Pix: []uint8{0x00, 0x55, 0xaa, 0x33, 0xcc, 0xdd, 0xee, 0xff},
},
},
{
"Clone Gray",
&image.Gray{
Rect: image.Rect(-1, -1, 0, 1),
Stride: 1 * 1,
Pix: []uint8{0x11, 0xee},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 2),
Stride: 1 * 4,
Pix: []uint8{0x11, 0x11, 0x11, 0xff, 0xee, 0xee, 0xee, 0xff},
},
},
{
"Clone Gray16",
&image.Gray16{
Rect: image.Rect(-1, -1, 0, 1),
Stride: 1 * 2,
Pix: []uint8{0x11, 0x11, 0xee, 0xee},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 2),
Stride: 1 * 4,
Pix: []uint8{0x11, 0x11, 0x11, 0xff, 0xee, 0xee, 0xee, 0xff},
},
},
{
"Clone Alpha",
&image.Alpha{
Rect: image.Rect(-1, -1, 0, 1),
Stride: 1 * 1,
Pix: []uint8{0x11, 0xee},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 2),
Stride: 1 * 4,
Pix: []uint8{0xff, 0xff, 0xff, 0x11, 0xff, 0xff, 0xff, 0xee},
},
},
{
"Clone YCbCr",
&image.YCbCr{
Rect: image.Rect(-1, -1, 5, 0),
SubsampleRatio: image.YCbCrSubsampleRatio444,
YStride: 6,
CStride: 6,
Y: []uint8{0x00, 0xff, 0x7f, 0x26, 0x4b, 0x0e},
Cb: []uint8{0x80, 0x80, 0x80, 0x6b, 0x56, 0xc0},
Cr: []uint8{0x80, 0x80, 0x80, 0xc0, 0x4b, 0x76},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 6, 1),
Stride: 6 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0xff,
0xff, 0xff, 0xff, 0xff,
0x7f, 0x7f, 0x7f, 0xff,
0x7f, 0x00, 0x00, 0xff,
0x00, 0x7f, 0x00, 0xff,
0x00, 0x00, 0x7f, 0xff,
},
},
},
{
"Clone YCbCr 444",
&image.YCbCr{
Y: []uint8{0x4c, 0x69, 0x1d, 0xb1, 0x96, 0xe2, 0x26, 0x34, 0xe, 0x59, 0x4b, 0x71, 0x0, 0x4c, 0x99, 0xff},
Cb: []uint8{0x55, 0xd4, 0xff, 0x8e, 0x2c, 0x01, 0x6b, 0xaa, 0xc0, 0x95, 0x56, 0x40, 0x80, 0x80, 0x80, 0x80},
Cr: []uint8{0xff, 0xeb, 0x6b, 0x36, 0x15, 0x95, 0xc0, 0xb5, 0x76, 0x41, 0x4b, 0x8c, 0x80, 0x80, 0x80, 0x80},
YStride: 4,
CStride: 4,
SubsampleRatio: image.YCbCrSubsampleRatio444,
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
},
&image.NRGBA{
Pix: []uint8{0xff, 0x0, 0x0, 0xff, 0xff, 0x0, 0xff, 0xff, 0x0, 0x0, 0xff, 0xff, 0x49, 0xe1, 0xca, 0xff, 0x0, 0xff, 0x0, 0xff, 0xff, 0xff, 0x0, 0xff, 0x7f, 0x0, 0x0, 0xff, 0x7f, 0x0, 0x7f, 0xff, 0x0, 0x0, 0x7f, 0xff, 0x0, 0x7f, 0x7f, 0xff, 0x0, 0x7f, 0x0, 0xff, 0x82, 0x7f, 0x0, 0xff, 0x0, 0x0, 0x0, 0xff, 0x4c, 0x4c, 0x4c, 0xff, 0x99, 0x99, 0x99, 0xff, 0xff, 0xff, 0xff, 0xff},
Stride: 16,
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
},
},
{
"Clone YCbCr 440",
&image.YCbCr{
Y: []uint8{0x4c, 0x69, 0x1d, 0xb1, 0x96, 0xe2, 0x26, 0x34, 0xe, 0x59, 0x4b, 0x71, 0x0, 0x4c, 0x99, 0xff},
Cb: []uint8{0x2c, 0x01, 0x6b, 0xaa, 0x80, 0x80, 0x80, 0x80},
Cr: []uint8{0x15, 0x95, 0xc0, 0xb5, 0x80, 0x80, 0x80, 0x80},
YStride: 4,
CStride: 4,
SubsampleRatio: image.YCbCrSubsampleRatio440,
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
},
&image.NRGBA{
Pix: []uint8{0x0, 0xb5, 0x0, 0xff, 0x86, 0x86, 0x0, 0xff, 0x77, 0x0, 0x0, 0xff, 0xfb, 0x7d, 0xfb, 0xff, 0x0, 0xff, 0x1, 0xff, 0xff, 0xff, 0x1, 0xff, 0x80, 0x0, 0x1, 0xff, 0x7e, 0x0, 0x7e, 0xff, 0xe, 0xe, 0xe, 0xff, 0x59, 0x59, 0x59, 0xff, 0x4b, 0x4b, 0x4b, 0xff, 0x71, 0x71, 0x71, 0xff, 0x0, 0x0, 0x0, 0xff, 0x4c, 0x4c, 0x4c, 0xff, 0x99, 0x99, 0x99, 0xff, 0xff, 0xff, 0xff, 0xff},
Stride: 16,
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
},
},
{
"Clone YCbCr 422",
&image.YCbCr{
Y: []uint8{0x4c, 0x69, 0x1d, 0xb1, 0x96, 0xe2, 0x26, 0x34, 0xe, 0x59, 0x4b, 0x71, 0x0, 0x4c, 0x99, 0xff},
Cb: []uint8{0xd4, 0x8e, 0x01, 0xaa, 0x95, 0x40, 0x80, 0x80},
Cr: []uint8{0xeb, 0x36, 0x95, 0xb5, 0x41, 0x8c, 0x80, 0x80},
YStride: 4,
CStride: 2,
SubsampleRatio: image.YCbCrSubsampleRatio422,
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
},
&image.NRGBA{
Pix: []uint8{0xe2, 0x0, 0xe1, 0xff, 0xff, 0x0, 0xfe, 0xff, 0x0, 0x4d, 0x36, 0xff, 0x49, 0xe1, 0xca, 0xff, 0xb3, 0xb3, 0x0, 0xff, 0xff, 0xff, 0x1, 0xff, 0x70, 0x0, 0x70, 0xff, 0x7e, 0x0, 0x7e, 0xff, 0x0, 0x34, 0x33, 0xff, 0x1, 0x7f, 0x7e, 0xff, 0x5c, 0x58, 0x0, 0xff, 0x82, 0x7e, 0x0, 0xff, 0x0, 0x0, 0x0, 0xff, 0x4c, 0x4c, 0x4c, 0xff, 0x99, 0x99, 0x99, 0xff, 0xff, 0xff, 0xff, 0xff},
Stride: 16,
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
},
},
{
"Clone YCbCr 420",
&image.YCbCr{
Y: []uint8{0x4c, 0x69, 0x1d, 0xb1, 0x96, 0xe2, 0x26, 0x34, 0xe, 0x59, 0x4b, 0x71, 0x0, 0x4c, 0x99, 0xff},
Cb: []uint8{0x01, 0xaa, 0x80, 0x80},
Cr: []uint8{0x95, 0xb5, 0x80, 0x80},
YStride: 4, CStride: 2,
SubsampleRatio: image.YCbCrSubsampleRatio420,
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
},
&image.NRGBA{
Pix: []uint8{0x69, 0x69, 0x0, 0xff, 0x86, 0x86, 0x0, 0xff, 0x67, 0x0, 0x67, 0xff, 0xfb, 0x7d, 0xfb, 0xff, 0xb3, 0xb3, 0x0, 0xff, 0xff, 0xff, 0x1, 0xff, 0x70, 0x0, 0x70, 0xff, 0x7e, 0x0, 0x7e, 0xff, 0xe, 0xe, 0xe, 0xff, 0x59, 0x59, 0x59, 0xff, 0x4b, 0x4b, 0x4b, 0xff, 0x71, 0x71, 0x71, 0xff, 0x0, 0x0, 0x0, 0xff, 0x4c, 0x4c, 0x4c, 0xff, 0x99, 0x99, 0x99, 0xff, 0xff, 0xff, 0xff, 0xff},
Stride: 16,
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
},
},
{
"Clone Paletted",
&image.Paletted{
Rect: image.Rect(-1, -1, 5, 0),
Stride: 6 * 1,
Palette: color.Palette{
color.NRGBA{R: 0x00, G: 0x00, B: 0x00, A: 0xff},
color.NRGBA{R: 0xff, G: 0xff, B: 0xff, A: 0xff},
color.NRGBA{R: 0x7f, G: 0x7f, B: 0x7f, A: 0xff},
color.NRGBA{R: 0x7f, G: 0x00, B: 0x00, A: 0xff},
color.NRGBA{R: 0x00, G: 0x7f, B: 0x00, A: 0xff},
color.NRGBA{R: 0x00, G: 0x00, B: 0x7f, A: 0xff},
},
Pix: []uint8{0x0, 0x1, 0x2, 0x3, 0x4, 0x5},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 6, 1),
Stride: 6 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0xff,
0xff, 0xff, 0xff, 0xff,
0x7f, 0x7f, 0x7f, 0xff,
0x7f, 0x00, 0x00, 0xff,
0x00, 0x7f, 0x00, 0xff,
0x00, 0x00, 0x7f, 0xff,
},
},
},
}
for _, d := range td {
got := Clone(d.src)
want := d.want
delta := 0
if _, ok := d.src.(*image.YCbCr); ok {
delta = 1
}
if !compareNRGBA(got, want, delta) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}

View File

@ -0,0 +1,564 @@
package imaging
import (
"image"
"math"
)
type iwpair struct {
i int
w int32
}
type pweights struct {
iwpairs []iwpair
wsum int32
}
func precomputeWeights(dstSize, srcSize int, filter ResampleFilter) []pweights {
du := float64(srcSize) / float64(dstSize)
scale := du
if scale < 1.0 {
scale = 1.0
}
ru := math.Ceil(scale * filter.Support)
out := make([]pweights, dstSize)
for v := 0; v < dstSize; v++ {
fu := (float64(v)+0.5)*du - 0.5
startu := int(math.Ceil(fu - ru))
if startu < 0 {
startu = 0
}
endu := int(math.Floor(fu + ru))
if endu > srcSize-1 {
endu = srcSize - 1
}
wsum := int32(0)
for u := startu; u <= endu; u++ {
w := int32(0xff * filter.Kernel((float64(u)-fu)/scale))
if w != 0 {
wsum += w
out[v].iwpairs = append(out[v].iwpairs, iwpair{u, w})
}
}
out[v].wsum = wsum
}
return out
}
// Resize resizes the image to the specified width and height using the specified resampling
// filter and returns the transformed image. If one of width or height is 0, the image aspect
// ratio is preserved.
//
// Supported resample filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali,
// CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine.
//
// Usage example:
//
// dstImage := imaging.Resize(srcImage, 800, 600, imaging.Lanczos)
//
func Resize(img image.Image, width, height int, filter ResampleFilter) *image.NRGBA {
dstW, dstH := width, height
if dstW < 0 || dstH < 0 {
return &image.NRGBA{}
}
if dstW == 0 && dstH == 0 {
return &image.NRGBA{}
}
src := toNRGBA(img)
srcW := src.Bounds().Max.X
srcH := src.Bounds().Max.Y
if srcW <= 0 || srcH <= 0 {
return &image.NRGBA{}
}
// if new width or height is 0 then preserve aspect ratio, minimum 1px
if dstW == 0 {
tmpW := float64(dstH) * float64(srcW) / float64(srcH)
dstW = int(math.Max(1.0, math.Floor(tmpW+0.5)))
}
if dstH == 0 {
tmpH := float64(dstW) * float64(srcH) / float64(srcW)
dstH = int(math.Max(1.0, math.Floor(tmpH+0.5)))
}
var dst *image.NRGBA
if filter.Support <= 0.0 {
// nearest-neighbor special case
dst = resizeNearest(src, dstW, dstH)
} else {
// two-pass resize
if srcW != dstW {
dst = resizeHorizontal(src, dstW, filter)
} else {
dst = src
}
if srcH != dstH {
dst = resizeVertical(dst, dstH, filter)
}
}
return dst
}
func resizeHorizontal(src *image.NRGBA, width int, filter ResampleFilter) *image.NRGBA {
srcBounds := src.Bounds()
srcW := srcBounds.Max.X
srcH := srcBounds.Max.Y
dstW := width
dstH := srcH
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
weights := precomputeWeights(dstW, srcW, filter)
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
for dstX := 0; dstX < dstW; dstX++ {
var c [4]int32
for _, iw := range weights[dstX].iwpairs {
i := dstY*src.Stride + iw.i*4
c[0] += int32(src.Pix[i+0]) * iw.w
c[1] += int32(src.Pix[i+1]) * iw.w
c[2] += int32(src.Pix[i+2]) * iw.w
c[3] += int32(src.Pix[i+3]) * iw.w
}
j := dstY*dst.Stride + dstX*4
sum := weights[dstX].wsum
dst.Pix[j+0] = clampint32(int32(float32(c[0])/float32(sum) + 0.5))
dst.Pix[j+1] = clampint32(int32(float32(c[1])/float32(sum) + 0.5))
dst.Pix[j+2] = clampint32(int32(float32(c[2])/float32(sum) + 0.5))
dst.Pix[j+3] = clampint32(int32(float32(c[3])/float32(sum) + 0.5))
}
}
})
return dst
}
func resizeVertical(src *image.NRGBA, height int, filter ResampleFilter) *image.NRGBA {
srcBounds := src.Bounds()
srcW := srcBounds.Max.X
srcH := srcBounds.Max.Y
dstW := srcW
dstH := height
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
weights := precomputeWeights(dstH, srcH, filter)
parallel(dstW, func(partStart, partEnd int) {
for dstX := partStart; dstX < partEnd; dstX++ {
for dstY := 0; dstY < dstH; dstY++ {
var c [4]int32
for _, iw := range weights[dstY].iwpairs {
i := iw.i*src.Stride + dstX*4
c[0] += int32(src.Pix[i+0]) * iw.w
c[1] += int32(src.Pix[i+1]) * iw.w
c[2] += int32(src.Pix[i+2]) * iw.w
c[3] += int32(src.Pix[i+3]) * iw.w
}
j := dstY*dst.Stride + dstX*4
sum := weights[dstY].wsum
dst.Pix[j+0] = clampint32(int32(float32(c[0])/float32(sum) + 0.5))
dst.Pix[j+1] = clampint32(int32(float32(c[1])/float32(sum) + 0.5))
dst.Pix[j+2] = clampint32(int32(float32(c[2])/float32(sum) + 0.5))
dst.Pix[j+3] = clampint32(int32(float32(c[3])/float32(sum) + 0.5))
}
}
})
return dst
}
// fast nearest-neighbor resize, no filtering
func resizeNearest(src *image.NRGBA, width, height int) *image.NRGBA {
dstW, dstH := width, height
srcBounds := src.Bounds()
srcW := srcBounds.Max.X
srcH := srcBounds.Max.Y
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
dx := float64(srcW) / float64(dstW)
dy := float64(srcH) / float64(dstH)
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
fy := (float64(dstY)+0.5)*dy - 0.5
for dstX := 0; dstX < dstW; dstX++ {
fx := (float64(dstX)+0.5)*dx - 0.5
srcX := int(math.Min(math.Max(math.Floor(fx+0.5), 0.0), float64(srcW)))
srcY := int(math.Min(math.Max(math.Floor(fy+0.5), 0.0), float64(srcH)))
srcOff := srcY*src.Stride + srcX*4
dstOff := dstY*dst.Stride + dstX*4
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
}
}
})
return dst
}
// Fit scales down the image using the specified resample filter to fit the specified
// maximum width and height and returns the transformed image.
//
// Supported resample filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali,
// CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine.
//
// Usage example:
//
// dstImage := imaging.Fit(srcImage, 800, 600, imaging.Lanczos)
//
func Fit(img image.Image, width, height int, filter ResampleFilter) *image.NRGBA {
maxW, maxH := width, height
if maxW <= 0 || maxH <= 0 {
return &image.NRGBA{}
}
srcBounds := img.Bounds()
srcW := srcBounds.Dx()
srcH := srcBounds.Dy()
if srcW <= 0 || srcH <= 0 {
return &image.NRGBA{}
}
if srcW <= maxW && srcH <= maxH {
return Clone(img)
}
srcAspectRatio := float64(srcW) / float64(srcH)
maxAspectRatio := float64(maxW) / float64(maxH)
var newW, newH int
if srcAspectRatio > maxAspectRatio {
newW = maxW
newH = int(float64(newW) / srcAspectRatio)
} else {
newH = maxH
newW = int(float64(newH) * srcAspectRatio)
}
return Resize(img, newW, newH, filter)
}
// Thumbnail scales the image up or down using the specified resample filter, crops it
// to the specified width and hight and returns the transformed image.
//
// Supported resample filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali,
// CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine.
//
// Usage example:
//
// dstImage := imaging.Thumbnail(srcImage, 100, 100, imaging.Lanczos)
//
func Thumbnail(img image.Image, width, height int, filter ResampleFilter) *image.NRGBA {
thumbW, thumbH := width, height
if thumbW <= 0 || thumbH <= 0 {
return &image.NRGBA{}
}
srcBounds := img.Bounds()
srcW := srcBounds.Dx()
srcH := srcBounds.Dy()
if srcW <= 0 || srcH <= 0 {
return &image.NRGBA{}
}
srcAspectRatio := float64(srcW) / float64(srcH)
thumbAspectRatio := float64(thumbW) / float64(thumbH)
var tmp image.Image
if srcAspectRatio > thumbAspectRatio {
tmp = Resize(img, 0, thumbH, filter)
} else {
tmp = Resize(img, thumbW, 0, filter)
}
return CropCenter(tmp, thumbW, thumbH)
}
// Resample filter struct. It can be used to make custom filters.
//
// Supported resample filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali,
// CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine.
//
// General filter recommendations:
//
// - Lanczos
// Probably the best resampling filter for photographic images yielding sharp results,
// but it's slower than cubic filters (see below).
//
// - CatmullRom
// A sharp cubic filter. It's a good filter for both upscaling and downscaling if sharp results are needed.
//
// - MitchellNetravali
// A high quality cubic filter that produces smoother results with less ringing than CatmullRom.
//
// - BSpline
// A good filter if a very smooth output is needed.
//
// - Linear
// Bilinear interpolation filter, produces reasonably good, smooth output. It's faster than cubic filters.
//
// - Box
// Simple and fast resampling filter appropriate for downscaling.
// When upscaling it's similar to NearestNeighbor.
//
// - NearestNeighbor
// Fastest resample filter, no antialiasing at all. Rarely used.
//
type ResampleFilter struct {
Support float64
Kernel func(float64) float64
}
// Nearest-neighbor filter, no anti-aliasing.
var NearestNeighbor ResampleFilter
// Box filter (averaging pixels).
var Box ResampleFilter
// Linear filter.
var Linear ResampleFilter
// Hermite cubic spline filter (BC-spline; B=0; C=0).
var Hermite ResampleFilter
// Mitchell-Netravali cubic filter (BC-spline; B=1/3; C=1/3).
var MitchellNetravali ResampleFilter
// Catmull-Rom - sharp cubic filter (BC-spline; B=0; C=0.5).
var CatmullRom ResampleFilter
// Cubic B-spline - smooth cubic filter (BC-spline; B=1; C=0).
var BSpline ResampleFilter
// Gaussian Blurring Filter.
var Gaussian ResampleFilter
// Bartlett-windowed sinc filter (3 lobes).
var Bartlett ResampleFilter
// Lanczos filter (3 lobes).
var Lanczos ResampleFilter
// Hann-windowed sinc filter (3 lobes).
var Hann ResampleFilter
// Hamming-windowed sinc filter (3 lobes).
var Hamming ResampleFilter
// Blackman-windowed sinc filter (3 lobes).
var Blackman ResampleFilter
// Welch-windowed sinc filter (parabolic window, 3 lobes).
var Welch ResampleFilter
// Cosine-windowed sinc filter (3 lobes).
var Cosine ResampleFilter
func bcspline(x, b, c float64) float64 {
x = math.Abs(x)
if x < 1.0 {
return ((12-9*b-6*c)*x*x*x + (-18+12*b+6*c)*x*x + (6 - 2*b)) / 6
}
if x < 2.0 {
return ((-b-6*c)*x*x*x + (6*b+30*c)*x*x + (-12*b-48*c)*x + (8*b + 24*c)) / 6
}
return 0
}
func sinc(x float64) float64 {
if x == 0 {
return 1
}
return math.Sin(math.Pi*x) / (math.Pi * x)
}
func init() {
NearestNeighbor = ResampleFilter{
Support: 0.0, // special case - not applying the filter
}
Box = ResampleFilter{
Support: 0.5,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x <= 0.5 {
return 1.0
}
return 0
},
}
Linear = ResampleFilter{
Support: 1.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 1.0 {
return 1.0 - x
}
return 0
},
}
Hermite = ResampleFilter{
Support: 1.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 1.0 {
return bcspline(x, 0.0, 0.0)
}
return 0
},
}
MitchellNetravali = ResampleFilter{
Support: 2.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 2.0 {
return bcspline(x, 1.0/3.0, 1.0/3.0)
}
return 0
},
}
CatmullRom = ResampleFilter{
Support: 2.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 2.0 {
return bcspline(x, 0.0, 0.5)
}
return 0
},
}
BSpline = ResampleFilter{
Support: 2.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 2.0 {
return bcspline(x, 1.0, 0.0)
}
return 0
},
}
Gaussian = ResampleFilter{
Support: 2.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 2.0 {
return math.Exp(-2 * x * x)
}
return 0
},
}
Bartlett = ResampleFilter{
Support: 3.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 3.0 {
return sinc(x) * (3.0 - x) / 3.0
}
return 0
},
}
Lanczos = ResampleFilter{
Support: 3.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 3.0 {
return sinc(x) * sinc(x/3.0)
}
return 0
},
}
Hann = ResampleFilter{
Support: 3.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 3.0 {
return sinc(x) * (0.5 + 0.5*math.Cos(math.Pi*x/3.0))
}
return 0
},
}
Hamming = ResampleFilter{
Support: 3.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 3.0 {
return sinc(x) * (0.54 + 0.46*math.Cos(math.Pi*x/3.0))
}
return 0
},
}
Blackman = ResampleFilter{
Support: 3.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 3.0 {
return sinc(x) * (0.42 - 0.5*math.Cos(math.Pi*x/3.0+math.Pi) + 0.08*math.Cos(2.0*math.Pi*x/3.0))
}
return 0
},
}
Welch = ResampleFilter{
Support: 3.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 3.0 {
return sinc(x) * (1.0 - (x * x / 9.0))
}
return 0
},
}
Cosine = ResampleFilter{
Support: 3.0,
Kernel: func(x float64) float64 {
x = math.Abs(x)
if x < 3.0 {
return sinc(x) * math.Cos((math.Pi/2.0)*(x/3.0))
}
return 0
},
}
}

View File

@ -0,0 +1,281 @@
package imaging
import (
"image"
"testing"
)
func TestResize(t *testing.T) {
td := []struct {
desc string
src image.Image
w, h int
f ResampleFilter
want *image.NRGBA
}{
{
"Resize 2x2 1x1 box",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 1),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
1, 1,
Box,
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 1),
Stride: 1 * 4,
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
},
},
{
"Resize 2x2 2x2 box",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 1),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
2, 2,
Box,
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
},
{
"Resize 3x1 1x1 nearest",
&image.NRGBA{
Rect: image.Rect(-1, -1, 2, 0),
Stride: 3 * 4,
Pix: []uint8{
0xff, 0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
1, 1,
NearestNeighbor,
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 1),
Stride: 1 * 4,
Pix: []uint8{0x00, 0xff, 0x00, 0xff},
},
},
{
"Resize 2x2 0x4 box",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 1),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
0, 4,
Box,
&image.NRGBA{
Rect: image.Rect(0, 0, 4, 4),
Stride: 4 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
},
{
"Resize 2x2 4x0 linear",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 1),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
4, 0,
Linear,
&image.NRGBA{
Rect: image.Rect(0, 0, 4, 4),
Stride: 4 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0xbf, 0x00, 0x00, 0xbf, 0xff, 0x00, 0x00, 0xff,
0x00, 0x40, 0x00, 0x40, 0x30, 0x30, 0x10, 0x70, 0x8f, 0x10, 0x30, 0xcf, 0xbf, 0x00, 0x40, 0xff,
0x00, 0xbf, 0x00, 0xbf, 0x10, 0x8f, 0x30, 0xcf, 0x30, 0x30, 0x8f, 0xef, 0x40, 0x00, 0xbf, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0xbf, 0x40, 0xff, 0x00, 0x40, 0xbf, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
},
}
for _, d := range td {
got := Resize(d.src, d.w, d.h, d.f)
want := d.want
if !compareNRGBA(got, want, 1) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestFit(t *testing.T) {
td := []struct {
desc string
src image.Image
w, h int
f ResampleFilter
want *image.NRGBA
}{
{
"Fit 2x2 1x10 box",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 1),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
1, 10,
Box,
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 1),
Stride: 1 * 4,
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
},
},
{
"Fit 2x2 10x1 box",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 1),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
10, 1,
Box,
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 1),
Stride: 1 * 4,
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
},
},
{
"Fit 2x2 10x10 box",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 1),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
10, 10,
Box,
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
},
},
},
}
for _, d := range td {
got := Fit(d.src, d.w, d.h, d.f)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestThumbnail(t *testing.T) {
td := []struct {
desc string
src image.Image
w, h int
f ResampleFilter
want *image.NRGBA
}{
{
"Thumbnail 6x2 1x1 box",
&image.NRGBA{
Rect: image.Rect(-1, -1, 5, 1),
Stride: 6 * 4,
Pix: []uint8{
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
1, 1,
Box,
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 1),
Stride: 1 * 4,
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
},
},
{
"Thumbnail 2x6 1x1 box",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 5),
Stride: 2 * 4,
Pix: []uint8{
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
},
},
1, 1,
Box,
&image.NRGBA{
Rect: image.Rect(0, 0, 1, 1),
Stride: 1 * 4,
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
},
},
{
"Thumbnail 1x3 2x2 box",
&image.NRGBA{
Rect: image.Rect(-1, -1, 0, 2),
Stride: 1 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0x00,
0xff, 0x00, 0x00, 0xff,
0xff, 0xff, 0xff, 0xff,
},
},
2, 2,
Box,
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 2),
Stride: 2 * 4,
Pix: []uint8{
0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff,
0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff,
},
},
},
}
for _, d := range td {
got := Thumbnail(d.src, d.w, d.h, d.f)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}

View File

@ -0,0 +1,139 @@
package imaging
import (
"image"
"math"
)
// Crop cuts out a rectangular region with the specified bounds
// from the image and returns the cropped image.
func Crop(img image.Image, rect image.Rectangle) *image.NRGBA {
src := toNRGBA(img)
srcRect := rect.Sub(img.Bounds().Min)
sub := src.SubImage(srcRect)
return Clone(sub) // New image Bounds().Min point will be (0, 0)
}
// CropCenter cuts out a rectangular region with the specified size
// from the center of the image and returns the cropped image.
func CropCenter(img image.Image, width, height int) *image.NRGBA {
cropW, cropH := width, height
srcBounds := img.Bounds()
srcW := srcBounds.Dx()
srcH := srcBounds.Dy()
srcMinX := srcBounds.Min.X
srcMinY := srcBounds.Min.Y
centerX := srcMinX + srcW/2
centerY := srcMinY + srcH/2
x0 := centerX - cropW/2
y0 := centerY - cropH/2
x1 := x0 + cropW
y1 := y0 + cropH
return Crop(img, image.Rect(x0, y0, x1, y1))
}
// Paste pastes the img image to the background image at the specified position and returns the combined image.
func Paste(background, img image.Image, pos image.Point) *image.NRGBA {
src := toNRGBA(img)
dst := Clone(background) // cloned image bounds start at (0, 0)
startPt := pos.Sub(background.Bounds().Min) // so we should translate start point
endPt := startPt.Add(src.Bounds().Size())
pasteBounds := image.Rectangle{startPt, endPt}
if dst.Bounds().Overlaps(pasteBounds) {
intersectBounds := dst.Bounds().Intersect(pasteBounds)
rowSize := intersectBounds.Dx() * 4
numRows := intersectBounds.Dy()
srcStartX := intersectBounds.Min.X - pasteBounds.Min.X
srcStartY := intersectBounds.Min.Y - pasteBounds.Min.Y
i0 := dst.PixOffset(intersectBounds.Min.X, intersectBounds.Min.Y)
j0 := src.PixOffset(srcStartX, srcStartY)
di := dst.Stride
dj := src.Stride
for row := 0; row < numRows; row++ {
copy(dst.Pix[i0:i0+rowSize], src.Pix[j0:j0+rowSize])
i0 += di
j0 += dj
}
}
return dst
}
// PasteCenter pastes the img image to the center of the background image and returns the combined image.
func PasteCenter(background, img image.Image) *image.NRGBA {
bgBounds := background.Bounds()
bgW := bgBounds.Dx()
bgH := bgBounds.Dy()
bgMinX := bgBounds.Min.X
bgMinY := bgBounds.Min.Y
centerX := bgMinX + bgW/2
centerY := bgMinY + bgH/2
x0 := centerX - img.Bounds().Dx()/2
y0 := centerY - img.Bounds().Dy()/2
return Paste(background, img, image.Pt(x0, y0))
}
// Overlay draws the img image over the background image at given position
// and returns the combined image. Opacity parameter is the opacity of the img
// image layer, used to compose the images, it must be from 0.0 to 1.0.
//
// Usage examples:
//
// // draw the sprite over the background at position (50, 50)
// dstImage := imaging.Overlay(backgroundImage, spriteImage, image.Pt(50, 50), 1.0)
//
// // blend two opaque images of the same size
// dstImage := imaging.Overlay(imageOne, imageTwo, image.Pt(0, 0), 0.5)
//
func Overlay(background, img image.Image, pos image.Point, opacity float64) *image.NRGBA {
opacity = math.Min(math.Max(opacity, 0.0), 1.0) // check: 0.0 <= opacity <= 1.0
src := toNRGBA(img)
dst := Clone(background) // cloned image bounds start at (0, 0)
startPt := pos.Sub(background.Bounds().Min) // so we should translate start point
endPt := startPt.Add(src.Bounds().Size())
pasteBounds := image.Rectangle{startPt, endPt}
if dst.Bounds().Overlaps(pasteBounds) {
intersectBounds := dst.Bounds().Intersect(pasteBounds)
for y := intersectBounds.Min.Y; y < intersectBounds.Max.Y; y++ {
for x := intersectBounds.Min.X; x < intersectBounds.Max.X; x++ {
i := y*dst.Stride + x*4
srcX := x - pasteBounds.Min.X
srcY := y - pasteBounds.Min.Y
j := srcY*src.Stride + srcX*4
a1 := float64(dst.Pix[i+3])
a2 := float64(src.Pix[j+3])
coef2 := opacity * a2 / 255.0
coef1 := (1 - coef2) * a1 / 255.0
coefSum := coef1 + coef2
coef1 /= coefSum
coef2 /= coefSum
dst.Pix[i+0] = uint8(float64(dst.Pix[i+0])*coef1 + float64(src.Pix[j+0])*coef2)
dst.Pix[i+1] = uint8(float64(dst.Pix[i+1])*coef1 + float64(src.Pix[j+1])*coef2)
dst.Pix[i+2] = uint8(float64(dst.Pix[i+2])*coef1 + float64(src.Pix[j+2])*coef2)
dst.Pix[i+3] = uint8(math.Min(a1+a2*opacity*(255.0-a1)/255.0, 255.0))
}
}
}
return dst
}

View File

@ -0,0 +1,250 @@
package imaging
import (
"image"
"testing"
)
func TestCrop(t *testing.T) {
td := []struct {
desc string
src image.Image
r image.Rectangle
want *image.NRGBA
}{
{
"Crop 2x3 2x1",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
image.Rect(-1, 0, 1, 1),
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 1),
Stride: 2 * 4,
Pix: []uint8{
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
},
},
},
}
for _, d := range td {
got := Crop(d.src, d.r)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestCropCenter(t *testing.T) {
td := []struct {
desc string
src image.Image
w, h int
want *image.NRGBA
}{
{
"CropCenter 2x3 2x1",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
2, 1,
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 1),
Stride: 2 * 4,
Pix: []uint8{
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
},
},
},
}
for _, d := range td {
got := CropCenter(d.src, d.w, d.h)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestPaste(t *testing.T) {
td := []struct {
desc string
src1 image.Image
src2 image.Image
p image.Point
want *image.NRGBA
}{
{
"Paste 2x3 2x1",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(1, 1, 3, 2),
Stride: 2 * 4,
Pix: []uint8{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
},
},
image.Pt(-1, 0),
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 3),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
},
}
for _, d := range td {
got := Paste(d.src1, d.src2, d.p)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestPasteCenter(t *testing.T) {
td := []struct {
desc string
src1 image.Image
src2 image.Image
want *image.NRGBA
}{
{
"PasteCenter 2x3 2x1",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(1, 1, 3, 2),
Stride: 2 * 4,
Pix: []uint8{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 3),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
},
}
for _, d := range td {
got := PasteCenter(d.src1, d.src2)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestOverlay(t *testing.T) {
td := []struct {
desc string
src1 image.Image
src2 image.Image
p image.Point
a float64
want *image.NRGBA
}{
{
"Overlay 2x3 2x1 1.0",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0x60, 0x00, 0x90, 0xff, 0xff, 0x00, 0x99, 0x7f,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(1, 1, 3, 2),
Stride: 2 * 4,
Pix: []uint8{
0x20, 0x40, 0x80, 0x7f, 0xaa, 0xbb, 0xcc, 0xff,
},
},
image.Pt(-1, 0),
1.0,
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 3),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0x40, 0x1f, 0x88, 0xff, 0xaa, 0xbb, 0xcc, 0xff,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
},
{
"Overlay 2x2 2x2 0.5",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 1),
Stride: 2 * 4,
Pix: []uint8{
0xff, 0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff,
0x00, 0x00, 0xff, 0xff, 0x20, 0x20, 0x20, 0x00,
},
},
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 1),
Stride: 2 * 4,
Pix: []uint8{
0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
0xff, 0xff, 0x00, 0xff, 0x20, 0x20, 0x20, 0xff,
},
},
image.Pt(-1, -1),
0.5,
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 2),
Stride: 2 * 4,
Pix: []uint8{
0xff, 0x7f, 0x7f, 0xff, 0x00, 0xff, 0x00, 0xff,
0x7f, 0x7f, 0x7f, 0xff, 0x20, 0x20, 0x20, 0x7f,
},
},
},
}
for _, d := range td {
got := Overlay(d.src1, d.src2, d.p, d.a)
want := d.want
if !compareNRGBA(got, want, 1) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}

View File

@ -0,0 +1,201 @@
package imaging
import (
"image"
)
// Rotate90 rotates the image 90 degrees counterclockwise and returns the transformed image.
func Rotate90(img image.Image) *image.NRGBA {
src := toNRGBA(img)
srcW := src.Bounds().Max.X
srcH := src.Bounds().Max.Y
dstW := srcH
dstH := srcW
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
for dstX := 0; dstX < dstW; dstX++ {
srcX := dstH - dstY - 1
srcY := dstX
srcOff := srcY*src.Stride + srcX*4
dstOff := dstY*dst.Stride + dstX*4
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
}
}
})
return dst
}
// Rotate180 rotates the image 180 degrees counterclockwise and returns the transformed image.
func Rotate180(img image.Image) *image.NRGBA {
src := toNRGBA(img)
srcW := src.Bounds().Max.X
srcH := src.Bounds().Max.Y
dstW := srcW
dstH := srcH
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
for dstX := 0; dstX < dstW; dstX++ {
srcX := dstW - dstX - 1
srcY := dstH - dstY - 1
srcOff := srcY*src.Stride + srcX*4
dstOff := dstY*dst.Stride + dstX*4
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
}
}
})
return dst
}
// Rotate270 rotates the image 270 degrees counterclockwise and returns the transformed image.
func Rotate270(img image.Image) *image.NRGBA {
src := toNRGBA(img)
srcW := src.Bounds().Max.X
srcH := src.Bounds().Max.Y
dstW := srcH
dstH := srcW
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
for dstX := 0; dstX < dstW; dstX++ {
srcX := dstY
srcY := dstW - dstX - 1
srcOff := srcY*src.Stride + srcX*4
dstOff := dstY*dst.Stride + dstX*4
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
}
}
})
return dst
}
// FlipH flips the image horizontally (from left to right) and returns the transformed image.
func FlipH(img image.Image) *image.NRGBA {
src := toNRGBA(img)
srcW := src.Bounds().Max.X
srcH := src.Bounds().Max.Y
dstW := srcW
dstH := srcH
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
for dstX := 0; dstX < dstW; dstX++ {
srcX := dstW - dstX - 1
srcY := dstY
srcOff := srcY*src.Stride + srcX*4
dstOff := dstY*dst.Stride + dstX*4
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
}
}
})
return dst
}
// FlipV flips the image vertically (from top to bottom) and returns the transformed image.
func FlipV(img image.Image) *image.NRGBA {
src := toNRGBA(img)
srcW := src.Bounds().Max.X
srcH := src.Bounds().Max.Y
dstW := srcW
dstH := srcH
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
for dstX := 0; dstX < dstW; dstX++ {
srcX := dstX
srcY := dstH - dstY - 1
srcOff := srcY*src.Stride + srcX*4
dstOff := dstY*dst.Stride + dstX*4
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
}
}
})
return dst
}
// Transpose flips the image horizontally and rotates 90 degrees counter-clockwise.
func Transpose(img image.Image) *image.NRGBA {
src := toNRGBA(img)
srcW := src.Bounds().Max.X
srcH := src.Bounds().Max.Y
dstW := srcH
dstH := srcW
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
for dstX := 0; dstX < dstW; dstX++ {
srcX := dstY
srcY := dstX
srcOff := srcY*src.Stride + srcX*4
dstOff := dstY*dst.Stride + dstX*4
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
}
}
})
return dst
}
// Transverse flips the image vertically and rotates 90 degrees counter-clockwise.
func Transverse(img image.Image) *image.NRGBA {
src := toNRGBA(img)
srcW := src.Bounds().Max.X
srcH := src.Bounds().Max.Y
dstW := srcH
dstH := srcW
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
parallel(dstH, func(partStart, partEnd int) {
for dstY := partStart; dstY < partEnd; dstY++ {
for dstX := 0; dstX < dstW; dstX++ {
srcX := dstH - dstY - 1
srcY := dstW - dstX - 1
srcOff := srcY*src.Stride + srcX*4
dstOff := dstY*dst.Stride + dstX*4
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
}
}
})
return dst
}

View File

@ -0,0 +1,261 @@
package imaging
import (
"image"
"testing"
)
func TestRotate90(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"Rotate90 2x3",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 2),
Stride: 3 * 4,
Pix: []uint8{
0xcc, 0xdd, 0xee, 0xff, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff,
0x00, 0x11, 0x22, 0x33, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00,
},
},
},
}
for _, d := range td {
got := Rotate90(d.src)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestRotate180(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"Rotate180 2x3",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 3),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0x00,
0x00, 0xff, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00,
0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33,
},
},
},
}
for _, d := range td {
got := Rotate180(d.src)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestRotate270(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"Rotate270 2x3",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 2),
Stride: 3 * 4,
Pix: []uint8{
0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x11, 0x22, 0x33,
0x00, 0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xcc, 0xdd, 0xee, 0xff,
},
},
},
}
for _, d := range td {
got := Rotate270(d.src)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestFlipV(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"FlipV 2x3",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 3),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
},
},
},
}
for _, d := range td {
got := FlipV(d.src)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestFlipH(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"FlipH 2x3",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 2, 3),
Stride: 2 * 4,
Pix: []uint8{
0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33,
0x00, 0xff, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0x00,
},
},
},
}
for _, d := range td {
got := FlipH(d.src)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestTranspose(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"Transpose 2x3",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 2),
Stride: 3 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00,
0xcc, 0xdd, 0xee, 0xff, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
},
}
for _, d := range td {
got := Transpose(d.src)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}
func TestTransverse(t *testing.T) {
td := []struct {
desc string
src image.Image
want *image.NRGBA
}{
{
"Transverse 2x3",
&image.NRGBA{
Rect: image.Rect(-1, -1, 1, 2),
Stride: 2 * 4,
Pix: []uint8{
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
},
},
&image.NRGBA{
Rect: image.Rect(0, 0, 3, 2),
Stride: 3 * 4,
Pix: []uint8{
0x00, 0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xcc, 0xdd, 0xee, 0xff,
0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x11, 0x22, 0x33,
},
},
},
}
for _, d := range td {
got := Transverse(d.src)
want := d.want
if !compareNRGBA(got, want, 0) {
t.Errorf("test [%s] failed: %#v", d.desc, got)
}
}
}

View File

@ -0,0 +1,77 @@
package imaging
import (
"math"
"runtime"
"sync"
"sync/atomic"
)
var parallelizationEnabled = true
// if GOMAXPROCS = 1: no goroutines used
// if GOMAXPROCS > 1: spawn N=GOMAXPROCS workers in separate goroutines
func parallel(dataSize int, fn func(partStart, partEnd int)) {
numGoroutines := 1
partSize := dataSize
if parallelizationEnabled {
numProcs := runtime.GOMAXPROCS(0)
if numProcs > 1 {
numGoroutines = numProcs
partSize = dataSize / (numGoroutines * 10)
if partSize < 1 {
partSize = 1
}
}
}
if numGoroutines == 1 {
fn(0, dataSize)
} else {
var wg sync.WaitGroup
wg.Add(numGoroutines)
idx := uint64(0)
for p := 0; p < numGoroutines; p++ {
go func() {
defer wg.Done()
for {
partStart := int(atomic.AddUint64(&idx, uint64(partSize))) - partSize
if partStart >= dataSize {
break
}
partEnd := partStart + partSize
if partEnd > dataSize {
partEnd = dataSize
}
fn(partStart, partEnd)
}
}()
}
wg.Wait()
}
}
func absint(i int) int {
if i < 0 {
return -i
}
return i
}
// clamp & round float64 to uint8 (0..255)
func clamp(v float64) uint8 {
return uint8(math.Min(math.Max(v, 0.0), 255.0) + 0.5)
}
// clamp int32 to uint8 (0..255)
func clampint32(v int32) uint8 {
if v < 0 {
return 0
} else if v > 255 {
return 255
}
return uint8(v)
}

View File

@ -0,0 +1,61 @@
package imaging
import (
"runtime"
"testing"
)
func testParallelN(enabled bool, n, procs int) bool {
data := make([]bool, n)
before := runtime.GOMAXPROCS(0)
runtime.GOMAXPROCS(procs)
parallel(n, func(start, end int) {
for i := start; i < end; i++ {
data[i] = true
}
})
for i := 0; i < n; i++ {
if data[i] != true {
return false
}
}
runtime.GOMAXPROCS(before)
return true
}
func TestParallel(t *testing.T) {
for _, e := range []bool{true, false} {
for _, n := range []int{1, 10, 100, 1000} {
for _, p := range []int{1, 2, 4, 8, 16, 100} {
if testParallelN(e, n, p) != true {
t.Errorf("test [parallel %v %d %d] failed", e, n, p)
}
}
}
}
}
func TestClamp(t *testing.T) {
td := []struct {
f float64
u uint8
}{
{0, 0},
{255, 255},
{128, 128},
{0.49, 0},
{0.50, 1},
{254.9, 255},
{254.0, 254},
{256, 255},
{2500, 255},
{-10, 0},
{127.6, 128},
}
for _, d := range td {
if clamp(d.f) != d.u {
t.Errorf("test [clamp %v %v] failed: %v", d.f, d.u, clamp(d.f))
}
}
}

View File

@ -0,0 +1,19 @@
Copyright (c) 2012 Brad Rydzewski
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -0,0 +1,100 @@
# routes.go
a simple http routing API for the Go programming language
go get github.com/drone/routes
for more information see:
http://gopkgdoc.appspot.com/pkg/github.com/bradrydzewski/routes
[![](https://drone.io/drone/routes/status.png)](https://drone.io/drone/routes/latest)
## Getting Started
package main
import (
"fmt"
"github.com/drone/routes"
"net/http"
)
func Whoami(w http.ResponseWriter, r *http.Request) {
params := r.URL.Query()
lastName := params.Get(":last")
firstName := params.Get(":first")
fmt.Fprintf(w, "you are %s %s", firstName, lastName)
}
func main() {
mux := routes.New()
mux.Get("/:last/:first", Whoami)
http.Handle("/", mux)
http.ListenAndServe(":8088", nil)
}
### Route Examples
You can create routes for all http methods:
mux.Get("/:param", handler)
mux.Put("/:param", handler)
mux.Post("/:param", handler)
mux.Patch("/:param", handler)
mux.Del("/:param", handler)
You can specify custom regular expressions for routes:
mux.Get("/files/:param(.+)", handler)
You can also create routes for static files:
pwd, _ := os.Getwd()
mux.Static("/static", pwd)
this will serve any files in `/static`, including files in subdirectories. For example `/static/logo.gif` or `/static/style/main.css`.
## Filters / Middleware
You can apply filters to routes, which is useful for enforcing security,
redirects, etc.
You can, for example, filter all request to enforce some type of security:
var FilterUser = func(w http.ResponseWriter, r *http.Request) {
if r.URL.User == nil || r.URL.User.Username() != "admin" {
http.Error(w, "", http.StatusUnauthorized)
}
}
r.Filter(FilterUser)
You can also apply filters only when certain REST URL Parameters exist:
r.Get("/:id", handler)
r.Filter("id", func(rw http.ResponseWriter, r *http.Request) {
...
})
## Helper Functions
You can use helper functions for serializing to Json and Xml. I found myself constantly writing code to serialize, set content type, content length, etc. Feel free to use these functions to eliminate redundant code in your app.
Helper function for serving Json, sets content type to `application/json`:
func handler(w http.ResponseWriter, r *http.Request) {
mystruct := { ... }
routes.ServeJson(w, &mystruct)
}
Helper function for serving Xml, sets content type to `application/xml`:
func handler(w http.ResponseWriter, r *http.Request) {
mystruct := { ... }
routes.ServeXml(w, &mystruct)
}
Helper function to serve Xml OR Json, depending on the value of the `Accept` header:
func handler(w http.ResponseWriter, r *http.Request) {
mystruct := { ... }
routes.ServeFormatted(w, r, &mystruct)
}

View File

@ -0,0 +1,78 @@
package bench
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/drone/routes"
gorilla "code.google.com/p/gorilla/mux"
"github.com/bmizerany/pat"
)
func HandlerOk(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "hello world")
w.WriteHeader(http.StatusOK)
}
// Benchmark_Routes runs a benchmark against our custom Mux using the
// default settings.
func Benchmark_Routes(b *testing.B) {
handler := routes.New()
handler.Get("/person/:last/:first", HandlerOk)
for i := 0; i < b.N; i++ {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
}
}
// Benchmark_Web runs a benchmark against the pat.go Mux using the
// default settings.
func Benchmark_Pat(b *testing.B) {
m := pat.New()
m.Get("/person/:last/:first", http.HandlerFunc(HandlerOk))
for i := 0; i < b.N; i++ {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
m.ServeHTTP(w, r)
}
}
// Benchmark_Gorilla runs a benchmark against the Gorilla Mux using
// the default settings.
func Benchmark_GorillaHandler(b *testing.B) {
handler := gorilla.NewRouter()
handler.HandleFunc("/person/{last}/{first}", HandlerOk)
for i := 0; i < b.N; i++ {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
}
}
// Benchmark_ServeMux runs a benchmark against the ServeMux Go function.
// We use this to determine performance impact of our library, when compared
// to the out-of-the-box Mux provided by Go.
func Benchmark_ServeMux(b *testing.B) {
mux := http.NewServeMux()
mux.HandleFunc("/", HandlerOk)
for i := 0; i < b.N; i++ {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, r)
}
}

View File

@ -0,0 +1,38 @@
/*
Package routes a simple http routing API for the Go programming language,
compatible with the standard http.ListenAndServe function.
Create a new route multiplexer:
mux := routes.New()
Define a simple route with a given method (ie Get, Put, Post ...), path and
http.HandleFunc.
mux.Get("/foo", fooHandler)
Define a route with restful parameters in the path:
mux.Get("/:foo/:bar", func(w http.ResponseWriter, r *http.Request) {
params := r.URL.Query()
foo := params.Get(":foo")
bar := params.Get(":bar")
fmt.Fprintf(w, "%s %s", foo, bar)
})
The parameters are parsed from the URL, and appended to the Request URL's
query parameters.
More control over the route's parameter matching is possible by providing
a custom regular expression:
mux.Get("/files/:file(.+)", handler)
To start the web server, use the standard http.ListenAndServe
function, and provide the route multiplexer:
http.Handle("/", mux)
http.ListenAndServe(":8000", nil)
*/
package routes

View File

@ -0,0 +1,107 @@
# routes.go
a simple http routing API for the Go programming language
go get github.com/drone/routes
for more information see:
http://gopkgdoc.appspot.com/pkg/github.com/drone/routes
[![](https://drone.io/drone/routes/status.png)](https://drone.io/drone/routes/latest)
## Getting Started
package main
import (
"fmt"
"github.com/drone/routes"
"net/http"
)
func foobar (w http.ResponseWriter, r *http.Request) {
c := routes.NewContext(r)
foo := c.Params.Get(":foo")
bar := c.Params.Get(":bar")
fmt.Fprintf(w, "%s %s", foo, bar)
}
func main() {
r := routes.NewRouter()
r.Get("/:bar/:foo", foobar)
http.Handle("/", r)
http.ListenAndServe(":8088", nil)
}
### Route Examples
You can create routes for all http methods:
r.Get("/:param", handler)
r.Put("/:param", handler)
r.Post("/:param", handler)
r.Patch("/:param", handler)
r.Del("/:param", handler)
You can specify custom regular expressions for routes:
r.Get("/files/:param(.+)", handler)
You can also create routes for static files:
pwd, _ := os.Getwd()
r.Static("/static", pwd)
this will serve any files in `/static`, including files in subdirectories. For
example `/static/logo.gif` or `/static/style/main.css`.
## Filters / Middleware
You can implement route filters to do things like enforce security, set session
variables, etc
You can, for example, filter all request to enforce some type of security:
r.Filter(func(rw http.ResponseWriter, r *http.Request) {
if r.URL.User != "admin" {
http.Error(w, "", http.StatusForbidden)
}
})
You can also apply filters only when certain REST URL Parameters exist:
r.Get("/:id", handler)
r.Filter("id", func(rw http.ResponseWriter, r *http.Request) {
c := routes.NewContext(r)
id := c.Params.Get("id")
// verify the user has access to the specified resource id
user := r.URL.User.Username()
if HasAccess(user, id) == false {
http.Error(w, "", http.StatusForbidden)
}
})
## Helper Functions
You can use helper functions for serializing to Json and Xml. I found myself
constantly writing code to serialize, set content type, content length, etc.
Feel free to use these functions to eliminate redundant code in your app.
Helper function for serving Json, sets content type to `application/json`:
func handler(w http.ResponseWriter, r *http.Request) {
mystruct := { ... }
routes.ServeJson(w, &mystruct)
}
Helper function for serving Xml, sets content type to `application/xml`:
func handler(w http.ResponseWriter, r *http.Request) {
mystruct := { ... }
routes.ServeXml(w, &mystruct)
}
Helper function to serve Xml OR Json, depending on the value of the `Accept` header:
func handler(w http.ResponseWriter, r *http.Request) {
mystruct := { ... }
routes.ServeFormatted(w, r, &mystruct)
}

View File

@ -0,0 +1,132 @@
package context
import (
"io"
"net/http"
)
// Context stores data for the duration of the http.Request
type Context struct {
// named parameters that are passed in via RESTful URL Parameters
Params Params
// named attributes that persist for the lifetime of the request
Values Values
// reference to the parent http.Request
req *http.Request
}
// Retruns the Context associated with the http.Request.
func Get(r *http.Request) *Context {
// get the context bound to the http.Request
if v, ok := r.Body.(*wrapper); ok {
return v.context
}
// create a new context
c := Context{ }
c.Params = make(Params)
c.Values = make(Values)
c.req = r
// wrap the request and bind the context
wrapper := wrap(r)
wrapper.context = &c
return &c
}
// Retruns the parent http.Request to which the context is bound.
func (c *Context) Request() *http.Request {
return c.req
}
// wrapper decorates an http.Request's Body (io.ReadCloser) so that we can
// bind a Context to the Request. This is obviously a hack that i'd rather
// avoid, however, it is for the greater good ...
//
// NOTE: If this turns out to be a really stupid approach we can use this
// approach from the go mailing list: http://goo.gl/Vw13f which I
// avoided because I didn't want a global lock
type wrapper struct {
body io.ReadCloser // the original message body
context *Context
}
func wrap(r *http.Request) *wrapper {
w := wrapper{ body: r.Body }
r.Body = &w
return &w
}
func (w *wrapper) Read(p []byte) (n int, err error) {
return w.body.Read(p)
}
func (w *wrapper) Close() error {
return w.body.Close()
}
// Parameter Map ---------------------------------------------------------------
// Params maps a string key to a list of values.
type Params map[string]string
// Get gets the first value associated with the given key. If there are
// no values associated with the key, Get returns the empty string.
func (p Params) Get(key string) string {
if p == nil {
return ""
}
return p[key]
}
// Set sets the key to value. It replaces any existing values.
func (p Params) Set(key, value string) {
p[key] = value
}
// Del deletes the values associated with key.
func (p Params) Del(key string) {
delete(p, key)
}
// Value Map -------------------------------------------------------------------
// Values maps a string key to a list of values.
type Values map[interface{}]interface{}
// Get gets the value associated with the given key. If there are
// no values associated with the key, Get returns nil.
func (v Values) Get(key interface{}) interface{} {
if v == nil {
return nil
}
return v[key]
}
// GetStr gets the value associated with the given key in string format.
// If there are no values associated with the key, Get returns an
// empty string.
func (v Values) GetStr(key interface{}) interface{} {
if v == nil { return "" }
val := v.Get(key)
if val == nil { return "" }
str, ok := val.(string)
if !ok { return "" }
return str
}
// Set sets the key to value. It replaces any existing values.
func (v Values) Set(key, value interface{}) {
v[key] = value
}
// Del deletes the values associated with key.
func (v Values) Del(key interface{}) {
delete(v, key)
}

View File

@ -0,0 +1,19 @@
Copyright (c) 2011 Dmitry Chestnykh
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -0,0 +1,99 @@
Package authcookie
=====================
import "github.com/dchest/authcookie"
Package authcookie implements creation and verification of signed
authentication cookies.
Cookie is a Base64 encoded (using URLEncoding, from RFC 4648) string, which
consists of concatenation of expiration time, login, and signature:
expiration time || login || signature
where expiration time is the number of seconds since Unix epoch UTC
indicating when this cookie must expire (4 bytes, big-endian, uint32), login
is a byte string of arbitrary length (at least 1 byte, not null-terminated),
and signature is 32 bytes of HMAC-SHA256(expiration_time || login, k), where
k = HMAC-SHA256(expiration_time || login, secret key).
Example:
secret := []byte("my secret key")
// Generate cookie valid for 24 hours for user "bender"
cookie := authcookie.NewSinceNow("bender", 24 * time.Hour, secret)
// cookie is now:
// Tajh02JlbmRlcskYMxowgwPj5QZ94jaxhDoh3n0Yp4hgGtUpeO0YbMTY
// send it to user's browser..
// To authenticate a user later, receive cookie and:
login := authcookie.Login(cookie, secret)
if login != "" {
// access for login granted
} else {
// access denied
}
Note that login and expiration time are not encrypted, they are only signed
and Base64 encoded.
Variables
---------
var (
ErrMalformedCookie = errors.New("malformed cookie")
ErrWrongSignature = errors.New("wrong cookie signature")
)
var MinLength = base64.URLEncoding.EncodedLen(decodedMinLength)
MinLength is the minimum allowed length of cookie string.
It is useful for avoiding DoS attacks with too long cookies: before passing
a cookie to Parse or Login functions, check that it has length less than the
[maximum login length allowed in your application] + MinLength.
Functions
---------
### func Login
func Login(cookie string, secret []byte) string
Login returns a valid login extracted from the given cookie and verified
using the given secret key. If verification fails or the cookie expired,
the function returns an empty string.
### func New
func New(login string, expires time.Time, secret []byte) string
New returns a signed authentication cookie for the given login,
expiration time, and secret key.
If the login is empty, the function returns an empty string.
### func NewSinceNow
func NewSinceNow(login string, dur time.Duration, secret []byte) string
NewSinceNow returns a signed authetication cookie for the given login,
duration time since current time, and secret key.
### func Parse
func Parse(cookie string, secret []byte) (login string, expires time.Time, err error)
Parse verifies the given cookie with the secret key and returns login and
expiration time extracted from the cookie. If the cookie fails verification
or is not well-formed, the function returns an error.
Callers must:
1. Check for the returned error and deny access if it's present.
2. Check the returned expiration time and deny access if it's in the past.

View File

@ -0,0 +1,154 @@
// Package authcookie implements creation and verification of signed
// authentication cookies.
//
// Cookie is a Base64 encoded (using URLEncoding, from RFC 4648) string, which
// consists of concatenation of expiration time, login, and signature:
//
// expiration time || login || signature
//
// where expiration time is the number of seconds since Unix epoch UTC
// indicating when this cookie must expire (4 bytes, big-endian, uint32), login
// is a byte string of arbitrary length (at least 1 byte, not null-terminated),
// and signature is 32 bytes of HMAC-SHA256(expiration_time || login, k), where
// k = HMAC-SHA256(expiration_time || login, secret key).
//
// Example:
//
// secret := []byte("my secret key")
//
// // Generate cookie valid for 24 hours for user "bender"
// cookie := authcookie.NewSinceNow("bender", 24 * time.Hour, secret)
//
// // cookie is now:
// // Tajh02JlbmRlcskYMxowgwPj5QZ94jaxhDoh3n0Yp4hgGtUpeO0YbMTY
// // send it to user's browser..
//
// // To authenticate a user later, receive cookie and:
// login := authcookie.Login(cookie, secret)
// if login != "" {
// // access for login granted
// } else {
// // access denied
// }
//
// Note that login and expiration time are not encrypted, they are only signed
// and Base64 encoded.
//
// For safety, the maximum length of base64-decoded cookie is limited to 1024
// bytes.
package authcookie
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"errors"
"time"
)
const (
decodedMinLength = 4 /*expiration*/ + 1 /*login*/ + 32 /*signature*/
decodedMaxLength = 1024 /* maximum decoded length, for safety */
)
// MinLength is the minimum allowed length of cookie string.
//
// It is useful for avoiding DoS attacks with too long cookies: before passing
// a cookie to Parse or Login functions, check that it has length less than the
// [maximum login length allowed in your application] + MinLength.
var MinLength = base64.URLEncoding.EncodedLen(decodedMinLength)
func getSignature(b []byte, secret []byte) []byte {
keym := hmac.New(sha256.New, secret)
keym.Write(b)
m := hmac.New(sha256.New, keym.Sum(nil))
m.Write(b)
return m.Sum(nil)
}
var (
ErrMalformedCookie = errors.New("malformed cookie")
ErrWrongSignature = errors.New("wrong cookie signature")
)
// New returns a signed authentication cookie for the given login,
// expiration time, and secret key.
// If the login is empty, the function returns an empty string.
func New(login string, expires time.Time, secret []byte) string {
if login == "" {
return ""
}
llen := len(login)
b := make([]byte, llen+4+32)
// Put expiration time.
binary.BigEndian.PutUint32(b, uint32(expires.Unix()))
// Put login.
copy(b[4:], []byte(login))
// Calculate and put signature.
sig := getSignature([]byte(b[:4+llen]), secret)
copy(b[4+llen:], sig)
// Base64-encode.
return base64.URLEncoding.EncodeToString(b)
}
// NewSinceNow returns a signed authetication cookie for the given login,
// duration since current time, and secret key.
func NewSinceNow(login string, dur time.Duration, secret []byte) string {
return New(login, time.Now().Add(dur), secret)
}
// Parse verifies the given cookie with the secret key and returns login and
// expiration time extracted from the cookie. If the cookie fails verification
// or is not well-formed, the function returns an error.
//
// Callers must:
//
// 1. Check for the returned error and deny access if it's present.
//
// 2. Check the returned expiration time and deny access if it's in the past.
//
func Parse(cookie string, secret []byte) (login string, expires time.Time, err error) {
blen := base64.URLEncoding.DecodedLen(len(cookie))
// Avoid allocation if cookie is too short or too long.
if blen < decodedMinLength || blen > decodedMaxLength {
err = ErrMalformedCookie
return
}
b, err := base64.URLEncoding.DecodeString(cookie)
if err != nil {
return
}
// Decoded length may be different from max length, which
// we allocated, so check it, and set new length for b.
blen = len(b)
if blen < decodedMinLength {
err = ErrMalformedCookie
return
}
b = b[:blen]
sig := b[blen-32:]
data := b[:blen-32]
realSig := getSignature(data, secret)
if subtle.ConstantTimeCompare(realSig, sig) != 1 {
err = ErrWrongSignature
return
}
expires = time.Unix(int64(binary.BigEndian.Uint32(data[:4])), 0)
login = string(data[4:])
return
}
// Login returns a valid login extracted from the given cookie and verified
// using the given secret key. If verification fails or the cookie expired,
// the function returns an empty string.
func Login(cookie string, secret []byte) string {
l, exp, err := Parse(cookie, secret)
if err != nil || exp.Before(time.Now()) {
return ""
}
return l
}

View File

@ -0,0 +1,78 @@
package authcookie
import (
"testing"
"time"
)
func TestNew(t *testing.T) {
secret := []byte("secret key")
good := "AAAAKmhlbGxvIHdvcmxk9p6koQvSacAeliAm445i7errSk1NPkYJGYZhF93wG9U="
c := New("hello world", time.Unix(42, 0), secret)
if c != good {
t.Errorf("expected %q, got %q", good, c)
}
// Test empty login
c = New("", time.Unix(42, 0), secret)
if c != "" {
t.Errorf(`allowed empty login: got %q, expected ""`, c)
}
}
func TestParse(t *testing.T) {
// good
sec := time.Now()
login := "bender"
key := []byte("another secret key")
c := New(login, sec, key)
l, e, err := Parse(c, key)
if err != nil {
t.Errorf("error parsing valid cookie: %s", err)
}
if l != login {
t.Errorf("login: expected %q, got %q", login, l)
}
// NOTE: nanos are discarded internally since only 4 bytes of timestamp are used
// so we can only compare seconds here
if e.Unix() != sec.Unix() {
t.Errorf("expiration: expected %v, got %v", sec, e)
}
// bad
key = []byte("secret key")
bad := []string{
"",
"AAAAKvgQ2I_RGePVk9oAu55q-Valnf__Fx_hlTM-dLwYxXOf",
"badcookie",
"AAAAAKmhlbGxvIHdvcmxk9p6koQvSacAeliAm445i7errSk1NPkYJGYZhF93wG9U=",
"zAAAKmhlbGxvIHdvcmxk9p6koQvSacAeliAm445i7errSk1NPkYJGYZhF93wG9U=",
"AAAAAKmhlbGxvIHdvcmxk9p6kiQvSacAeliAm445i7errSk1NPkYJGYZhF93wG9U=",
}
for _, v := range bad {
_, _, err := Parse(v, key)
if err == nil {
t.Errorf("bad cookie didn't return error: %q", v)
}
}
}
func TestLogin(t *testing.T) {
login := "~~~!|zoidberg|!~~~"
key := []byte("(:€")
exp := time.Now().Add(time.Second * 120)
c := New(login, exp, key)
l := Login(c, key)
if l != login {
t.Errorf("login: expected %q, got %q", login, l)
}
c = "no" + c
l = Login(c, key)
if l != "" {
t.Errorf("login expected empty string, got %q", l)
}
exp = time.Now().Add(-(time.Second * 30))
c = New(login, exp, key)
l = Login(c, key)
if l != "" {
t.Errorf("returned login from expired cookie")
}
}

View File

@ -0,0 +1,53 @@
package cookie
import (
"net/http"
"time"
"github.com/drone/routes/exp/cookie/authcookie"
)
// Sign signs and timestamps a cookie so it cannot be forged.
func Sign(cookie *http.Cookie, secret string, expires time.Time) {
val := SignStr(cookie.Value, secret, expires)
cookie.Value = val
}
// SignStr signs and timestamps a string so it cannot be forged.
//
// Normally used via Sign, but provided as a separate method for
// non-cookie uses. To decode a value not stored as a cookie use the
// DecodeStr function.
func SignStr(value, secret string, expires time.Time) string {
return authcookie.New(value, expires, []byte(secret))
}
// DecodeStr returns the given signed cookie value if it validates,
// else returns an empty string.
func Decode(cookie *http.Cookie, secret string) string {
return DecodeStr(cookie.Value, secret)
}
// DecodeStr returns the given signed value if it validates,
// else returns an empty string.
func DecodeStr(value, secret string) string {
return authcookie.Login(value, []byte(secret))
}
// Clear deletes the cookie with the given name.
func Clear(w http.ResponseWriter, r *http.Request, name string) {
cookie := http.Cookie{
Name: name,
Value: "deleted",
Path: "/",
Domain: r.URL.Host,
MaxAge: -1,
}
http.SetCookie(w, &cookie)
}

View File

@ -0,0 +1,241 @@
package router
import (
"bufio"
"net"
"net/http"
"path/filepath"
"regexp"
"strings"
"sync"
"github.com/drone/routes/exp/context"
)
const (
DELETE = "DELETE"
GET = "GET"
HEAD = "HEAD"
OPTIONS = "OPTIONS"
PATCH = "PATCH"
POST = "POST"
PUT = "PUT"
)
type route struct {
method string
regex *regexp.Regexp
params map[int]string
handler http.HandlerFunc
}
type Router struct {
sync.RWMutex
routes []*route
filters []http.HandlerFunc
params map[string]interface{}
}
func New() *Router {
r := Router{}
r.params = make(map[string]interface{})
return &r
}
// Get adds a new Route for GET requests.
func (r *Router) Get(pattern string, handler http.HandlerFunc) {
r.AddRoute(GET, pattern, handler)
}
// Put adds a new Route for PUT requests.
func (r *Router) Put(pattern string, handler http.HandlerFunc) {
r.AddRoute(PUT, pattern, handler)
}
// Del adds a new Route for DELETE requests.
func (r *Router) Del(pattern string, handler http.HandlerFunc) {
r.AddRoute(DELETE, pattern, handler)
}
// Patch adds a new Route for PATCH requests.
func (r *Router) Patch(pattern string, handler http.HandlerFunc) {
r.AddRoute(PATCH, pattern, handler)
}
// Post adds a new Route for POST requests.
func (r *Router) Post(pattern string, handler http.HandlerFunc) {
r.AddRoute(POST, pattern, handler)
}
// Adds a new Route for Static http requests. Serves
// static files from the specified directory
func (r *Router) Static(pattern string, dir string) {
//append a regex to the param to match everything
// that comes after the prefix
pattern = pattern + "(.+)"
r.Get(pattern, func(w http.ResponseWriter, req *http.Request) {
path := filepath.Clean(req.URL.Path)
path = filepath.Join(dir, path)
http.ServeFile(w, req, path)
})
}
// Adds a new Route to the Handler
func (r *Router) AddRoute(method string, pattern string, handler http.HandlerFunc) {
r.Lock()
defer r.Unlock()
//split the url into sections
parts := strings.Split(pattern, "/")
//find params that start with ":"
//replace with regular expressions
j := 0
params := make(map[int]string)
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "([^/]+)"
//a user may choose to override the defult expression
// similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 {
expr = part[index:]
part = part[:index]
}
params[j] = part[1:]
parts[i] = expr
j++
}
}
//recreate the url pattern, with parameters replaced
//by regular expressions. then compile the regex
pattern = strings.Join(parts, "/")
regex := regexp.MustCompile(pattern)
route := &route{
method : method,
regex : regex,
handler : handler,
params : params,
}
//append to the list of Routes
r.routes = append(r.routes, route)
}
// Filter adds the middleware filter.
func (r *Router) Filter(filter http.HandlerFunc) {
r.Lock()
r.filters = append(r.filters, filter)
r.Unlock()
}
// FilterParam adds the middleware filter iff the URL parameter exists.
func (r *Router) FilterParam(param string, filter http.HandlerFunc) {
r.Filter(func(w http.ResponseWriter, req *http.Request) {
c := context.Get(req)
if len(c.Params.Get(param)) > 0 { filter(w, req) }
})
}
// FilterPath adds the middleware filter iff the path matches the request.
func (r *Router) FilterPath(path string, filter http.HandlerFunc) {
pattern := path
pattern = strings.Replace(pattern, "*", "(.+)", -1)
pattern = strings.Replace(pattern, "**", "([^/]+)", -1)
regex := regexp.MustCompile(pattern)
r.Filter(func(w http.ResponseWriter, req *http.Request) {
if regex.MatchString(req.URL.Path) { filter(w, req) }
})
}
// Required by http.Handler interface. This method is invoked by the
// http server and will handle all page routing
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
r.RLock()
defer r.RUnlock()
//wrap the response writer in our custom interface
w := &responseWriter{writer: rw, Router: r}
//find a matching Route
for _, route := range r.routes {
//if the methods don't match, skip this handler
//i.e if request.Method is 'PUT' Route.Method must be 'PUT'
if req.Method != route.method {
continue
}
//check if Route pattern matches url
if !route.regex.MatchString(req.URL.Path) {
continue
}
//get submatches (params)
matches := route.regex.FindStringSubmatch(req.URL.Path)
//double check that the Route matches the URL pattern.
if len(matches[0]) != len(req.URL.Path) {
continue
}
//create the http.Requests context
c := context.Get(req)
//add url parameters to the context
for i, match := range matches[1:] {
c.Params.Set(route.params[i], match)
}
//execute middleware filters
for _, filter := range r.filters {
filter(w, req)
if w.started { return }
}
//invoke the request handler
route.handler(w, req)
return
}
//if no matches to url, throw a not found exception
if w.started == false {
http.NotFound(w, req)
}
}
// responseWriter is a wrapper for the http.ResponseWriter to track if
// response was written to, and to store a reference to the router.
type responseWriter struct {
Router *Router
writer http.ResponseWriter
started bool
status int
}
// Header returns the header map that will be sent by WriteHeader.
func (w *responseWriter) Header() http.Header {
return w.writer.Header()
}
// Write writes the data to the connection as part of an HTTP reply,
// and sets `started` to true
func (w *responseWriter) Write(p []byte) (int, error) {
w.started = true
return w.writer.Write(p)
}
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true
func (w *responseWriter) WriteHeader(code int) {
w.status = code
w.started = true
w.writer.WriteHeader(code)
}
// The Hijacker interface is implemented by ResponseWriters that allow an
// HTTP handler to take over the connection.
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.writer.(http.Hijacker).Hijack()
}

View File

@ -0,0 +1,227 @@
package router
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/drone/routes/exp/context"
)
func HandlerOk(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "hello world")
w.WriteHeader(http.StatusOK)
}
func HandlerSetVar(w http.ResponseWriter, r *http.Request) {
c := context.Get(r)
c.Values.Set("password", "z1on")
}
func HandlerErr(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
// TestRouteOk tests that the route is correctly handled, and the URL parameters
// are added to the Context.
func TestRouteOk(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := New()
mux.Get("/person/:last/:first", HandlerOk)
mux.ServeHTTP(w, r)
c := context.Get(r)
lastNameParam := c.Params.Get("last")
firstNameParam := c.Params.Get("first")
if lastNameParam != "anderson" {
t.Errorf("url param set to [%s]; want [%s]", lastNameParam, "anderson")
}
if firstNameParam != "thomas" {
t.Errorf("url param set to [%s]; want [%s]", firstNameParam, "thomas")
}
if w.Body.String() != "hello world" {
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
}
}
// TestFilter tests that a route is filtered prior to handling
func TestRouteFilter(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := New()
mux.Filter(HandlerSetVar)
mux.Get("/person/:last/:first", HandlerOk)
mux.ServeHTTP(w, r)
c := context.Get(r)
password := c.Values.Get("password")
if password != "z1on" {
t.Errorf("session variable set to [%s]; want [%s]", password, "z1on")
}
if w.Body.String() != "hello world" {
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
}
}
// TestFilterHalt tests that a route is filtered prior to handling, and then
// halts execution (by writing to the response).
func TestRouteFilterHalt(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := New()
mux.Filter(HandlerErr)
mux.Get("/person/:last/:first", HandlerOk)
mux.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
}
if w.Body.String() == "hello world" {
t.Errorf("Body set to [%s]; want empty", w.Body.String())
}
}
// TestRouterFilterParam tests the Parameter filter, and ensures the
// filter is only executed when the specified Parameter exists.
func TestRouterFilterParam(t *testing.T) {
// in the first test scenario, the Parameter filter should not
// be triggered because the "codename" variab does not exist
r, _ := http.NewRequest("GET", "/neo", nil)
w := httptest.NewRecorder()
mux := New()
mux.Filter(HandlerSetVar)
mux.FilterParam("codename", HandlerErr)
mux.Get("/:nickname", HandlerOk)
mux.ServeHTTP(w, r)
if w.Body.String() != "hello world" {
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
}
// in this second scenario, the Parameter filter SHOULD fire, and should
// halt the request
w = httptest.NewRecorder()
mux = New()
mux.Filter(HandlerSetVar)
mux.FilterParam("codename", HandlerErr)
mux.Get("/:codename", HandlerOk)
mux.ServeHTTP(w, r)
if w.Body.String() == "hello world" {
t.Errorf("Body set to [%s]; want empty", w.Body.String())
}
if w.Code != 400 {
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
}
}
// TestRouterFilterPath tests the Path filter, and ensures the filter
// is only executed when the Request Path matches the filter Path.
func TestRouterFilterPath(t *testing.T) {
// in the first test scenario, the Path filter should not fire
// because it does not take the "first name" section of the URL
// into account, and should therefore not match
r, _ := http.NewRequest("GET", "/person/anderson/thomas", nil)
w := httptest.NewRecorder()
mux := New()
mux.FilterPath("/person/*/anderson", HandlerErr)
mux.Get("/person/:last/:first", HandlerOk)
mux.ServeHTTP(w, r)
if w.Body.String() != "hello world" {
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
}
// in this second scenario, the Parameter filter SHOULD fire because
// we are filtering on all "last names", and the pattern should match
// the first section of the URL (person) and the last section of the
// url (:first)
w = httptest.NewRecorder()
mux = New()
mux.FilterPath("/person/*/thomas", HandlerErr)
mux.Get("/person/:last/:first", HandlerOk)
mux.ServeHTTP(w, r)
if w.Body.String() == "hello world" {
t.Errorf("Body set to [%s]; want empty", w.Body.String())
}
if w.Code != 400 {
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
}
}
// TestNotFound tests that a 404 code is returned in the
// response if no route matches the request url.
func TestNotFound(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
mux := New()
mux.ServeHTTP(w, r)
if w.Code != http.StatusNotFound {
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusNotFound)
}
}
// Benchmark_Routes runs a benchmark against our custom Mux using the
// default settings.
func Benchmark_Routes(b *testing.B) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := New()
mux.Get("/person/:last/:first", HandlerOk)
for i := 0; i < b.N; i++ {
mux.ServeHTTP(w, r)
}
}
// Benchmark_Routes_x30 runs a benchmark against our custom Mux using the
// default settings, but with 30 routes
func Benchmark_Routes_x30(b *testing.B) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := New()
for i:=0;i<30;i++ {
mux.Get(fmt.Sprintf("/%v/:last/:first",i), HandlerOk)
}
// and we'll make the matching URL the LAST in the list
mux.Get("/person/:last/:first", HandlerOk)
for i := 0; i < b.N; i++ {
mux.ServeHTTP(w, r)
}
}
// Benchmark_ServeMux runs a benchmark against the ServeMux Go function.
// We use this to determine performance impact of our library, when compared
// to the out-of-the-box Mux provided by Go.
func Benchmark_ServeMux(b *testing.B) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := http.NewServeMux()
mux.HandleFunc("/", HandlerOk)
for i := 0; i < b.N; i++ {
r.URL.Query().Get("learn")
mux.ServeHTTP(w, r)
}
}

View File

@ -0,0 +1,107 @@
# routes.go
a simple http routing API for the Go programming language
go get github.com/drone/routes
for more information see:
http://gopkgdoc.appspot.com/pkg/github.com/drone/routes
[![](https://drone.io/drone/routes/status.png)](https://drone.io/drone/routes/latest)
## Getting Started
package main
import (
"fmt"
"github.com/drone/routes"
"net/http"
)
func foobar (w http.ResponseWriter, r *http.Request) {
c := routes.NewContext(r)
foo := c.Params.Get(":foo")
bar := c.Params.Get(":bar")
fmt.Fprintf(w, "%s %s", foo, bar)
}
func main() {
r := routes.NewRouter()
r.Get("/:bar/:foo", foobar)
http.Handle("/", r)
http.ListenAndServe(":8088", nil)
}
### Route Examples
You can create routes for all http methods:
r.Get("/:param", handler)
r.Put("/:param", handler)
r.Post("/:param", handler)
r.Patch("/:param", handler)
r.Del("/:param", handler)
You can specify custom regular expressions for routes:
r.Get("/files/:param(.+)", handler)
You can also create routes for static files:
pwd, _ := os.Getwd()
r.Static("/static", pwd)
this will serve any files in `/static`, including files in subdirectories. For
example `/static/logo.gif` or `/static/style/main.css`.
## Filters / Middleware
You can implement route filters to do things like enforce security, set session
variables, etc
You can, for example, filter all request to enforce some type of security:
r.Filter(func(rw http.ResponseWriter, r *http.Request) {
if r.URL.User != "admin" {
http.Error(w, "", http.StatusForbidden)
}
})
You can also apply filters only when certain REST URL Parameters exist:
r.Get("/:id", handler)
r.Filter("id", func(rw http.ResponseWriter, r *http.Request) {
c := routes.NewContext(r)
id := c.Params.Get("id")
// verify the user has access to the specified resource id
user := r.URL.User.Username()
if HasAccess(user, id) == false {
http.Error(w, "", http.StatusForbidden)
}
})
## Helper Functions
You can use helper functions for serializing to Json and Xml. I found myself
constantly writing code to serialize, set content type, content length, etc.
Feel free to use these functions to eliminate redundant code in your app.
Helper function for serving Json, sets content type to `application/json`:
func handler(w http.ResponseWriter, r *http.Request) {
mystruct := { ... }
routes.ServeJson(w, &mystruct)
}
Helper function for serving Xml, sets content type to `application/xml`:
func handler(w http.ResponseWriter, r *http.Request) {
mystruct := { ... }
routes.ServeXml(w, &mystruct)
}
Helper function to serve Xml OR Json, depending on the value of the `Accept` header:
func handler(w http.ResponseWriter, r *http.Request) {
mystruct := { ... }
routes.ServeFormatted(w, r, &mystruct)
}

View File

@ -0,0 +1,74 @@
package bench
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/drone/routes/exp/routes"
gorilla "code.google.com/p/gorilla/mux"
"github.com/bmizerany/pat"
)
func HandlerOk(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "hello world")
w.WriteHeader(http.StatusOK)
}
// Benchmark_Routes runs a benchmark against our custom Mux using the
// default settings.
func Benchmark_Routes(b *testing.B) {
handler := routes.NewRouter()
handler.Get("/person/:last/:first", HandlerOk)
for i := 0; i < b.N; i++ {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
}
}
// Benchmark_Web runs a benchmark against the pat.go Mux using the
// default settings.
func Benchmark_Pat(b *testing.B) {
m := pat.New()
m.Get("/person/:last/:first", http.HandlerFunc(HandlerOk))
for i := 0; i < b.N; i++ {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
m.ServeHTTP(w, r)
}
}
// Benchmark_Gorilla runs a benchmark against the Gorilla Mux using
// the default settings.
func Benchmark_GorillaHandler(b *testing.B) {
handler := gorilla.NewRouter()
handler.HandleFunc("/person/{last}/{first}", HandlerOk)
for i := 0; i < b.N; i++ {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
}
}
// Benchmark_ServeMux runs a benchmark against the ServeMux Go function.
// We use this to determine performance impact of our library, when compared
// to the out-of-the-box Mux provided by Go.
func Benchmark_ServeMux(b *testing.B) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := http.NewServeMux()
mux.HandleFunc("/", HandlerOk)
for i := 0; i < b.N; i++ {
mux.ServeHTTP(w, r)
}
}

View File

@ -0,0 +1,132 @@
package routes
import (
"io"
"net/http"
)
// Context stores data for the duration of the http.Request
type Context struct {
// named parameters that are passed in via RESTful URL Parameters
Params Params
// named attributes that persist for the lifetime of the request
Values Values
// reference to the parent http.Request
req *http.Request
}
// Retruns the Context associated with the http.Request.
func NewContext(r *http.Request) *Context {
// get the context bound to the http.Request
if v, ok := r.Body.(*wrapper); ok {
return v.context
}
// create a new context
c := Context{ }
c.Params = make(Params)
c.Values = make(Values)
c.req = r
// wrap the request and bind the context
wrapper := wrap(r)
wrapper.context = &c
return &c
}
// Retruns the parent http.Request to which the context is bound.
func (c *Context) Request() *http.Request {
return c.req
}
// wrapper decorates an http.Request's Body (io.ReadCloser) so that we can
// bind a Context to the Request. This is obviously a hack that i'd rather
// avoid, however, it is for the greater good ...
//
// NOTE: If this turns out to be a really stupid approach we can use this
// approach from the go mailing list: http://goo.gl/Vw13f which I
// avoided because I didn't want a global lock
type wrapper struct {
body io.ReadCloser // the original message body
context *Context
}
func wrap(r *http.Request) *wrapper {
w := wrapper{ body: r.Body }
r.Body = &w
return &w
}
func (w *wrapper) Read(p []byte) (n int, err error) {
return w.body.Read(p)
}
func (w *wrapper) Close() error {
return w.body.Close()
}
// Parameter Map ---------------------------------------------------------------
// Params maps a string key to a list of values.
type Params map[string]string
// Get gets the first value associated with the given key. If there are
// no values associated with the key, Get returns the empty string.
func (p Params) Get(key string) string {
if p == nil {
return ""
}
return p[key]
}
// Set sets the key to value. It replaces any existing values.
func (p Params) Set(key, value string) {
p[key] = value
}
// Del deletes the values associated with key.
func (p Params) Del(key string) {
delete(p, key)
}
// Value Map -------------------------------------------------------------------
// Values maps a string key to a list of values.
type Values map[interface{}]interface{}
// Get gets the value associated with the given key. If there are
// no values associated with the key, Get returns nil.
func (v Values) Get(key interface{}) interface{} {
if v == nil {
return nil
}
return v[key]
}
// GetStr gets the value associated with the given key in string format.
// If there are no values associated with the key, Get returns an
// empty string.
func (v Values) GetStr(key interface{}) interface{} {
if v == nil { return "" }
val := v.Get(key)
if val == nil { return "" }
str, ok := val.(string)
if !ok { return "" }
return str
}
// Set sets the key to value. It replaces any existing values.
func (v Values) Set(key, value interface{}) {
v[key] = value
}
// Del deletes the values associated with key.
func (v Values) Del(key interface{}) {
delete(v, key)
}

View File

@ -0,0 +1,37 @@
/*
Package routes a simple http routing API for the Go programming language,
compatible with the standard http.ListenAndServe function.
Create a new route multiplexer:
r := routes.NewRouter()
Define a simple route with a given method (ie Get, Put, Post ...), path and
http.HandleFunc.
r.Get("/foo", fooHandler)
Define a route with restful parameters in the path:
r.Get("/:foo/:bar", func(rw http.ResponseWriter, req *http.Request) {
c := routes.NewContext(req)
foo := c.Params.Get("foo")
bar := c.Params.Get("bar")
fmt.Fprintf(rw, "%s %s", foo, bar)
})
The parameters are parsed from the URL, and stored in the Request Context.
More control over the route's parameter matching is possible by providing
a custom regular expression:
r.Get("/files/:file(.+)", handler)
To start the web server, use the standard http.ListenAndServe
function, and provide the route multiplexer:
http.Handle("/", r)
http.ListenAndServe(":8000", nil)
*/
package routes

View File

@ -0,0 +1,96 @@
package routes
import (
"bytes"
"encoding/json"
"encoding/xml"
"io/ioutil"
"net/http"
"strconv"
)
// Helper Functions to Read from the http.Request Body -------------------------
// ReadJson parses the JSON-encoded data in the http.Request object and
// stores the result in the value pointed to by v.
func ReadJson(r *http.Request, v interface{}) error {
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return err
}
return json.Unmarshal(body, v)
}
// ReadXml parses the XML-encoded data in the http.Request object and
// stores the result in the value pointed to by v.
func ReadXml(r *http.Request, v interface{}) error {
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return err
}
return xml.Unmarshal(body, v)
}
// Helper Functions to Write to the http.ReponseWriter -------------------------
// ServeJson writes the JSON representation of resource v to the
// http.ResponseWriter.
func ServeJson(w http.ResponseWriter, v interface{}) {
content, err := json.MarshalIndent(v, "", " ")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
w.Header().Set("Content-Type", "application/json")
w.Write(content)
}
// ServeXml writes the XML representation of resource v to the
// http.ResponseWriter.
func ServeXml(w http.ResponseWriter, v interface{}) {
content, err := xml.Marshal(v)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
w.Header().Set("Content-Type", "text/xml; charset=utf-8")
w.Write(content)
}
// ServeTemplate applies the named template to the specified data map and
// writes the output to the http.ResponseWriter.
func ServeTemplate(w http.ResponseWriter, name string, data map[string]interface{}) {
// cast the writer to the resposneWriter, get the router
r := w.(*responseWriter).Router
r.RLock()
defer r.RUnlock()
if data == nil {
data = map[string]interface{}{}
}
// append global params to the template
for k, v := range r.params {
data[k] = v
}
var buf bytes.Buffer
if err := r.views.ExecuteTemplate(&buf, name, data); err != nil {
panic(err)
return
}
// set the content length, type, etc
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Write(buf.Bytes())
}
// Error will terminate the http Request with the specified error code.
func Error(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}

View File

@ -0,0 +1,237 @@
package routes
import (
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"text/template"
)
const (
DELETE = "DELETE"
GET = "GET"
HEAD = "HEAD"
OPTIONS = "OPTIONS"
PATCH = "PATCH"
POST = "POST"
PUT = "PUT"
)
type route struct {
method string
regex *regexp.Regexp
params map[int]string
handler http.HandlerFunc
}
type Router struct {
sync.RWMutex
routes []*route
filters []http.HandlerFunc
views *template.Template
params map[string]interface{}
}
func NewRouter() *Router {
r := Router{}
r.params = make(map[string]interface{})
return &r
}
// Get adds a new Route for GET requests.
func (r *Router) Get(pattern string, handler http.HandlerFunc) {
r.AddRoute(GET, pattern, handler)
}
// Put adds a new Route for PUT requests.
func (r *Router) Put(pattern string, handler http.HandlerFunc) {
r.AddRoute(PUT, pattern, handler)
}
// Del adds a new Route for DELETE requests.
func (r *Router) Del(pattern string, handler http.HandlerFunc) {
r.AddRoute(DELETE, pattern, handler)
}
// Patch adds a new Route for PATCH requests.
func (r *Router) Patch(pattern string, handler http.HandlerFunc) {
r.AddRoute(PATCH, pattern, handler)
}
// Post adds a new Route for POST requests.
func (r *Router) Post(pattern string, handler http.HandlerFunc) {
r.AddRoute(POST, pattern, handler)
}
// Adds a new Route for Static http requests. Serves
// static files from the specified directory
func (r *Router) Static(pattern string, dir string) {
//append a regex to the param to match everything
// that comes after the prefix
pattern = pattern + "(.+)"
r.Get(pattern, func(w http.ResponseWriter, req *http.Request) {
path := filepath.Clean(req.URL.Path)
path = filepath.Join(dir, path)
http.ServeFile(w, req, path)
})
}
// Adds a new Route to the Handler
func (r *Router) AddRoute(method string, pattern string, handler http.HandlerFunc) {
r.Lock()
defer r.Unlock()
//split the url into sections
parts := strings.Split(pattern, "/")
//find params that start with ":"
//replace with regular expressions
j := 0
params := make(map[int]string)
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "([^/]+)"
//a user may choose to override the defult expression
// similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 {
expr = part[index:]
part = part[:index]
}
params[j] = part[1:]
parts[i] = expr
j++
}
}
//recreate the url pattern, with parameters replaced
//by regular expressions. then compile the regex
pattern = strings.Join(parts, "/")
regex, regexErr := regexp.Compile(pattern)
if regexErr != nil {
panic(regexErr)
}
route := &route{
method : method,
regex : regex,
handler : handler,
params : params,
}
//append to the list of Routes
r.routes = append(r.routes, route)
}
// Filter adds the middleware filter.
func (r *Router) Filter(filter http.HandlerFunc) {
r.Lock()
r.filters = append(r.filters, filter)
r.Unlock()
}
// FilterParam adds the middleware filter iff the URL parameter exists.
func (r *Router) FilterParam(param string, filter http.HandlerFunc) {
r.Filter(func(w http.ResponseWriter, req *http.Request) {
c := NewContext(req)
if len(c.Params.Get(param)) > 0 { filter(w, req) }
})
}
// Set stores the specified key / value pair.
func (r *Router) Set(name string, value interface{}) {
r.Lock()
r.params[name] = value
r.Unlock()
}
// SetEnv stores the specified environment variable as a key / value pair. If
// the environment variable is not set the default value will be used
func (r *Router) SetEnv(name, value string) {
r.Lock()
defer r.Unlock()
env := os.Getenv(name)
if len(env) == 0 { env = value }
r.Set(name, env)
}
// Required by http.Handler interface. This method is invoked by the
// http server and will handle all page routing
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
r.RLock()
defer r.RUnlock()
//wrap the response writer in our custom interface
w := &responseWriter{writer: rw, Router: r}
//find a matching Route
for _, route := range r.routes {
//if the methods don't match, skip this handler
//i.e if request.Method is 'PUT' Route.Method must be 'PUT'
if req.Method != route.method {
continue
}
//check if Route pattern matches url
if !route.regex.MatchString(req.URL.Path) {
continue
}
//get submatches (params)
matches := route.regex.FindStringSubmatch(req.URL.Path)
//double check that the Route matches the URL pattern.
if len(matches[0]) != len(req.URL.Path) {
continue
}
//create the http.Requests context
c := NewContext(req)
//add url parameters to the context
for i, match := range matches[1:] {
c.Params.Set(route.params[i], match)
}
//execute middleware filters
for _, filter := range r.filters {
filter(w, req)
if w.started { return }
}
//invoke the request handler
route.handler(w, req)
return
}
//if no matches to url, throw a not found exception
if w.started == false {
http.NotFound(w, req)
}
}
// Template uses the provided template definitions.
func (r *Router) Template(t *template.Template) {
r.Lock()
defer r.Unlock()
r.views = template.Must(t.Clone())
}
// TemplateFiles parses the template definitions from the named files.
func (r *Router) TemplateFiles(filenames ...string) {
r.Lock()
defer r.Unlock()
r.views = template.Must(template.ParseFiles(filenames...))
}
// TemplateGlob parses the template definitions from the files identified
// by the pattern, which must match at least one file.
func (r *Router) TemplateGlob(pattern string) {
r.Lock()
defer r.Unlock()
r.views = template.Must(template.ParseGlob(pattern))
}

View File

@ -0,0 +1,189 @@
package routes
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func HandlerOk(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "hello world")
w.WriteHeader(http.StatusOK)
}
func HandlerSetVar(w http.ResponseWriter, r *http.Request) {
c := NewContext(r)
c.Values.Set("password", "z1on")
}
func HandlerErr(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
// TestRouteOk tests that the route is correctly handled, and the URL parameters
// are added to the Context.
func TestRouteOk(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := NewRouter()
mux.Get("/person/:last/:first", HandlerOk)
mux.ServeHTTP(w, r)
c := NewContext(r)
lastNameParam := c.Params.Get("last")
firstNameParam := c.Params.Get("first")
if lastNameParam != "anderson" {
t.Errorf("url param set to [%s]; want [%s]", lastNameParam, "anderson")
}
if firstNameParam != "thomas" {
t.Errorf("url param set to [%s]; want [%s]", firstNameParam, "thomas")
}
if w.Body.String() != "hello world" {
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
}
}
// TestFilter tests that a route is filtered prior to handling
func TestRouteFilter(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := NewRouter()
mux.Filter(HandlerSetVar)
mux.Get("/person/:last/:first", HandlerOk)
mux.ServeHTTP(w, r)
c := NewContext(r)
password := c.Values.Get("password")
if password != "z1on" {
t.Errorf("session variable set to [%s]; want [%s]", password, "z1on")
}
if w.Body.String() != "hello world" {
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
}
}
// TestFilterHalt tests that a route is filtered prior to handling, and then
// halts execution (by writing to the response).
func TestRouteFilterHalt(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := NewRouter()
mux.Filter(HandlerErr)
mux.Get("/person/:last/:first", HandlerOk)
mux.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
}
if w.Body.String() == "hello world" {
t.Errorf("Body set to [%s]; want empty", w.Body.String())
}
}
// TestParam tests the Parameter filter, and ensures the filter is only
// executed when the specified Parameter exists.
func TestParam(t *testing.T) {
// in the first test scenario, the Parameter filter should not
// be triggered because the "codename" variab does not exist
r, _ := http.NewRequest("GET", "/neo", nil)
w := httptest.NewRecorder()
mux := NewRouter()
mux.Filter(HandlerSetVar)
mux.FilterParam("codename", HandlerErr)
mux.Get("/:nickname", HandlerOk)
mux.ServeHTTP(w, r)
if w.Body.String() != "hello world" {
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
}
// in this second scenario, the Parameter filter SHOULD fire, and should
// halt the request
w = httptest.NewRecorder()
mux = NewRouter()
mux.Filter(HandlerSetVar)
mux.FilterParam("codename", HandlerErr)
mux.Get("/:codename", HandlerOk)
mux.ServeHTTP(w, r)
if w.Body.String() == "hello world" {
t.Errorf("Body set to [%s]; want empty", w.Body.String())
}
if w.Code != 400 {
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
}
}
/*
// TestTemplate tests template rendering
func TestTemplate(t *testing.T) {
w := httptest.NewRecorder()
tmpl, _ := template.New("template.html").Parse("<html><head><title>{{ .Title }}</title><body>{{ .Name }}</body></html>")
mux := NewRouter()
mux.Template(tmpl)
mux.Set("Title", "Matrix")
mux.ExecuteTemplate(w, "template.html", map[string]interface{}{ "Name" : "Morpheus" })
if w.Body.String() != "<html><head><title>Matrix</title><body>Morpheus</body></html>" {
t.Errorf("template not rendered correctly [%s]", w.Body.String())
}
}
*/
// TestNotFound tests that a 404 code is returned in the
// response if no route matches the request url.
func TestNotFound(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
mux := NewRouter()
mux.ServeHTTP(w, r)
if w.Code != http.StatusNotFound {
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusNotFound)
}
}
// Benchmark_Routes runs a benchmark against our custom Mux using the
// default settings.
func Benchmark_Routes(b *testing.B) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := NewRouter()
mux.Get("/person/:last/:first", HandlerOk)
for i := 0; i < b.N; i++ {
mux.ServeHTTP(w, r)
}
}
// Benchmark_ServeMux runs a benchmark against the ServeMux Go function.
// We use this to determine performance impact of our library, when compared
// to the out-of-the-box Mux provided by Go.
func Benchmark_ServeMux(b *testing.B) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
mux := http.NewServeMux()
mux.HandleFunc("/", HandlerOk)
for i := 0; i < b.N; i++ {
r.URL.Query().Get("learn")
mux.ServeHTTP(w, r)
}
}

View File

@ -0,0 +1,42 @@
package routes
import (
"bufio"
"net"
"net/http"
)
// ResponseWriter is a wrapper for the http.ResponseWriter to track if
// response was written to.
type responseWriter struct {
Router *Router
writer http.ResponseWriter
started bool
status int
}
// Header returns the header map that will be sent by WriteHeader.
func (w *responseWriter) Header() http.Header {
return w.writer.Header()
}
// Write writes the data to the connection as part of an HTTP reply,
// and sets `started` to true
func (w *responseWriter) Write(p []byte) (int, error) {
w.started = true
return w.writer.Write(p)
}
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true
func (w *responseWriter) WriteHeader(code int) {
w.status = code
w.started = true
w.writer.WriteHeader(code)
}
// The Hijacker interface is implemented by ResponseWriters that allow an
// HTTP handler to take over the connection.
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.writer.(http.Hijacker).Hijack()
}

View File

@ -0,0 +1,84 @@
package user
import (
"net/url"
"github.com/drone/routes/exp/context"
)
// Key used to store the user in the session
const userKey = "_user"
// User represents a user of the application.
type User struct {
Id string // the unique permanent ID of the user.
Name string // the human-readable ID of the user.
Email string
Photo string
FederatedIdentity string
FederatedProvider string
// additional, custom Attributes
Attrs map[string]string
}
// Decode will create a user from a URL Query string.
func Decode(v string) *User {
values, err := url.ParseQuery(v)
if err != nil {
return nil
}
attrs := map[string]string{}
for key, _ := range values {
attrs[key]=values.Get(key)
}
return &User {
Id : values.Get("id"),
Name : values.Get("name"),
Email : values.Get("email"),
Photo : values.Get("photo"),
Attrs : attrs,
}
}
// Encode will encode a user as a URL query string.
func (u *User) Encode() string {
values := url.Values{}
// add custom attributes
if u.Attrs != nil {
for key, val := range u.Attrs {
values.Set(key, val)
}
}
values.Set("id", u.Id)
values.Set("name", u.Name)
values.Set("email", u.Email)
values.Set("photo", u.Photo)
return values.Encode()
}
// Current returns the currently logged-in user, or nil if the user is not
// signed in.
func Current(c *context.Context) *User {
v := c.Values.Get(userKey)
if v == nil {
return nil
}
u, ok := v.(*User)
if !ok {
return nil
}
return u
}
// Set sets the currently logged-in user. This is typically used by middleware
// that handles user authentication.
func Set(c *context.Context, u *User) {
c.Values.Set(userKey, u)
}

View File

@ -0,0 +1,317 @@
package routes
import (
"encoding/json"
"encoding/xml"
"io/ioutil"
"net/http"
"net/url"
"path/filepath"
"regexp"
"strconv"
"strings"
)
const (
CONNECT = "CONNECT"
DELETE = "DELETE"
GET = "GET"
HEAD = "HEAD"
OPTIONS = "OPTIONS"
PATCH = "PATCH"
POST = "POST"
PUT = "PUT"
TRACE = "TRACE"
)
//commonly used mime-types
const (
applicationJson = "application/json"
applicationXml = "application/xml"
textXml = "text/xml"
)
type route struct {
method string
regex *regexp.Regexp
params map[int]string
handler http.HandlerFunc
}
type RouteMux struct {
routes []*route
filters []http.HandlerFunc
}
func New() *RouteMux {
return &RouteMux{}
}
// Get adds a new Route for GET requests.
func (m *RouteMux) Get(pattern string, handler http.HandlerFunc) {
m.AddRoute(GET, pattern, handler)
}
// Put adds a new Route for PUT requests.
func (m *RouteMux) Put(pattern string, handler http.HandlerFunc) {
m.AddRoute(PUT, pattern, handler)
}
// Del adds a new Route for DELETE requests.
func (m *RouteMux) Del(pattern string, handler http.HandlerFunc) {
m.AddRoute(DELETE, pattern, handler)
}
// Patch adds a new Route for PATCH requests.
func (m *RouteMux) Patch(pattern string, handler http.HandlerFunc) {
m.AddRoute(PATCH, pattern, handler)
}
// Post adds a new Route for POST requests.
func (m *RouteMux) Post(pattern string, handler http.HandlerFunc) {
m.AddRoute(POST, pattern, handler)
}
// Adds a new Route for Static http requests. Serves
// static files from the specified directory
func (m *RouteMux) Static(pattern string, dir string) {
//append a regex to the param to match everything
// that comes after the prefix
pattern = pattern + "(.+)"
m.AddRoute(GET, pattern, func(w http.ResponseWriter, r *http.Request) {
path := filepath.Clean(r.URL.Path)
path = filepath.Join(dir, path)
http.ServeFile(w, r, path)
})
}
// Adds a new Route to the Handler
func (m *RouteMux) AddRoute(method string, pattern string, handler http.HandlerFunc) {
//split the url into sections
parts := strings.Split(pattern, "/")
//find params that start with ":"
//replace with regular expressions
j := 0
params := make(map[int]string)
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "([^/]+)"
//a user may choose to override the defult expression
// similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 {
expr = part[index:]
part = part[:index]
}
params[j] = part
parts[i] = expr
j++
}
}
//recreate the url pattern, with parameters replaced
//by regular expressions. then compile the regex
pattern = strings.Join(parts, "/")
regex, regexErr := regexp.Compile(pattern)
if regexErr != nil {
//TODO add error handling here to avoid panic
panic(regexErr)
return
}
//now create the Route
route := &route{}
route.method = method
route.regex = regex
route.handler = handler
route.params = params
//and finally append to the list of Routes
m.routes = append(m.routes, route)
}
// Filter adds the middleware filter.
func (m *RouteMux) Filter(filter http.HandlerFunc) {
m.filters = append(m.filters, filter)
}
// FilterParam adds the middleware filter iff the REST URL parameter exists.
func (m *RouteMux) FilterParam(param string, filter http.HandlerFunc) {
if !strings.HasPrefix(param,":") {
param = ":"+param
}
m.Filter(func(w http.ResponseWriter, r *http.Request) {
p := r.URL.Query().Get(param)
if len(p) > 0 { filter(w, r) }
})
}
// Required by http.Handler interface. This method is invoked by the
// http server and will handle all page routing
func (m *RouteMux) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
requestPath := r.URL.Path
//wrap the response writer, in our custom interface
w := &responseWriter{writer: rw}
//find a matching Route
for _, route := range m.routes {
//if the methods don't match, skip this handler
//i.e if request.Method is 'PUT' Route.Method must be 'PUT'
if r.Method != route.method {
continue
}
//check if Route pattern matches url
if !route.regex.MatchString(requestPath) {
continue
}
//get submatches (params)
matches := route.regex.FindStringSubmatch(requestPath)
//double check that the Route matches the URL pattern.
if len(matches[0]) != len(requestPath) {
continue
}
if len(route.params) > 0 {
//add url parameters to the query param map
values := r.URL.Query()
for i, match := range matches[1:] {
values.Add(route.params[i], match)
}
//reassemble query params and add to RawQuery
r.URL.RawQuery = url.Values(values).Encode() + "&" + r.URL.RawQuery
//r.URL.RawQuery = url.Values(values).Encode()
}
//execute middleware filters
for _, filter := range m.filters {
filter(w, r)
if w.started {
return
}
}
//Invoke the request handler
route.handler(w, r)
break
}
//if no matches to url, throw a not found exception
if w.started == false {
http.NotFound(w, r)
}
}
// -----------------------------------------------------------------------------
// Simple wrapper around a ResponseWriter
// responseWriter is a wrapper for the http.ResponseWriter
// to track if response was written to. It also allows us
// to automatically set certain headers, such as Content-Type,
// Access-Control-Allow-Origin, etc.
type responseWriter struct {
writer http.ResponseWriter
started bool
status int
}
// Header returns the header map that will be sent by WriteHeader.
func (w *responseWriter) Header() http.Header {
return w.writer.Header()
}
// Write writes the data to the connection as part of an HTTP reply,
// and sets `started` to true
func (w *responseWriter) Write(p []byte) (int, error) {
w.started = true
return w.writer.Write(p)
}
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true
func (w *responseWriter) WriteHeader(code int) {
w.status = code
w.started = true
w.writer.WriteHeader(code)
}
// -----------------------------------------------------------------------------
// Below are helper functions to replace boilerplate
// code that serializes resources and writes to the
// http response.
// ServeJson replies to the request with a JSON
// representation of resource v.
func ServeJson(w http.ResponseWriter, v interface{}) {
content, err := json.MarshalIndent(v, "", " ")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
w.Header().Set("Content-Type", applicationJson)
w.Write(content)
}
// ReadJson will parses the JSON-encoded data in the http
// Request object and stores the result in the value
// pointed to by v.
func ReadJson(r *http.Request, v interface{}) error {
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return err
}
return json.Unmarshal(body, v)
}
// ServeXml replies to the request with an XML
// representation of resource v.
func ServeXml(w http.ResponseWriter, v interface{}) {
content, err := xml.Marshal(v)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
w.Header().Set("Content-Type", "text/xml; charset=utf-8")
w.Write(content)
}
// ReadXml will parses the XML-encoded data in the http
// Request object and stores the result in the value
// pointed to by v.
func ReadXml(r *http.Request, v interface{}) error {
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return err
}
return xml.Unmarshal(body, v)
}
// ServeFormatted replies to the request with
// a formatted representation of resource v, in the
// format requested by the client specified in the
// Accept header.
func ServeFormatted(w http.ResponseWriter, r *http.Request, v interface{}) {
accept := r.Header.Get("Accept")
switch accept {
case applicationJson:
ServeJson(w, v)
case applicationXml, textXml:
ServeXml(w, v)
default:
ServeJson(w, v)
}
return
}

View File

@ -0,0 +1,193 @@
package routes
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
)
var HandlerOk = func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "hello world")
w.WriteHeader(http.StatusOK)
}
var HandlerErr = func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
var FilterUser = func(w http.ResponseWriter, r *http.Request) {
if r.URL.User == nil || r.URL.User.Username() != "admin" {
http.Error(w, "", http.StatusUnauthorized)
}
}
var FilterId = func(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get(":id")
if id == "admin" {
http.Error(w, "", http.StatusUnauthorized)
}
}
// TestAuthOk tests that an Auth handler will append the
// username and password to to the request URL, and will
// continue processing the request by invoking the handler.
func TestRouteOk(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
handler := new(RouteMux)
handler.Get("/person/:last/:first", HandlerOk)
handler.ServeHTTP(w, r)
lastNameParam := r.URL.Query().Get(":last")
firstNameParam := r.URL.Query().Get(":first")
learnParam := r.URL.Query().Get("learn")
if lastNameParam != "anderson" {
t.Errorf("url param set to [%s]; want [%s]", lastNameParam, "anderson")
}
if firstNameParam != "thomas" {
t.Errorf("url param set to [%s]; want [%s]", firstNameParam, "thomas")
}
if learnParam != "kungfu" {
t.Errorf("url param set to [%s]; want [%s]", learnParam, "kungfu")
}
}
// TestNotFound tests that a 404 code is returned in the
// response if no route matches the request url.
func TestNotFound(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler := new(RouteMux)
handler.ServeHTTP(w, r)
if w.Code != http.StatusNotFound {
t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusNotFound)
}
}
// TestStatic tests the ability to serve static
// content from the filesystem
func TestStatic(t *testing.T) {
r, _ := http.NewRequest("GET", "/routes_test.go", nil)
w := httptest.NewRecorder()
pwd, _ := os.Getwd()
handler := new(RouteMux)
handler.Static("/", pwd)
handler.ServeHTTP(w, r)
testFile, _ := ioutil.ReadFile(pwd + "/routes_test.go")
if w.Body.String() != string(testFile) {
t.Errorf("handler.Static failed to serve file")
}
}
// TestFilter tests the ability to apply middleware function
// to filter all routes
func TestFilter(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler := new(RouteMux)
handler.Get("/", HandlerOk)
handler.Filter(FilterUser)
handler.ServeHTTP(w, r)
if w.Code != http.StatusUnauthorized {
t.Errorf("Did not apply Filter. Code set to [%v]; want [%v]", w.Code, http.StatusUnauthorized)
}
r, _ = http.NewRequest("GET", "/", nil)
r.URL.User = url.User("admin")
w = httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusOK)
}
}
// TestFilterParam tests the ability to apply middleware
// function to filter all routes with specified parameter
// in the REST url
func TestFilterParam(t *testing.T) {
r, _ := http.NewRequest("GET", "/:id", nil)
w := httptest.NewRecorder()
// first test that the param filter does not trigger
handler := new(RouteMux)
handler.Get("/", HandlerOk)
handler.Get("/:id", HandlerOk)
handler.FilterParam("id", FilterId)
handler.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusOK)
}
// now test the param filter does trigger
r, _ = http.NewRequest("GET", "/admin", nil)
w = httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Code != http.StatusUnauthorized {
t.Errorf("Did not apply Param Filter. Code set to [%v]; want [%v]", w.Code, http.StatusUnauthorized)
}
}
// Benchmark_RoutedHandler runs a benchmark against
// the RouteMux using the default settings.
func Benchmark_RoutedHandler(b *testing.B) {
handler := new(RouteMux)
handler.Get("/", HandlerOk)
for i := 0; i < b.N; i++ {
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
}
}
// Benchmark_RoutedHandler runs a benchmark against
// the RouteMux using the default settings with REST
// URL params.
func Benchmark_RoutedHandlerParams(b *testing.B) {
handler := new(RouteMux)
handler.Get("/:user", HandlerOk)
for i := 0; i < b.N; i++ {
r, _ := http.NewRequest("GET", "/admin", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
}
}
// Benchmark_ServeMux runs a benchmark against
// the ServeMux Go function. We use this to determine
// performance impact of our library, when compared
// to the out-of-the-box Mux provided by Go.
func Benchmark_ServeMux(b *testing.B) {
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
mux := http.NewServeMux()
mux.HandleFunc("/", HandlerOk)
for i := 0; i < b.N; i++ {
mux.ServeHTTP(w, r)
}
}

View File

@ -0,0 +1,175 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.

View File

@ -0,0 +1,44 @@
Redigo
======
Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database.
Features
-------
* A [Print-like](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Executing_Commands) API with support for all Redis commands.
* [Pipelining](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Pipelining), including pipelined transactions.
* [Publish/Subscribe](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Publish_and_Subscribe).
* [Connection pooling](http://godoc.org/github.com/garyburd/redigo/redis#Pool).
* [Script helper type](http://godoc.org/github.com/garyburd/redigo/redis#Script) with optimistic use of EVALSHA.
* [Helper functions](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Reply_Helpers) for working with command replies.
Documentation
-------------
- [API Reference](http://godoc.org/github.com/garyburd/redigo/redis)
- [FAQ](https://github.com/garyburd/redigo/wiki/FAQ)
Installation
------------
Install Redigo using the "go get" command:
go get github.com/garyburd/redigo/redis
The Go distribution is Redigo's only dependency.
Contributing
------------
Contributions are welcome.
Before writing code, send mail to gary@beagledreams.com to discuss what you
plan to do. This gives me a chance to validate the design, avoid duplication of
effort and ensure that the changes fit the goals of the project. Do not start
the discussion with a pull request.
License
-------
Redigo is available under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0.html).

View File

@ -0,0 +1,54 @@
// Copyright 2014 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package internal // import "github.com/garyburd/redigo/internal"
import (
"strings"
)
const (
WatchState = 1 << iota
MultiState
SubscribeState
MonitorState
)
type CommandInfo struct {
Set, Clear int
}
var commandInfos = map[string]CommandInfo{
"WATCH": {Set: WatchState},
"UNWATCH": {Clear: WatchState},
"MULTI": {Set: MultiState},
"EXEC": {Clear: WatchState | MultiState},
"DISCARD": {Clear: WatchState | MultiState},
"PSUBSCRIBE": {Set: SubscribeState},
"SUBSCRIBE": {Set: SubscribeState},
"MONITOR": {Set: MonitorState},
}
func init() {
for n, ci := range commandInfos {
commandInfos[strings.ToLower(n)] = ci
}
}
func LookupCommandInfo(commandName string) CommandInfo {
if ci, ok := commandInfos[commandName]; ok {
return ci
}
return commandInfos[strings.ToUpper(commandName)]
}

View File

@ -0,0 +1,27 @@
package internal
import "testing"
func TestLookupCommandInfo(t *testing.T) {
for _, n := range []string{"watch", "WATCH", "wAtch"} {
if LookupCommandInfo(n) == (CommandInfo{}) {
t.Errorf("LookupCommandInfo(%q) = CommandInfo{}, expected non-zero value", n)
}
}
}
func benchmarkLookupCommandInfo(b *testing.B, names ...string) {
for i := 0; i < b.N; i++ {
for _, c := range names {
LookupCommandInfo(c)
}
}
}
func BenchmarkLookupCommandInfoCorrectCase(b *testing.B) {
benchmarkLookupCommandInfo(b, "watch", "WATCH", "monitor", "MONITOR")
}
func BenchmarkLookupCommandInfoMixedCase(b *testing.B) {
benchmarkLookupCommandInfo(b, "wAtch", "WeTCH", "monItor", "MONiTOR")
}

View File

@ -0,0 +1,65 @@
// Copyright 2014 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package redistest contains utilities for writing Redigo tests.
package redistest
import (
"errors"
"time"
"github.com/garyburd/redigo/redis"
)
type testConn struct {
redis.Conn
}
func (t testConn) Close() error {
_, err := t.Conn.Do("SELECT", "9")
if err != nil {
return nil
}
_, err = t.Conn.Do("FLUSHDB")
if err != nil {
return err
}
return t.Conn.Close()
}
// Dial dials the local Redis server and selects database 9. To prevent
// stomping on real data, DialTestDB fails if database 9 contains data. The
// returned connection flushes database 9 on close.
func Dial() (redis.Conn, error) {
c, err := redis.DialTimeout("tcp", ":6379", 0, 1*time.Second, 1*time.Second)
if err != nil {
return nil, err
}
_, err = c.Do("SELECT", "9")
if err != nil {
return nil, err
}
n, err := redis.Int(c.Do("DBSIZE"))
if err != nil {
return nil, err
}
if n != 0 {
return nil, errors.New("database #9 is not empty, test can not continue")
}
return testConn{c}, nil
}

View File

@ -0,0 +1,455 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
)
// conn is the low-level implementation of Conn
type conn struct {
// Shared
mu sync.Mutex
pending int
err error
conn net.Conn
// Read
readTimeout time.Duration
br *bufio.Reader
// Write
writeTimeout time.Duration
bw *bufio.Writer
// Scratch space for formatting argument length.
// '*' or '$', length, "\r\n"
lenScratch [32]byte
// Scratch space for formatting integers and floats.
numScratch [40]byte
}
// Dial connects to the Redis server at the given network and address.
func Dial(network, address string) (Conn, error) {
dialer := xDialer{}
return dialer.Dial(network, address)
}
// DialTimeout acts like Dial but takes timeouts for establishing the
// connection to the server, writing a command and reading a reply.
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
netDialer := net.Dialer{Timeout: connectTimeout}
dialer := xDialer{
NetDial: netDialer.Dial,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
}
return dialer.Dial(network, address)
}
// A Dialer specifies options for connecting to a Redis server.
type xDialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, then net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// ReadTimeout specifies the timeout for reading a single command
// reply. If ReadTimeout is zero, then no timeout is used.
ReadTimeout time.Duration
// WriteTimeout specifies the timeout for writing a single command. If
// WriteTimeout is zero, then no timeout is used.
WriteTimeout time.Duration
}
// Dial connects to the Redis server at address on the named network.
func (d *xDialer) Dial(network, address string) (Conn, error) {
dial := d.NetDial
if dial == nil {
dial = net.Dial
}
netConn, err := dial(network, address)
if err != nil {
return nil, err
}
return &conn{
conn: netConn,
bw: bufio.NewWriter(netConn),
br: bufio.NewReader(netConn),
readTimeout: d.ReadTimeout,
writeTimeout: d.WriteTimeout,
}, nil
}
// NewConn returns a new Redigo connection for the given net connection.
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
return &conn{
conn: netConn,
bw: bufio.NewWriter(netConn),
br: bufio.NewReader(netConn),
readTimeout: readTimeout,
writeTimeout: writeTimeout,
}
}
func (c *conn) Close() error {
c.mu.Lock()
err := c.err
if c.err == nil {
c.err = errors.New("redigo: closed")
err = c.conn.Close()
}
c.mu.Unlock()
return err
}
func (c *conn) fatal(err error) error {
c.mu.Lock()
if c.err == nil {
c.err = err
// Close connection to force errors on subsequent calls and to unblock
// other reader or writer.
c.conn.Close()
}
c.mu.Unlock()
return err
}
func (c *conn) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}
func (c *conn) writeLen(prefix byte, n int) error {
c.lenScratch[len(c.lenScratch)-1] = '\n'
c.lenScratch[len(c.lenScratch)-2] = '\r'
i := len(c.lenScratch) - 3
for {
c.lenScratch[i] = byte('0' + n%10)
i -= 1
n = n / 10
if n == 0 {
break
}
}
c.lenScratch[i] = prefix
_, err := c.bw.Write(c.lenScratch[i:])
return err
}
func (c *conn) writeString(s string) error {
c.writeLen('$', len(s))
c.bw.WriteString(s)
_, err := c.bw.WriteString("\r\n")
return err
}
func (c *conn) writeBytes(p []byte) error {
c.writeLen('$', len(p))
c.bw.Write(p)
_, err := c.bw.WriteString("\r\n")
return err
}
func (c *conn) writeInt64(n int64) error {
return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
}
func (c *conn) writeFloat64(n float64) error {
return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
}
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
c.writeLen('*', 1+len(args))
err = c.writeString(cmd)
for _, arg := range args {
if err != nil {
break
}
switch arg := arg.(type) {
case string:
err = c.writeString(arg)
case []byte:
err = c.writeBytes(arg)
case int:
err = c.writeInt64(int64(arg))
case int64:
err = c.writeInt64(arg)
case float64:
err = c.writeFloat64(arg)
case bool:
if arg {
err = c.writeString("1")
} else {
err = c.writeString("0")
}
case nil:
err = c.writeString("")
default:
var buf bytes.Buffer
fmt.Fprint(&buf, arg)
err = c.writeBytes(buf.Bytes())
}
}
return err
}
type protocolError string
func (pe protocolError) Error() string {
return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
}
func (c *conn) readLine() ([]byte, error) {
p, err := c.br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
return nil, protocolError("long response line")
}
if err != nil {
return nil, err
}
i := len(p) - 2
if i < 0 || p[i] != '\r' {
return nil, protocolError("bad response line terminator")
}
return p[:i], nil
}
// parseLen parses bulk string and array lengths.
func parseLen(p []byte) (int, error) {
if len(p) == 0 {
return -1, protocolError("malformed length")
}
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
// handle $-1 and $-1 null replies.
return -1, nil
}
var n int
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return -1, protocolError("illegal bytes in length")
}
n += int(b - '0')
}
return n, nil
}
// parseInt parses an integer reply.
func parseInt(p []byte) (interface{}, error) {
if len(p) == 0 {
return 0, protocolError("malformed integer")
}
var negate bool
if p[0] == '-' {
negate = true
p = p[1:]
if len(p) == 0 {
return 0, protocolError("malformed integer")
}
}
var n int64
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return 0, protocolError("illegal bytes in length")
}
n += int64(b - '0')
}
if negate {
n = -n
}
return n, nil
}
var (
okReply interface{} = "OK"
pongReply interface{} = "PONG"
)
func (c *conn) readReply() (interface{}, error) {
line, err := c.readLine()
if err != nil {
return nil, err
}
if len(line) == 0 {
return nil, protocolError("short response line")
}
switch line[0] {
case '+':
switch {
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
// Avoid allocation for frequent "+OK" response.
return okReply, nil
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
// Avoid allocation in PING command benchmarks :)
return pongReply, nil
default:
return string(line[1:]), nil
}
case '-':
return Error(string(line[1:])), nil
case ':':
return parseInt(line[1:])
case '$':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(c.br, p)
if err != nil {
return nil, err
}
if line, err := c.readLine(); err != nil {
return nil, err
} else if len(line) != 0 {
return nil, protocolError("bad bulk string format")
}
return p, nil
case '*':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
r := make([]interface{}, n)
for i := range r {
r[i], err = c.readReply()
if err != nil {
return nil, err
}
}
return r, nil
}
return nil, protocolError("unexpected response line")
}
func (c *conn) Send(cmd string, args ...interface{}) error {
c.mu.Lock()
c.pending += 1
c.mu.Unlock()
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if err := c.writeCommand(cmd, args); err != nil {
return c.fatal(err)
}
return nil
}
func (c *conn) Flush() error {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if err := c.bw.Flush(); err != nil {
return c.fatal(err)
}
return nil
}
func (c *conn) Receive() (reply interface{}, err error) {
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
if reply, err = c.readReply(); err != nil {
return nil, c.fatal(err)
}
// When using pub/sub, the number of receives can be greater than the
// number of sends. To enable normal use of the connection after
// unsubscribing from all channels, we do not decrement pending to a
// negative value.
//
// The pending field is decremented after the reply is read to handle the
// case where Receive is called before Send.
c.mu.Lock()
if c.pending > 0 {
c.pending -= 1
}
c.mu.Unlock()
if err, ok := reply.(Error); ok {
return nil, err
}
return
}
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock()
pending := c.pending
c.pending = 0
c.mu.Unlock()
if cmd == "" && pending == 0 {
return nil, nil
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if cmd != "" {
c.writeCommand(cmd, args)
}
if err := c.bw.Flush(); err != nil {
return nil, c.fatal(err)
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
if cmd == "" {
reply := make([]interface{}, pending)
for i := range reply {
r, e := c.readReply()
if e != nil {
return nil, c.fatal(e)
}
reply[i] = r
}
return reply, nil
}
var err error
var reply interface{}
for i := 0; i <= pending; i++ {
var e error
if reply, e = c.readReply(); e != nil {
return nil, c.fatal(e)
}
if e, ok := reply.(Error); ok && err == nil {
err = e
}
}
return reply, err
}

View File

@ -0,0 +1,542 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"bufio"
"bytes"
"math"
"net"
"reflect"
"strings"
"testing"
"time"
"github.com/garyburd/redigo/internal/redistest"
"github.com/garyburd/redigo/redis"
)
var writeTests = []struct {
args []interface{}
expected string
}{
{
[]interface{}{"SET", "key", "value"},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
[]interface{}{"SET", "key", "value"},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
[]interface{}{"SET", "key", byte(100)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
},
{
[]interface{}{"SET", "key", 100},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
},
{
[]interface{}{"SET", "key", int64(math.MinInt64)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$20\r\n-9223372036854775808\r\n",
},
{
[]interface{}{"SET", "key", float64(1349673917.939762)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$21\r\n1.349673917939762e+09\r\n",
},
{
[]interface{}{"SET", "key", ""},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
[]interface{}{"SET", "key", nil},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
[]interface{}{"ECHO", true, false},
"*3\r\n$4\r\nECHO\r\n$1\r\n1\r\n$1\r\n0\r\n",
},
}
func TestWrite(t *testing.T) {
for _, tt := range writeTests {
var buf bytes.Buffer
rw := bufio.ReadWriter{Writer: bufio.NewWriter(&buf)}
c := redis.NewConnBufio(rw)
err := c.Send(tt.args[0].(string), tt.args[1:]...)
if err != nil {
t.Errorf("Send(%v) returned error %v", tt.args, err)
continue
}
rw.Flush()
actual := buf.String()
if actual != tt.expected {
t.Errorf("Send(%v) = %q, want %q", tt.args, actual, tt.expected)
}
}
}
var errorSentinel = &struct{}{}
var readTests = []struct {
reply string
expected interface{}
}{
{
"+OK\r\n",
"OK",
},
{
"+PONG\r\n",
"PONG",
},
{
"@OK\r\n",
errorSentinel,
},
{
"$6\r\nfoobar\r\n",
[]byte("foobar"),
},
{
"$-1\r\n",
nil,
},
{
":1\r\n",
int64(1),
},
{
":-2\r\n",
int64(-2),
},
{
"*0\r\n",
[]interface{}{},
},
{
"*-1\r\n",
nil,
},
{
"*4\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$5\r\nHello\r\n$5\r\nWorld\r\n",
[]interface{}{[]byte("foo"), []byte("bar"), []byte("Hello"), []byte("World")},
},
{
"*3\r\n$3\r\nfoo\r\n$-1\r\n$3\r\nbar\r\n",
[]interface{}{[]byte("foo"), nil, []byte("bar")},
},
{
// "x" is not a valid length
"$x\r\nfoobar\r\n",
errorSentinel,
},
{
// -2 is not a valid length
"$-2\r\n",
errorSentinel,
},
{
// "x" is not a valid integer
":x\r\n",
errorSentinel,
},
{
// missing \r\n following value
"$6\r\nfoobar",
errorSentinel,
},
{
// short value
"$6\r\nxx",
errorSentinel,
},
{
// long value
"$6\r\nfoobarx\r\n",
errorSentinel,
},
}
func TestRead(t *testing.T) {
for _, tt := range readTests {
rw := bufio.ReadWriter{
Reader: bufio.NewReader(strings.NewReader(tt.reply)),
Writer: bufio.NewWriter(nil), // writer need to support Flush
}
c := redis.NewConnBufio(rw)
actual, err := c.Receive()
if tt.expected == errorSentinel {
if err == nil {
t.Errorf("Receive(%q) did not return expected error", tt.reply)
}
} else {
if err != nil {
t.Errorf("Receive(%q) returned error %v", tt.reply, err)
continue
}
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("Receive(%q) = %v, want %v", tt.reply, actual, tt.expected)
}
}
}
}
var testCommands = []struct {
args []interface{}
expected interface{}
}{
{
[]interface{}{"PING"},
"PONG",
},
{
[]interface{}{"SET", "foo", "bar"},
"OK",
},
{
[]interface{}{"GET", "foo"},
[]byte("bar"),
},
{
[]interface{}{"GET", "nokey"},
nil,
},
{
[]interface{}{"MGET", "nokey", "foo"},
[]interface{}{nil, []byte("bar")},
},
{
[]interface{}{"INCR", "mycounter"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "foo"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "bar"},
int64(2),
},
{
[]interface{}{"LRANGE", "mylist", 0, -1},
[]interface{}{[]byte("bar"), []byte("foo")},
},
{
[]interface{}{"MULTI"},
"OK",
},
{
[]interface{}{"LRANGE", "mylist", 0, -1},
"QUEUED",
},
{
[]interface{}{"PING"},
"QUEUED",
},
{
[]interface{}{"EXEC"},
[]interface{}{
[]interface{}{[]byte("bar"), []byte("foo")},
"PONG",
},
},
}
func TestDoCommands(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
actual, err := c.Do(cmd.args[0].(string), cmd.args[1:]...)
if err != nil {
t.Errorf("Do(%v) returned error %v", cmd.args, err)
continue
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Do(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestPipelineCommands(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
if err := c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
}
}
if err := c.Flush(); err != nil {
t.Errorf("Flush() returned error %v", err)
}
for _, cmd := range testCommands {
actual, err := c.Receive()
if err != nil {
t.Fatalf("Receive(%v) returned error %v", cmd.args, err)
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestBlankCommmand(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
if err := c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
}
}
reply, err := redis.Values(c.Do(""))
if err != nil {
t.Fatalf("Do() returned error %v", err)
}
if len(reply) != len(testCommands) {
t.Fatalf("len(reply)=%d, want %d", len(reply), len(testCommands))
}
for i, cmd := range testCommands {
actual := reply[i]
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestRecvBeforeSend(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
done := make(chan struct{})
go func() {
c.Receive()
close(done)
}()
time.Sleep(time.Millisecond)
c.Send("PING")
c.Flush()
<-done
_, err = c.Do("")
if err != nil {
t.Fatalf("error=%v", err)
}
}
func TestError(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
c.Do("SET", "key", "val")
_, err = c.Do("HSET", "key", "fld", "val")
if err == nil {
t.Errorf("Expected err for HSET on string key.")
}
if c.Err() != nil {
t.Errorf("Conn has Err()=%v, expect nil", c.Err())
}
_, err = c.Do("SET", "key", "val")
if err != nil {
t.Errorf("Do(SET, key, val) returned error %v, expected nil.", err)
}
}
func TestReadDeadline(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen returned %v", err)
}
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}
go func() {
time.Sleep(time.Second)
c.Write([]byte("+OK\r\n"))
c.Close()
}()
}
}()
c1, err := redis.DialTimeout(l.Addr().Network(), l.Addr().String(), 0, time.Millisecond, 0)
if err != nil {
t.Fatalf("redis.Dial returned %v", err)
}
defer c1.Close()
_, err = c1.Do("PING")
if err == nil {
t.Fatalf("c1.Do() returned nil, expect error")
}
if c1.Err() == nil {
t.Fatalf("c1.Err() = nil, expect error")
}
c2, err := redis.DialTimeout(l.Addr().Network(), l.Addr().String(), 0, time.Millisecond, 0)
if err != nil {
t.Fatalf("redis.Dial returned %v", err)
}
defer c2.Close()
c2.Send("PING")
c2.Flush()
_, err = c2.Receive()
if err == nil {
t.Fatalf("c2.Receive() returned nil, expect error")
}
if c2.Err() == nil {
t.Fatalf("c2.Err() = nil, expect error")
}
}
// Connect to local instance of Redis running on the default port.
func ExampleDial(x int) {
c, err := redis.Dial("tcp", ":6379")
if err != nil {
// handle error
}
defer c.Close()
}
// TextExecError tests handling of errors in a transaction. See
// http://redis.io/topics/transactions for information on how Redis handles
// errors in a transaction.
func TestExecError(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
// Execute commands that fail before EXEC is called.
c.Do("ZADD", "k0", 0, 0)
c.Send("MULTI")
c.Send("NOTACOMMAND", "k0", 0, 0)
c.Send("ZINCRBY", "k0", 0, 0)
v, err := c.Do("EXEC")
if err == nil {
t.Fatalf("EXEC returned values %v, expected error", v)
}
// Execute commands that fail after EXEC is called. The first command
// returns an error.
c.Do("ZADD", "k1", 0, 0)
c.Send("MULTI")
c.Send("HSET", "k1", 0, 0)
c.Send("ZINCRBY", "k1", 0, 0)
v, err = c.Do("EXEC")
if err != nil {
t.Fatalf("EXEC returned error %v", err)
}
vs, err := redis.Values(v, nil)
if err != nil {
t.Fatalf("Values(v) returned error %v", err)
}
if len(vs) != 2 {
t.Fatalf("len(vs) == %d, want 2", len(vs))
}
if _, ok := vs[0].(error); !ok {
t.Fatalf("first result is type %T, expected error", vs[0])
}
if _, ok := vs[1].([]byte); !ok {
t.Fatalf("second result is type %T, expected []byte", vs[2])
}
// Execute commands that fail after EXEC is called. The second command
// returns an error.
c.Do("ZADD", "k2", 0, 0)
c.Send("MULTI")
c.Send("ZINCRBY", "k2", 0, 0)
c.Send("HSET", "k2", 0, 0)
v, err = c.Do("EXEC")
if err != nil {
t.Fatalf("EXEC returned error %v", err)
}
vs, err = redis.Values(v, nil)
if err != nil {
t.Fatalf("Values(v) returned error %v", err)
}
if len(vs) != 2 {
t.Fatalf("len(vs) == %d, want 2", len(vs))
}
if _, ok := vs[0].([]byte); !ok {
t.Fatalf("first result is type %T, expected []byte", vs[0])
}
if _, ok := vs[1].(error); !ok {
t.Fatalf("second result is type %T, expected error", vs[2])
}
}
func BenchmarkDoEmpty(b *testing.B) {
b.StopTimer()
c, err := redistest.Dial()
if err != nil {
b.Fatal(err)
}
defer c.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := c.Do(""); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkDoPing(b *testing.B) {
b.StopTimer()
c, err := redistest.Dial()
if err != nil {
b.Fatal(err)
}
defer c.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
}
}

View File

@ -0,0 +1,169 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package redis is a client for the Redis database.
//
// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more
// documentation about this package.
//
// Connections
//
// The Conn interface is the primary interface for working with Redis.
// Applications create connections by calling the Dial, DialWithTimeout or
// NewConn functions. In the future, functions will be added for creating
// sharded and other types of connections.
//
// The application must call the connection Close method when the application
// is done with the connection.
//
// Executing Commands
//
// The Conn interface has a generic method for executing Redis commands:
//
// Do(commandName string, args ...interface{}) (reply interface{}, err error)
//
// The Redis command reference (http://redis.io/commands) lists the available
// commands. An example of using the Redis APPEND command is:
//
// n, err := conn.Do("APPEND", "key", "value")
//
// The Do method converts command arguments to binary strings for transmission
// to the server as follows:
//
// Go Type Conversion
// []byte Sent as is
// string Sent as is
// int, int64 strconv.FormatInt(v)
// float64 strconv.FormatFloat(v, 'g', -1, 64)
// bool true -> "1", false -> "0"
// nil ""
// all other types fmt.Print(v)
//
// Redis command reply types are represented using the following Go types:
//
// Redis type Go type
// error redis.Error
// integer int64
// simple string string
// bulk string []byte or nil if value not present.
// array []interface{} or nil if value not present.
//
// Use type assertions or the reply helper functions to convert from
// interface{} to the specific Go type for the command result.
//
// Pipelining
//
// Connections support pipelining using the Send, Flush and Receive methods.
//
// Send(commandName string, args ...interface{}) error
// Flush() error
// Receive() (reply interface{}, err error)
//
// Send writes the command to the connection's output buffer. Flush flushes the
// connection's output buffer to the server. Receive reads a single reply from
// the server. The following example shows a simple pipeline.
//
// c.Send("SET", "foo", "bar")
// c.Send("GET", "foo")
// c.Flush()
// c.Receive() // reply from SET
// v, err = c.Receive() // reply from GET
//
// The Do method combines the functionality of the Send, Flush and Receive
// methods. The Do method starts by writing the command and flushing the output
// buffer. Next, the Do method receives all pending replies including the reply
// for the command just sent by Do. If any of the received replies is an error,
// then Do returns the error. If there are no errors, then Do returns the last
// reply. If the command argument to the Do method is "", then the Do method
// will flush the output buffer and receive pending replies without sending a
// command.
//
// Use the Send and Do methods to implement pipelined transactions.
//
// c.Send("MULTI")
// c.Send("INCR", "foo")
// c.Send("INCR", "bar")
// r, err := c.Do("EXEC")
// fmt.Println(r) // prints [1, 1]
//
// Concurrency
//
// Connections do not support concurrent calls to the write methods (Send,
// Flush) or concurrent calls to the read method (Receive). Connections do
// allow a concurrent reader and writer.
//
// Because the Do method combines the functionality of Send, Flush and Receive,
// the Do method cannot be called concurrently with the other methods.
//
// For full concurrent access to Redis, use the thread-safe Pool to get and
// release connections from within a goroutine.
//
// Publish and Subscribe
//
// Use the Send, Flush and Receive methods to implement Pub/Sub subscribers.
//
// c.Send("SUBSCRIBE", "example")
// c.Flush()
// for {
// reply, err := c.Receive()
// if err != nil {
// return err
// }
// // process pushed message
// }
//
// The PubSubConn type wraps a Conn with convenience methods for implementing
// subscribers. The Subscribe, PSubscribe, Unsubscribe and PUnsubscribe methods
// send and flush a subscription management command. The receive method
// converts a pushed message to convenient types for use in a type switch.
//
// psc := redis.PubSubConn{c}
// psc.Subscribe("example")
// for {
// switch v := psc.Receive().(type) {
// case redis.Message:
// fmt.Printf("%s: message: %s\n", v.Channel, v.Data)
// case redis.Subscription:
// fmt.Printf("%s: %s %d\n", v.Channel, v.Kind, v.Count)
// case error:
// return v
// }
// }
//
// Reply Helpers
//
// The Bool, Int, Bytes, String, Strings and Values functions convert a reply
// to a value of a specific type. To allow convenient wrapping of calls to the
// connection Do and Receive methods, the functions take a second argument of
// type error. If the error is non-nil, then the helper function returns the
// error. If the error is nil, the function converts the reply to the specified
// type:
//
// exists, err := redis.Bool(c.Do("EXISTS", "foo"))
// if err != nil {
// // handle error return from c.Do or type conversion error.
// }
//
// The Scan function converts elements of a array reply to Go types:
//
// var value1 int
// var value2 string
// reply, err := redis.Values(c.Do("MGET", "key1", "key2"))
// if err != nil {
// // handle error
// }
// if _, err := redis.Scan(reply, &value1, &value2); err != nil {
// // handle error
// }
package redis // import "github.com/garyburd/redigo/redis"

View File

@ -0,0 +1,117 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"fmt"
"log"
)
// NewLoggingConn returns a logging wrapper around a connection.
func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn {
if prefix != "" {
prefix = prefix + "."
}
return &loggingConn{conn, logger, prefix}
}
type loggingConn struct {
Conn
logger *log.Logger
prefix string
}
func (c *loggingConn) Close() error {
err := c.Conn.Close()
var buf bytes.Buffer
fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err)
c.logger.Output(2, buf.String())
return err
}
func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) {
const chop = 32
switch v := v.(type) {
case []byte:
if len(v) > chop {
fmt.Fprintf(buf, "%q...", v[:chop])
} else {
fmt.Fprintf(buf, "%q", v)
}
case string:
if len(v) > chop {
fmt.Fprintf(buf, "%q...", v[:chop])
} else {
fmt.Fprintf(buf, "%q", v)
}
case []interface{}:
if len(v) == 0 {
buf.WriteString("[]")
} else {
sep := "["
fin := "]"
if len(v) > chop {
v = v[:chop]
fin = "...]"
}
for _, vv := range v {
buf.WriteString(sep)
c.printValue(buf, vv)
sep = ", "
}
buf.WriteString(fin)
}
default:
fmt.Fprint(buf, v)
}
}
func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) {
var buf bytes.Buffer
fmt.Fprintf(&buf, "%s%s(", c.prefix, method)
if method != "Receive" {
buf.WriteString(commandName)
for _, arg := range args {
buf.WriteString(", ")
c.printValue(&buf, arg)
}
}
buf.WriteString(") -> (")
if method != "Send" {
c.printValue(&buf, reply)
buf.WriteString(", ")
}
fmt.Fprintf(&buf, "%v)", err)
c.logger.Output(3, buf.String())
}
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) {
reply, err := c.Conn.Do(commandName, args...)
c.print("Do", commandName, args, reply, err)
return reply, err
}
func (c *loggingConn) Send(commandName string, args ...interface{}) error {
err := c.Conn.Send(commandName, args...)
c.print("Send", commandName, args, nil, err)
return err
}
func (c *loggingConn) Receive() (interface{}, error) {
reply, err := c.Conn.Receive()
c.print("Receive", "", nil, reply, err)
return reply, err
}

View File

@ -0,0 +1,389 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"container/list"
"crypto/rand"
"crypto/sha1"
"errors"
"io"
"strconv"
"sync"
"time"
"github.com/garyburd/redigo/internal"
)
var nowFunc = time.Now // for testing
// ErrPoolExhausted is returned from a pool connection method (Do, Send,
// Receive, Flush, Err) when the maximum number of database connections in the
// pool has been reached.
var ErrPoolExhausted = errors.New("redigo: connection pool exhausted")
var (
errPoolClosed = errors.New("redigo: connection pool closed")
errConnClosed = errors.New("redigo: connection closed")
)
// Pool maintains a pool of connections. The application calls the Get method
// to get a connection from the pool and the connection's Close method to
// return the connection's resources to the pool.
//
// The following example shows how to use a pool in a web application. The
// application creates a pool at application startup and makes it available to
// request handlers using a global variable.
//
// func newPool(server, password string) *redis.Pool {
// return &redis.Pool{
// MaxIdle: 3,
// IdleTimeout: 240 * time.Second,
// Dial: func () (redis.Conn, error) {
// c, err := redis.Dial("tcp", server)
// if err != nil {
// return nil, err
// }
// if _, err := c.Do("AUTH", password); err != nil {
// c.Close()
// return nil, err
// }
// return c, err
// },
// TestOnBorrow: func(c redis.Conn, t time.Time) error {
// _, err := c.Do("PING")
// return err
// },
// }
// }
//
// var (
// pool *redis.Pool
// redisServer = flag.String("redisServer", ":6379", "")
// redisPassword = flag.String("redisPassword", "", "")
// )
//
// func main() {
// flag.Parse()
// pool = newPool(*redisServer, *redisPassword)
// ...
// }
//
// A request handler gets a connection from the pool and closes the connection
// when the handler is done:
//
// func serveHome(w http.ResponseWriter, r *http.Request) {
// conn := pool.Get()
// defer conn.Close()
// ....
// }
//
type Pool struct {
// Dial is an application supplied function for creating and configuring a
// connection
Dial func() (Conn, error)
// TestOnBorrow is an optional application supplied function for checking
// the health of an idle connection before the connection is used again by
// the application. Argument t is the time that the connection was returned
// to the pool. If the function returns an error, then the connection is
// closed.
TestOnBorrow func(c Conn, t time.Time) error
// Maximum number of idle connections in the pool.
MaxIdle int
// Maximum number of connections allocated by the pool at a given time.
// When zero, there is no limit on the number of connections in the pool.
MaxActive int
// Close connections after remaining idle for this duration. If the value
// is zero, then idle connections are not closed. Applications should set
// the timeout to a value less than the server's timeout.
IdleTimeout time.Duration
// If Wait is true and the pool is at the MaxIdle limit, then Get() waits
// for a connection to be returned to the pool before returning.
Wait bool
// mu protects fields defined below.
mu sync.Mutex
cond *sync.Cond
closed bool
active int
// Stack of idleConn with most recently used at the front.
idle list.List
}
type idleConn struct {
c Conn
t time.Time
}
// NewPool creates a new pool. This function is deprecated. Applications should
// initialize the Pool fields directly as shown in example.
func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
return &Pool{Dial: newFn, MaxIdle: maxIdle}
}
// Get gets a connection. The application must close the returned connection.
// This method always returns a valid connection so that applications can defer
// error handling to the first use of the connection. If there is an error
// getting an underlying connection, then the connection Err, Do, Send, Flush
// and Receive methods return that error.
func (p *Pool) Get() Conn {
c, err := p.get()
if err != nil {
return errorConnection{err}
}
return &pooledConnection{p: p, c: c}
}
// ActiveCount returns the number of active connections in the pool.
func (p *Pool) ActiveCount() int {
p.mu.Lock()
active := p.active
p.mu.Unlock()
return active
}
// Close releases the resources used by the pool.
func (p *Pool) Close() error {
p.mu.Lock()
idle := p.idle
p.idle.Init()
p.closed = true
p.active -= idle.Len()
if p.cond != nil {
p.cond.Broadcast()
}
p.mu.Unlock()
for e := idle.Front(); e != nil; e = e.Next() {
e.Value.(idleConn).c.Close()
}
return nil
}
// release decrements the active count and signals waiters. The caller must
// hold p.mu during the call.
func (p *Pool) release() {
p.active -= 1
if p.cond != nil {
p.cond.Signal()
}
}
// get prunes stale connections and returns a connection from the idle list or
// creates a new connection.
func (p *Pool) get() (Conn, error) {
p.mu.Lock()
// Prune stale connections.
if timeout := p.IdleTimeout; timeout > 0 {
for i, n := 0, p.idle.Len(); i < n; i++ {
e := p.idle.Back()
if e == nil {
break
}
ic := e.Value.(idleConn)
if ic.t.Add(timeout).After(nowFunc()) {
break
}
p.idle.Remove(e)
p.release()
p.mu.Unlock()
ic.c.Close()
p.mu.Lock()
}
}
for {
// Get idle connection.
for i, n := 0, p.idle.Len(); i < n; i++ {
e := p.idle.Front()
if e == nil {
break
}
ic := e.Value.(idleConn)
p.idle.Remove(e)
test := p.TestOnBorrow
p.mu.Unlock()
if test == nil || test(ic.c, ic.t) == nil {
return ic.c, nil
}
ic.c.Close()
p.mu.Lock()
p.release()
}
// Check for pool closed before dialing a new connection.
if p.closed {
p.mu.Unlock()
return nil, errors.New("redigo: get on closed pool")
}
// Dial new connection if under limit.
if p.MaxActive == 0 || p.active < p.MaxActive {
dial := p.Dial
p.active += 1
p.mu.Unlock()
c, err := dial()
if err != nil {
p.mu.Lock()
p.release()
p.mu.Unlock()
c = nil
}
return c, err
}
if !p.Wait {
p.mu.Unlock()
return nil, ErrPoolExhausted
}
if p.cond == nil {
p.cond = sync.NewCond(&p.mu)
}
p.cond.Wait()
}
}
func (p *Pool) put(c Conn, forceClose bool) error {
err := c.Err()
p.mu.Lock()
if !p.closed && err == nil && !forceClose {
p.idle.PushFront(idleConn{t: nowFunc(), c: c})
if p.idle.Len() > p.MaxIdle {
c = p.idle.Remove(p.idle.Back()).(idleConn).c
} else {
c = nil
}
}
if c == nil {
if p.cond != nil {
p.cond.Signal()
}
p.mu.Unlock()
return nil
}
p.release()
p.mu.Unlock()
return c.Close()
}
type pooledConnection struct {
p *Pool
c Conn
state int
}
var (
sentinel []byte
sentinelOnce sync.Once
)
func initSentinel() {
p := make([]byte, 64)
if _, err := rand.Read(p); err == nil {
sentinel = p
} else {
h := sha1.New()
io.WriteString(h, "Oops, rand failed. Use time instead.")
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10))
sentinel = h.Sum(nil)
}
}
func (pc *pooledConnection) Close() error {
c := pc.c
if _, ok := c.(errorConnection); ok {
return nil
}
pc.c = errorConnection{errConnClosed}
if pc.state&internal.MultiState != 0 {
c.Send("DISCARD")
pc.state &^= (internal.MultiState | internal.WatchState)
} else if pc.state&internal.WatchState != 0 {
c.Send("UNWATCH")
pc.state &^= internal.WatchState
}
if pc.state&internal.SubscribeState != 0 {
c.Send("UNSUBSCRIBE")
c.Send("PUNSUBSCRIBE")
// To detect the end of the message stream, ask the server to echo
// a sentinel value and read until we see that value.
sentinelOnce.Do(initSentinel)
c.Send("ECHO", sentinel)
c.Flush()
for {
p, err := c.Receive()
if err != nil {
break
}
if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
pc.state &^= internal.SubscribeState
break
}
}
}
c.Do("")
pc.p.put(c, pc.state != 0)
return nil
}
func (pc *pooledConnection) Err() error {
return pc.c.Err()
}
func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
ci := internal.LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear
return pc.c.Do(commandName, args...)
}
func (pc *pooledConnection) Send(commandName string, args ...interface{}) error {
ci := internal.LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear
return pc.c.Send(commandName, args...)
}
func (pc *pooledConnection) Flush() error {
return pc.c.Flush()
}
func (pc *pooledConnection) Receive() (reply interface{}, err error) {
return pc.c.Receive()
}
type errorConnection struct{ err error }
func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
func (ec errorConnection) Send(string, ...interface{}) error { return ec.err }
func (ec errorConnection) Err() error { return ec.err }
func (ec errorConnection) Close() error { return ec.err }
func (ec errorConnection) Flush() error { return ec.err }
func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err }

View File

@ -0,0 +1,674 @@
// Copyright 2011 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"errors"
"io"
"reflect"
"sync"
"testing"
"time"
"github.com/garyburd/redigo/internal/redistest"
"github.com/garyburd/redigo/redis"
)
type poolTestConn struct {
d *poolDialer
err error
redis.Conn
}
func (c *poolTestConn) Close() error { c.d.open -= 1; return nil }
func (c *poolTestConn) Err() error { return c.err }
func (c *poolTestConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
if commandName == "ERR" {
c.err = args[0].(error)
commandName = "PING"
}
if commandName != "" {
c.d.commands = append(c.d.commands, commandName)
}
return c.Conn.Do(commandName, args...)
}
func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
c.d.commands = append(c.d.commands, commandName)
return c.Conn.Send(commandName, args...)
}
type poolDialer struct {
t *testing.T
dialed int
open int
commands []string
dialErr error
}
func (d *poolDialer) dial() (redis.Conn, error) {
d.dialed += 1
if d.dialErr != nil {
return nil, d.dialErr
}
c, err := redistest.Dial()
if err != nil {
return nil, err
}
d.open += 1
return &poolTestConn{d: d, Conn: c}, nil
}
func (d *poolDialer) check(message string, p *redis.Pool, dialed, open int) {
if d.dialed != dialed {
d.t.Errorf("%s: dialed=%d, want %d", message, d.dialed, dialed)
}
if d.open != open {
d.t.Errorf("%s: open=%d, want %d", message, d.open, open)
}
if active := p.ActiveCount(); active != open {
d.t.Errorf("%s: active=%d, want %d", message, active, open)
}
}
func TestPoolReuse(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
}
for i := 0; i < 10; i++ {
c1 := p.Get()
c1.Do("PING")
c2 := p.Get()
c2.Do("PING")
c1.Close()
c2.Close()
}
d.check("before close", p, 2, 2)
p.Close()
d.check("after close", p, 2, 0)
}
func TestPoolMaxIdle(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
}
for i := 0; i < 10; i++ {
c1 := p.Get()
c1.Do("PING")
c2 := p.Get()
c2.Do("PING")
c3 := p.Get()
c3.Do("PING")
c1.Close()
c2.Close()
c3.Close()
}
d.check("before close", p, 12, 2)
p.Close()
d.check("after close", p, 12, 0)
}
func TestPoolError(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
}
c := p.Get()
c.Do("ERR", io.EOF)
if c.Err() == nil {
t.Errorf("expected c.Err() != nil")
}
c.Close()
c = p.Get()
c.Do("ERR", io.EOF)
c.Close()
d.check(".", p, 2, 0)
}
func TestPoolClose(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
}
c1 := p.Get()
c1.Do("PING")
c2 := p.Get()
c2.Do("PING")
c3 := p.Get()
c3.Do("PING")
c1.Close()
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after connection closed")
}
c2.Close()
c2.Close()
p.Close()
d.check("after pool close", p, 3, 1)
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after connection and pool closed")
}
c3.Close()
d.check("after conn close", p, 3, 0)
c1 = p.Get()
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after pool closed")
}
}
func TestPoolTimeout(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
IdleTimeout: 300 * time.Second,
Dial: d.dial,
}
now := time.Now()
redis.SetNowFunc(func() time.Time { return now })
defer redis.SetNowFunc(time.Now)
c := p.Get()
c.Do("PING")
c.Close()
d.check("1", p, 1, 1)
now = now.Add(p.IdleTimeout)
c = p.Get()
c.Do("PING")
c.Close()
d.check("2", p, 2, 1)
p.Close()
}
func TestPoolConcurrenSendReceive(t *testing.T) {
p := &redis.Pool{
Dial: redistest.Dial,
}
c := p.Get()
done := make(chan error, 1)
go func() {
_, err := c.Receive()
done <- err
}()
c.Send("PING")
c.Flush()
err := <-done
if err != nil {
t.Fatalf("Receive() returned error %v", err)
}
_, err = c.Do("")
if err != nil {
t.Fatalf("Do() returned error %v", err)
}
c.Close()
p.Close()
}
func TestPoolBorrowCheck(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") },
}
for i := 0; i < 10; i++ {
c := p.Get()
c.Do("PING")
c.Close()
}
d.check("1", p, 10, 1)
p.Close()
}
func TestPoolMaxActive(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
MaxActive: 2,
Dial: d.dial,
}
c1 := p.Get()
c1.Do("PING")
c2 := p.Get()
c2.Do("PING")
d.check("1", p, 2, 2)
c3 := p.Get()
if _, err := c3.Do("PING"); err != redis.ErrPoolExhausted {
t.Errorf("expected pool exhausted")
}
c3.Close()
d.check("2", p, 2, 2)
c2.Close()
d.check("3", p, 2, 2)
c3 = p.Get()
if _, err := c3.Do("PING"); err != nil {
t.Errorf("expected good channel, err=%v", err)
}
c3.Close()
d.check("4", p, 2, 2)
p.Close()
}
func TestPoolMonitorCleanup(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
MaxActive: 2,
Dial: d.dial,
}
c := p.Get()
c.Send("MONITOR")
c.Close()
d.check("", p, 1, 0)
p.Close()
}
func TestPoolPubSubCleanup(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
MaxActive: 2,
Dial: d.dial,
}
c := p.Get()
c.Send("SUBSCRIBE", "x")
c.Close()
want := []string{"SUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Send("PSUBSCRIBE", "x*")
c.Close()
want = []string{"PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
p.Close()
}
func TestPoolTransactionCleanup(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
MaxActive: 2,
Dial: d.dial,
}
c := p.Get()
c.Do("WATCH", "key")
c.Do("PING")
c.Close()
want := []string{"WATCH", "PING", "UNWATCH"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Do("WATCH", "key")
c.Do("UNWATCH")
c.Do("PING")
c.Close()
want = []string{"WATCH", "UNWATCH", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "PING", "DISCARD"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("DISCARD")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "DISCARD", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("EXEC")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "EXEC", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
p.Close()
}
func startGoroutines(p *redis.Pool, cmd string, args ...interface{}) chan error {
errs := make(chan error, 10)
for i := 0; i < cap(errs); i++ {
go func() {
c := p.Get()
_, err := c.Do(cmd, args...)
errs <- err
c.Close()
}()
}
// Wait for goroutines to block.
time.Sleep(time.Second / 4)
return errs
}
func TestWaitPool(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 1,
MaxActive: 1,
Dial: d.dial,
Wait: true,
}
defer p.Close()
c := p.Get()
errs := startGoroutines(p, "PING")
d.check("before close", p, 1, 1)
c.Close()
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
if err != nil {
t.Fatal(err)
}
case <-timeout:
t.Fatalf("timeout waiting for blocked goroutine %d", i)
}
}
d.check("done", p, 1, 1)
}
func TestWaitPoolClose(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 1,
MaxActive: 1,
Dial: d.dial,
Wait: true,
}
c := p.Get()
if _, err := c.Do("PING"); err != nil {
t.Fatal(err)
}
errs := startGoroutines(p, "PING")
d.check("before close", p, 1, 1)
p.Close()
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
switch err {
case nil:
t.Fatal("blocked goroutine did not get error")
case redis.ErrPoolExhausted:
t.Fatal("blocked goroutine got pool exhausted error")
}
case <-timeout:
t.Fatal("timeout waiting for blocked goroutine")
}
}
c.Close()
d.check("done", p, 1, 0)
}
func TestWaitPoolCommandError(t *testing.T) {
testErr := errors.New("test")
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 1,
MaxActive: 1,
Dial: d.dial,
Wait: true,
}
defer p.Close()
c := p.Get()
errs := startGoroutines(p, "ERR", testErr)
d.check("before close", p, 1, 1)
c.Close()
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
if err != nil {
t.Fatal(err)
}
case <-timeout:
t.Fatalf("timeout waiting for blocked goroutine %d", i)
}
}
d.check("done", p, cap(errs), 0)
}
func TestWaitPoolDialError(t *testing.T) {
testErr := errors.New("test")
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 1,
MaxActive: 1,
Dial: d.dial,
Wait: true,
}
defer p.Close()
c := p.Get()
errs := startGoroutines(p, "ERR", testErr)
d.check("before close", p, 1, 1)
d.dialErr = errors.New("dial")
c.Close()
nilCount := 0
errCount := 0
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
switch err {
case nil:
nilCount++
case d.dialErr:
errCount++
default:
t.Fatalf("expected dial error or nil, got %v", err)
}
case <-timeout:
t.Fatalf("timeout waiting for blocked goroutine %d", i)
}
}
if nilCount != 1 {
t.Errorf("expected one nil error, got %d", nilCount)
}
if errCount != cap(errs)-1 {
t.Errorf("expected %d dial erors, got %d", cap(errs)-1, errCount)
}
d.check("done", p, cap(errs), 0)
}
// Borrowing requires us to iterate over the idle connections, unlock the pool,
// and perform a blocking operation to check the connection still works. If
// TestOnBorrow fails, we must reacquire the lock and continue iteration. This
// test ensures that iteration will work correctly if multiple threads are
// iterating simultaneously.
func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) {
count := 100
// First we'll Create a pool where the pilfering of idle connections fails.
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: count,
MaxActive: count,
Dial: d.dial,
TestOnBorrow: func(c redis.Conn, t time.Time) error {
return errors.New("No way back into the real world.")
},
}
defer p.Close()
// Fill the pool with idle connections.
b1 := sync.WaitGroup{}
b1.Add(count)
b2 := sync.WaitGroup{}
b2.Add(count)
for i := 0; i < count; i++ {
go func() {
c := p.Get()
if c.Err() != nil {
t.Errorf("pool get failed: %v", c.Err())
}
b1.Done()
b1.Wait()
c.Close()
b2.Done()
}()
}
b2.Wait()
if d.dialed != count {
t.Errorf("Expected %d dials, got %d", count, d.dialed)
}
// Spawn a bunch of goroutines to thrash the pool.
b2.Add(count)
for i := 0; i < count; i++ {
go func() {
c := p.Get()
if c.Err() != nil {
t.Errorf("pool get failed: %v", c.Err())
}
c.Close()
b2.Done()
}()
}
b2.Wait()
if d.dialed != count*2 {
t.Errorf("Expected %d dials, got %d", count*2, d.dialed)
}
}
func BenchmarkPoolGet(b *testing.B) {
b.StopTimer()
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
c := p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
defer p.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c = p.Get()
c.Close()
}
}
func BenchmarkPoolGetErr(b *testing.B) {
b.StopTimer()
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
c := p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
defer p.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c = p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
}
}
func BenchmarkPoolGetPing(b *testing.B) {
b.StopTimer()
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
c := p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
defer p.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c = p.Get()
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
c.Close()
}
}

View File

@ -0,0 +1,144 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import "errors"
// Subscription represents a subscribe or unsubscribe notification.
type Subscription struct {
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
Kind string
// The channel that was changed.
Channel string
// The current number of subscriptions for connection.
Count int
}
// Message represents a message notification.
type Message struct {
// The originating channel.
Channel string
// The message data.
Data []byte
}
// PMessage represents a pmessage notification.
type PMessage struct {
// The matched pattern.
Pattern string
// The originating channel.
Channel string
// The message data.
Data []byte
}
// Pong represents a pubsub pong notification.
type Pong struct {
Data string
}
// PubSubConn wraps a Conn with convenience methods for subscribers.
type PubSubConn struct {
Conn Conn
}
// Close closes the connection.
func (c PubSubConn) Close() error {
return c.Conn.Close()
}
// Subscribe subscribes the connection to the specified channels.
func (c PubSubConn) Subscribe(channel ...interface{}) error {
c.Conn.Send("SUBSCRIBE", channel...)
return c.Conn.Flush()
}
// PSubscribe subscribes the connection to the given patterns.
func (c PubSubConn) PSubscribe(channel ...interface{}) error {
c.Conn.Send("PSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// Unsubscribe unsubscribes the connection from the given channels, or from all
// of them if none is given.
func (c PubSubConn) Unsubscribe(channel ...interface{}) error {
c.Conn.Send("UNSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// PUnsubscribe unsubscribes the connection from the given patterns, or from all
// of them if none is given.
func (c PubSubConn) PUnsubscribe(channel ...interface{}) error {
c.Conn.Send("PUNSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// Ping sends a PING to the server with the specified data.
func (c PubSubConn) Ping(data string) error {
c.Conn.Send("PING", data)
return c.Conn.Flush()
}
// Receive returns a pushed message as a Subscription, Message, PMessage, Pong
// or error. The return value is intended to be used directly in a type switch
// as illustrated in the PubSubConn example.
func (c PubSubConn) Receive() interface{} {
reply, err := Values(c.Conn.Receive())
if err != nil {
return err
}
var kind string
reply, err = Scan(reply, &kind)
if err != nil {
return err
}
switch kind {
case "message":
var m Message
if _, err := Scan(reply, &m.Channel, &m.Data); err != nil {
return err
}
return m
case "pmessage":
var pm PMessage
if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil {
return err
}
return pm
case "subscribe", "psubscribe", "unsubscribe", "punsubscribe":
s := Subscription{Kind: kind}
if _, err := Scan(reply, &s.Channel, &s.Count); err != nil {
return err
}
return s
case "pong":
var p Pong
if _, err := Scan(reply, &p.Data); err != nil {
return err
}
return p
}
return errors.New("redigo: unknown pubsub notification")
}

View File

@ -0,0 +1,150 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"fmt"
"net"
"reflect"
"sync"
"testing"
"time"
"github.com/garyburd/redigo/internal/redistest"
"github.com/garyburd/redigo/redis"
)
func publish(channel, value interface{}) {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
c.Do("PUBLISH", channel, value)
}
// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine.
func ExamplePubSubConn() {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
var wg sync.WaitGroup
wg.Add(2)
psc := redis.PubSubConn{Conn: c}
// This goroutine receives and prints pushed notifications from the server.
// The goroutine exits when the connection is unsubscribed from all
// channels or there is an error.
go func() {
defer wg.Done()
for {
switch n := psc.Receive().(type) {
case redis.Message:
fmt.Printf("Message: %s %s\n", n.Channel, n.Data)
case redis.PMessage:
fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data)
case redis.Subscription:
fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
if n.Count == 0 {
return
}
case error:
fmt.Printf("error: %v\n", n)
return
}
}
}()
// This goroutine manages subscriptions for the connection.
go func() {
defer wg.Done()
psc.Subscribe("example")
psc.PSubscribe("p*")
// The following function calls publish a message using another
// connection to the Redis server.
publish("example", "hello")
publish("example", "world")
publish("pexample", "foo")
publish("pexample", "bar")
// Unsubscribe from all connections. This will cause the receiving
// goroutine to exit.
psc.Unsubscribe()
psc.PUnsubscribe()
}()
wg.Wait()
// Output:
// Subscription: subscribe example 1
// Subscription: psubscribe p* 2
// Message: example hello
// Message: example world
// PMessage: p* pexample foo
// PMessage: p* pexample bar
// Subscription: unsubscribe example 1
// Subscription: punsubscribe p* 0
}
func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected interface{}) {
actual := c.Receive()
if !reflect.DeepEqual(actual, expected) {
t.Errorf("%s = %v, want %v", message, actual, expected)
}
}
func TestPushed(t *testing.T) {
pc, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer pc.Close()
nc, err := net.Dial("tcp", ":6379")
if err != nil {
t.Fatal(err)
}
defer nc.Close()
nc.SetReadDeadline(time.Now().Add(4 * time.Second))
c := redis.PubSubConn{Conn: redis.NewConn(nc, 0, 0)}
c.Subscribe("c1")
expectPushed(t, c, "Subscribe(c1)", redis.Subscription{Kind: "subscribe", Channel: "c1", Count: 1})
c.Subscribe("c2")
expectPushed(t, c, "Subscribe(c2)", redis.Subscription{Kind: "subscribe", Channel: "c2", Count: 2})
c.PSubscribe("p1")
expectPushed(t, c, "PSubscribe(p1)", redis.Subscription{Kind: "psubscribe", Channel: "p1", Count: 3})
c.PSubscribe("p2")
expectPushed(t, c, "PSubscribe(p2)", redis.Subscription{Kind: "psubscribe", Channel: "p2", Count: 4})
c.PUnsubscribe()
expectPushed(t, c, "Punsubscribe(p1)", redis.Subscription{Kind: "punsubscribe", Channel: "p1", Count: 3})
expectPushed(t, c, "Punsubscribe()", redis.Subscription{Kind: "punsubscribe", Channel: "p2", Count: 2})
pc.Do("PUBLISH", "c1", "hello")
expectPushed(t, c, "PUBLISH c1 hello", redis.Message{Channel: "c1", Data: []byte("hello")})
c.Ping("hello")
expectPushed(t, c, `Ping("hello")`, redis.Pong{"hello"})
c.Conn.Send("PING")
c.Conn.Flush()
expectPushed(t, c, `Send("PING")`, redis.Pong{})
}

View File

@ -0,0 +1,44 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
// Conn represents a connection to a Redis server.
type Conn interface {
// Close closes the connection.
Close() error
// Err returns a non-nil value if the connection is broken. The returned
// value is either the first non-nil value returned from the underlying
// network connection or a protocol parsing error. Applications should
// close broken connections.
Err() error
// Do sends a command to the server and returns the received reply.
Do(commandName string, args ...interface{}) (reply interface{}, err error)
// Send writes the command to the client's output buffer.
Send(commandName string, args ...interface{}) error
// Flush flushes the output buffer to the Redis server.
Flush() error
// Receive receives a single reply from the Redis server
Receive() (reply interface{}, err error)
}

View File

@ -0,0 +1,364 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
"fmt"
"strconv"
)
// ErrNil indicates that a reply value is nil.
var ErrNil = errors.New("redigo: nil returned")
// Int is a helper that converts a command reply to an integer. If err is not
// equal to nil, then Int returns 0, err. Otherwise, Int converts the
// reply to an int as follows:
//
// Reply type Result
// integer int(reply), nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Int(reply interface{}, err error) (int, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
x := int(reply)
if int64(x) != reply {
return 0, strconv.ErrRange
}
return x, nil
case []byte:
n, err := strconv.ParseInt(string(reply), 10, 0)
return int(n), err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, fmt.Errorf("redigo: unexpected type for Int, got type %T", reply)
}
// Int64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
//
// Reply type Result
// integer reply, nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Int64(reply interface{}, err error) (int64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
return reply, nil
case []byte:
n, err := strconv.ParseInt(string(reply), 10, 64)
return n, err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, fmt.Errorf("redigo: unexpected type for Int64, got type %T", reply)
}
var errNegativeInt = errors.New("redigo: unexpected value for Uint64")
// Uint64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
//
// Reply type Result
// integer reply, nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Uint64(reply interface{}, err error) (uint64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
if reply < 0 {
return 0, errNegativeInt
}
return uint64(reply), nil
case []byte:
n, err := strconv.ParseUint(string(reply), 10, 64)
return n, err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, fmt.Errorf("redigo: unexpected type for Uint64, got type %T", reply)
}
// Float64 is a helper that converts a command reply to 64 bit float. If err is
// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts
// the reply to an int as follows:
//
// Reply type Result
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Float64(reply interface{}, err error) (float64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case []byte:
n, err := strconv.ParseFloat(string(reply), 64)
return n, err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, fmt.Errorf("redigo: unexpected type for Float64, got type %T", reply)
}
// String is a helper that converts a command reply to a string. If err is not
// equal to nil, then String returns "", err. Otherwise String converts the
// reply to a string as follows:
//
// Reply type Result
// bulk string string(reply), nil
// simple string reply, nil
// nil "", ErrNil
// other "", error
func String(reply interface{}, err error) (string, error) {
if err != nil {
return "", err
}
switch reply := reply.(type) {
case []byte:
return string(reply), nil
case string:
return reply, nil
case nil:
return "", ErrNil
case Error:
return "", reply
}
return "", fmt.Errorf("redigo: unexpected type for String, got type %T", reply)
}
// Bytes is a helper that converts a command reply to a slice of bytes. If err
// is not equal to nil, then Bytes returns nil, err. Otherwise Bytes converts
// the reply to a slice of bytes as follows:
//
// Reply type Result
// bulk string reply, nil
// simple string []byte(reply), nil
// nil nil, ErrNil
// other nil, error
func Bytes(reply interface{}, err error) ([]byte, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []byte:
return reply, nil
case string:
return []byte(reply), nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, fmt.Errorf("redigo: unexpected type for Bytes, got type %T", reply)
}
// Bool is a helper that converts a command reply to a boolean. If err is not
// equal to nil, then Bool returns false, err. Otherwise Bool converts the
// reply to boolean as follows:
//
// Reply type Result
// integer value != 0, nil
// bulk string strconv.ParseBool(reply)
// nil false, ErrNil
// other false, error
func Bool(reply interface{}, err error) (bool, error) {
if err != nil {
return false, err
}
switch reply := reply.(type) {
case int64:
return reply != 0, nil
case []byte:
return strconv.ParseBool(string(reply))
case nil:
return false, ErrNil
case Error:
return false, reply
}
return false, fmt.Errorf("redigo: unexpected type for Bool, got type %T", reply)
}
// MultiBulk is deprecated. Use Values.
func MultiBulk(reply interface{}, err error) ([]interface{}, error) { return Values(reply, err) }
// Values is a helper that converts an array command reply to a []interface{}.
// If err is not equal to nil, then Values returns nil, err. Otherwise, Values
// converts the reply as follows:
//
// Reply type Result
// array reply, nil
// nil nil, ErrNil
// other nil, error
func Values(reply interface{}, err error) ([]interface{}, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
return reply, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, fmt.Errorf("redigo: unexpected type for Values, got type %T", reply)
}
// Strings is a helper that converts an array command reply to a []string. If
// err is not equal to nil, then Strings returns nil, err. Nil array items are
// converted to "" in the output slice. Strings returns an error if an array
// item is not a bulk string or nil.
func Strings(reply interface{}, err error) ([]string, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
result := make([]string, len(reply))
for i := range reply {
if reply[i] == nil {
continue
}
p, ok := reply[i].([]byte)
if !ok {
return nil, fmt.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i])
}
result[i] = string(p)
}
return result, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, fmt.Errorf("redigo: unexpected type for Strings, got type %T", reply)
}
// Ints is a helper that converts an array command reply to a []int. If
// err is not equal to nil, then Ints returns nil, err.
func Ints(reply interface{}, err error) ([]int, error) {
var ints []int
if reply == nil {
return ints, ErrNil
}
values, err := Values(reply, err)
if err != nil {
return ints, err
}
if err := ScanSlice(values, &ints); err != nil {
return ints, err
}
return ints, nil
}
// StringMap is a helper that converts an array of strings (alternating key, value)
// into a map[string]string. The HGETALL and CONFIG GET commands return replies in this format.
// Requires an even number of values in result.
func StringMap(result interface{}, err error) (map[string]string, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, errors.New("redigo: StringMap expects even number of values result")
}
m := make(map[string]string, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, okKey := values[i].([]byte)
value, okValue := values[i+1].([]byte)
if !okKey || !okValue {
return nil, errors.New("redigo: ScanMap key not a bulk string value")
}
m[string(key)] = string(value)
}
return m, nil
}
// IntMap is a helper that converts an array of strings (alternating key, value)
// into a map[string]int. The HGETALL commands return replies in this format.
// Requires an even number of values in result.
func IntMap(result interface{}, err error) (map[string]int, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, errors.New("redigo: IntMap expects even number of values result")
}
m := make(map[string]int, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte)
if !ok {
return nil, errors.New("redigo: ScanMap key not a bulk string value")
}
value, err := Int(values[i+1], nil)
if err != nil {
return nil, err
}
m[string(key)] = value
}
return m, nil
}
// Int64Map is a helper that converts an array of strings (alternating key, value)
// into a map[string]int64. The HGETALL commands return replies in this format.
// Requires an even number of values in result.
func Int64Map(result interface{}, err error) (map[string]int64, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, errors.New("redigo: Int64Map expects even number of values result")
}
m := make(map[string]int64, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte)
if !ok {
return nil, errors.New("redigo: ScanMap key not a bulk string value")
}
value, err := Int64(values[i+1], nil)
if err != nil {
return nil, err
}
m[string(key)] = value
}
return m, nil
}

View File

@ -0,0 +1,166 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"fmt"
"reflect"
"testing"
"github.com/garyburd/redigo/internal/redistest"
"github.com/garyburd/redigo/redis"
)
type valueError struct {
v interface{}
err error
}
func ve(v interface{}, err error) valueError {
return valueError{v, err}
}
var replyTests = []struct {
name interface{}
actual valueError
expected valueError
}{
{
"ints([v1, v2])",
ve(redis.Ints([]interface{}{[]byte("4"), []byte("5")}, nil)),
ve([]int{4, 5}, nil),
},
{
"ints(nil)",
ve(redis.Ints(nil, nil)),
ve([]int(nil), redis.ErrNil),
},
{
"strings([v1, v2])",
ve(redis.Strings([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
ve([]string{"v1", "v2"}, nil),
},
{
"strings(nil)",
ve(redis.Strings(nil, nil)),
ve([]string(nil), redis.ErrNil),
},
{
"values([v1, v2])",
ve(redis.Values([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
ve([]interface{}{[]byte("v1"), []byte("v2")}, nil),
},
{
"values(nil)",
ve(redis.Values(nil, nil)),
ve([]interface{}(nil), redis.ErrNil),
},
{
"float64(1.0)",
ve(redis.Float64([]byte("1.0"), nil)),
ve(float64(1.0), nil),
},
{
"float64(nil)",
ve(redis.Float64(nil, nil)),
ve(float64(0.0), redis.ErrNil),
},
{
"uint64(1)",
ve(redis.Uint64(int64(1), nil)),
ve(uint64(1), nil),
},
{
"uint64(-1)",
ve(redis.Uint64(int64(-1), nil)),
ve(uint64(0), redis.ErrNegativeInt),
},
}
func TestReply(t *testing.T) {
for _, rt := range replyTests {
if rt.actual.err != rt.expected.err {
t.Errorf("%s returned err %v, want %v", rt.name, rt.actual.err, rt.expected.err)
continue
}
if !reflect.DeepEqual(rt.actual.v, rt.expected.v) {
t.Errorf("%s=%+v, want %+v", rt.name, rt.actual.v, rt.expected.v)
}
}
}
// dial wraps DialTestDB() with a more suitable function name for examples.
func dial() (redis.Conn, error) {
return redistest.Dial()
}
func ExampleBool() {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
c.Do("SET", "foo", 1)
exists, _ := redis.Bool(c.Do("EXISTS", "foo"))
fmt.Printf("%#v\n", exists)
// Output:
// true
}
func ExampleInt() {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
c.Do("SET", "k1", 1)
n, _ := redis.Int(c.Do("GET", "k1"))
fmt.Printf("%#v\n", n)
n, _ = redis.Int(c.Do("INCR", "k1"))
fmt.Printf("%#v\n", n)
// Output:
// 1
// 2
}
func ExampleInts() {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
c.Do("SADD", "set_with_integers", 4, 5, 6)
ints, _ := redis.Ints(c.Do("SMEMBERS", "set_with_integers"))
fmt.Printf("%#v\n", ints)
// Output:
// []int{4, 5, 6}
}
func ExampleString() {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
c.Do("SET", "hello", "world")
s, err := redis.String(c.Do("GET", "hello"))
fmt.Printf("%#v\n", s)
// Output:
// "world"
}

View File

@ -0,0 +1,513 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
)
func ensureLen(d reflect.Value, n int) {
if n > d.Cap() {
d.Set(reflect.MakeSlice(d.Type(), n, n))
} else {
d.SetLen(n)
}
}
func cannotConvert(d reflect.Value, s interface{}) error {
return fmt.Errorf("redigo: Scan cannot convert from %s to %s",
reflect.TypeOf(s), d.Type())
}
func convertAssignBytes(d reflect.Value, s []byte) (err error) {
switch d.Type().Kind() {
case reflect.Float32, reflect.Float64:
var x float64
x, err = strconv.ParseFloat(string(s), d.Type().Bits())
d.SetFloat(x)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var x int64
x, err = strconv.ParseInt(string(s), 10, d.Type().Bits())
d.SetInt(x)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var x uint64
x, err = strconv.ParseUint(string(s), 10, d.Type().Bits())
d.SetUint(x)
case reflect.Bool:
var x bool
x, err = strconv.ParseBool(string(s))
d.SetBool(x)
case reflect.String:
d.SetString(string(s))
case reflect.Slice:
if d.Type().Elem().Kind() != reflect.Uint8 {
err = cannotConvert(d, s)
} else {
d.SetBytes(s)
}
default:
err = cannotConvert(d, s)
}
return
}
func convertAssignInt(d reflect.Value, s int64) (err error) {
switch d.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
d.SetInt(s)
if d.Int() != s {
err = strconv.ErrRange
d.SetInt(0)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if s < 0 {
err = strconv.ErrRange
} else {
x := uint64(s)
d.SetUint(x)
if d.Uint() != x {
err = strconv.ErrRange
d.SetUint(0)
}
}
case reflect.Bool:
d.SetBool(s != 0)
default:
err = cannotConvert(d, s)
}
return
}
func convertAssignValue(d reflect.Value, s interface{}) (err error) {
switch s := s.(type) {
case []byte:
err = convertAssignBytes(d, s)
case int64:
err = convertAssignInt(d, s)
default:
err = cannotConvert(d, s)
}
return err
}
func convertAssignValues(d reflect.Value, s []interface{}) error {
if d.Type().Kind() != reflect.Slice {
return cannotConvert(d, s)
}
ensureLen(d, len(s))
for i := 0; i < len(s); i++ {
if err := convertAssignValue(d.Index(i), s[i]); err != nil {
return err
}
}
return nil
}
func convertAssign(d interface{}, s interface{}) (err error) {
// Handle the most common destination types using type switches and
// fall back to reflection for all other types.
switch s := s.(type) {
case nil:
// ingore
case []byte:
switch d := d.(type) {
case *string:
*d = string(s)
case *int:
*d, err = strconv.Atoi(string(s))
case *bool:
*d, err = strconv.ParseBool(string(s))
case *[]byte:
*d = s
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignBytes(d.Elem(), s)
}
}
case int64:
switch d := d.(type) {
case *int:
x := int(s)
if int64(x) != s {
err = strconv.ErrRange
x = 0
}
*d = x
case *bool:
*d = s != 0
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignInt(d.Elem(), s)
}
}
case []interface{}:
switch d := d.(type) {
case *[]interface{}:
*d = s
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignValues(d.Elem(), s)
}
}
case Error:
err = s
default:
err = cannotConvert(reflect.ValueOf(d), s)
}
return
}
// Scan copies from src to the values pointed at by dest.
//
// The values pointed at by dest must be an integer, float, boolean, string,
// []byte, interface{} or slices of these types. Scan uses the standard strconv
// package to convert bulk strings to numeric and boolean types.
//
// If a dest value is nil, then the corresponding src value is skipped.
//
// If a src element is nil, then the corresponding dest value is not modified.
//
// To enable easy use of Scan in a loop, Scan returns the slice of src
// following the copied values.
func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) {
if len(src) < len(dest) {
return nil, errors.New("redigo: Scan array short")
}
var err error
for i, d := range dest {
err = convertAssign(d, src[i])
if err != nil {
break
}
}
return src[len(dest):], err
}
type fieldSpec struct {
name string
index []int
//omitEmpty bool
}
type structSpec struct {
m map[string]*fieldSpec
l []*fieldSpec
}
func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
return ss.m[string(name)]
}
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
switch {
case f.PkgPath != "":
// Ignore unexported fields.
case f.Anonymous:
// TODO: Handle pointers. Requires change to decoder and
// protection against infinite recursion.
if f.Type.Kind() == reflect.Struct {
compileStructSpec(f.Type, depth, append(index, i), ss)
}
default:
fs := &fieldSpec{name: f.Name}
tag := f.Tag.Get("redis")
p := strings.Split(tag, ",")
if len(p) > 0 {
if p[0] == "-" {
continue
}
if len(p[0]) > 0 {
fs.name = p[0]
}
for _, s := range p[1:] {
switch s {
//case "omitempty":
// fs.omitempty = true
default:
panic(errors.New("redigo: unknown field flag " + s + " for type " + t.Name()))
}
}
}
d, found := depth[fs.name]
if !found {
d = 1 << 30
}
switch {
case len(index) == d:
// At same depth, remove from result.
delete(ss.m, fs.name)
j := 0
for i := 0; i < len(ss.l); i++ {
if fs.name != ss.l[i].name {
ss.l[j] = ss.l[i]
j += 1
}
}
ss.l = ss.l[:j]
case len(index) < d:
fs.index = make([]int, len(index)+1)
copy(fs.index, index)
fs.index[len(index)] = i
depth[fs.name] = len(index)
ss.m[fs.name] = fs
ss.l = append(ss.l, fs)
}
}
}
}
var (
structSpecMutex sync.RWMutex
structSpecCache = make(map[reflect.Type]*structSpec)
defaultFieldSpec = &fieldSpec{}
)
func structSpecForType(t reflect.Type) *structSpec {
structSpecMutex.RLock()
ss, found := structSpecCache[t]
structSpecMutex.RUnlock()
if found {
return ss
}
structSpecMutex.Lock()
defer structSpecMutex.Unlock()
ss, found = structSpecCache[t]
if found {
return ss
}
ss = &structSpec{m: make(map[string]*fieldSpec)}
compileStructSpec(t, make(map[string]int), nil, ss)
structSpecCache[t] = ss
return ss
}
var errScanStructValue = errors.New("redigo: ScanStruct value must be non-nil pointer to a struct")
// ScanStruct scans alternating names and values from src to a struct. The
// HGETALL and CONFIG GET commands return replies in this format.
//
// ScanStruct uses exported field names to match values in the response. Use
// 'redis' field tag to override the name:
//
// Field int `redis:"myName"`
//
// Fields with the tag redis:"-" are ignored.
//
// Integer, float, boolean, string and []byte fields are supported. Scan uses the
// standard strconv package to convert bulk string values to numeric and
// boolean types.
//
// If a src element is nil, then the corresponding field is not modified.
func ScanStruct(src []interface{}, dest interface{}) error {
d := reflect.ValueOf(dest)
if d.Kind() != reflect.Ptr || d.IsNil() {
return errScanStructValue
}
d = d.Elem()
if d.Kind() != reflect.Struct {
return errScanStructValue
}
ss := structSpecForType(d.Type())
if len(src)%2 != 0 {
return errors.New("redigo: ScanStruct expects even number of values in values")
}
for i := 0; i < len(src); i += 2 {
s := src[i+1]
if s == nil {
continue
}
name, ok := src[i].([]byte)
if !ok {
return errors.New("redigo: ScanStruct key not a bulk string value")
}
fs := ss.fieldSpec(name)
if fs == nil {
continue
}
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
return err
}
}
return nil
}
var (
errScanSliceValue = errors.New("redigo: ScanSlice dest must be non-nil pointer to a struct")
)
// ScanSlice scans src to the slice pointed to by dest. The elements the dest
// slice must be integer, float, boolean, string, struct or pointer to struct
// values.
//
// Struct fields must be integer, float, boolean or string values. All struct
// fields are used unless a subset is specified using fieldNames.
func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error {
d := reflect.ValueOf(dest)
if d.Kind() != reflect.Ptr || d.IsNil() {
return errScanSliceValue
}
d = d.Elem()
if d.Kind() != reflect.Slice {
return errScanSliceValue
}
isPtr := false
t := d.Type().Elem()
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
isPtr = true
t = t.Elem()
}
if t.Kind() != reflect.Struct {
ensureLen(d, len(src))
for i, s := range src {
if s == nil {
continue
}
if err := convertAssignValue(d.Index(i), s); err != nil {
return err
}
}
return nil
}
ss := structSpecForType(t)
fss := ss.l
if len(fieldNames) > 0 {
fss = make([]*fieldSpec, len(fieldNames))
for i, name := range fieldNames {
fss[i] = ss.m[name]
if fss[i] == nil {
return errors.New("redigo: ScanSlice bad field name " + name)
}
}
}
if len(fss) == 0 {
return errors.New("redigo: ScanSlice no struct fields")
}
n := len(src) / len(fss)
if n*len(fss) != len(src) {
return errors.New("redigo: ScanSlice length not a multiple of struct field count")
}
ensureLen(d, n)
for i := 0; i < n; i++ {
d := d.Index(i)
if isPtr {
if d.IsNil() {
d.Set(reflect.New(t))
}
d = d.Elem()
}
for j, fs := range fss {
s := src[i*len(fss)+j]
if s == nil {
continue
}
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
return err
}
}
}
return nil
}
// Args is a helper for constructing command arguments from structured values.
type Args []interface{}
// Add returns the result of appending value to args.
func (args Args) Add(value ...interface{}) Args {
return append(args, value...)
}
// AddFlat returns the result of appending the flattened value of v to args.
//
// Maps are flattened by appending the alternating keys and map values to args.
//
// Slices are flattened by appending the slice elements to args.
//
// Structs are flattened by appending the alternating names and values of
// exported fields to args. If v is a nil struct pointer, then nothing is
// appended. The 'redis' field tag overrides struct field names. See ScanStruct
// for more information on the use of the 'redis' field tag.
//
// Other types are appended to args as is.
func (args Args) AddFlat(v interface{}) Args {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Struct:
args = flattenStruct(args, rv)
case reflect.Slice:
for i := 0; i < rv.Len(); i++ {
args = append(args, rv.Index(i).Interface())
}
case reflect.Map:
for _, k := range rv.MapKeys() {
args = append(args, k.Interface(), rv.MapIndex(k).Interface())
}
case reflect.Ptr:
if rv.Type().Elem().Kind() == reflect.Struct {
if !rv.IsNil() {
args = flattenStruct(args, rv.Elem())
}
} else {
args = append(args, v)
}
default:
args = append(args, v)
}
return args
}
func flattenStruct(args Args, v reflect.Value) Args {
ss := structSpecForType(v.Type())
for _, fs := range ss.l {
fv := v.FieldByIndex(fs.index)
args = append(args, fs.name, fv.Interface())
}
return args
}

View File

@ -0,0 +1,412 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"fmt"
"github.com/garyburd/redigo/redis"
"math"
"reflect"
"testing"
)
var scanConversionTests = []struct {
src interface{}
dest interface{}
}{
{[]byte("-inf"), math.Inf(-1)},
{[]byte("+inf"), math.Inf(1)},
{[]byte("0"), float64(0)},
{[]byte("3.14159"), float64(3.14159)},
{[]byte("3.14"), float32(3.14)},
{[]byte("-100"), int(-100)},
{[]byte("101"), int(101)},
{int64(102), int(102)},
{[]byte("103"), uint(103)},
{int64(104), uint(104)},
{[]byte("105"), int8(105)},
{int64(106), int8(106)},
{[]byte("107"), uint8(107)},
{int64(108), uint8(108)},
{[]byte("0"), false},
{int64(0), false},
{[]byte("f"), false},
{[]byte("1"), true},
{int64(1), true},
{[]byte("t"), true},
{[]byte("hello"), "hello"},
{[]byte("world"), []byte("world")},
{[]interface{}{[]byte("foo")}, []interface{}{[]byte("foo")}},
{[]interface{}{[]byte("foo")}, []string{"foo"}},
{[]interface{}{[]byte("hello"), []byte("world")}, []string{"hello", "world"}},
{[]interface{}{[]byte("bar")}, [][]byte{[]byte("bar")}},
{[]interface{}{[]byte("1")}, []int{1}},
{[]interface{}{[]byte("1"), []byte("2")}, []int{1, 2}},
{[]interface{}{[]byte("1"), []byte("2")}, []float64{1, 2}},
{[]interface{}{[]byte("1")}, []byte{1}},
{[]interface{}{[]byte("1")}, []bool{true}},
}
func TestScanConversion(t *testing.T) {
for _, tt := range scanConversionTests {
values := []interface{}{tt.src}
dest := reflect.New(reflect.TypeOf(tt.dest))
values, err := redis.Scan(values, dest.Interface())
if err != nil {
t.Errorf("Scan(%v) returned error %v", tt, err)
continue
}
if !reflect.DeepEqual(tt.dest, dest.Elem().Interface()) {
t.Errorf("Scan(%v) returned %v, want %v", tt, dest.Elem().Interface(), tt.dest)
}
}
}
var scanConversionErrorTests = []struct {
src interface{}
dest interface{}
}{
{[]byte("1234"), byte(0)},
{int64(1234), byte(0)},
{[]byte("-1"), byte(0)},
{int64(-1), byte(0)},
{[]byte("junk"), false},
{redis.Error("blah"), false},
}
func TestScanConversionError(t *testing.T) {
for _, tt := range scanConversionErrorTests {
values := []interface{}{tt.src}
dest := reflect.New(reflect.TypeOf(tt.dest))
values, err := redis.Scan(values, dest.Interface())
if err == nil {
t.Errorf("Scan(%v) did not return error", tt)
}
}
}
func ExampleScan() {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
c.Send("HMSET", "album:1", "title", "Red", "rating", 5)
c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1)
c.Send("HMSET", "album:3", "title", "Beat")
c.Send("LPUSH", "albums", "1")
c.Send("LPUSH", "albums", "2")
c.Send("LPUSH", "albums", "3")
values, err := redis.Values(c.Do("SORT", "albums",
"BY", "album:*->rating",
"GET", "album:*->title",
"GET", "album:*->rating"))
if err != nil {
panic(err)
}
for len(values) > 0 {
var title string
rating := -1 // initialize to illegal value to detect nil.
values, err = redis.Scan(values, &title, &rating)
if err != nil {
panic(err)
}
if rating == -1 {
fmt.Println(title, "not-rated")
} else {
fmt.Println(title, rating)
}
}
// Output:
// Beat not-rated
// Earthbound 1
// Red 5
}
type s0 struct {
X int
Y int `redis:"y"`
Bt bool
}
type s1 struct {
X int `redis:"-"`
I int `redis:"i"`
U uint `redis:"u"`
S string `redis:"s"`
P []byte `redis:"p"`
B bool `redis:"b"`
Bt bool
Bf bool
s0
}
var scanStructTests = []struct {
title string
reply []string
value interface{}
}{
{"basic",
[]string{"i", "-1234", "u", "5678", "s", "hello", "p", "world", "b", "t", "Bt", "1", "Bf", "0", "X", "123", "y", "456"},
&s1{I: -1234, U: 5678, S: "hello", P: []byte("world"), B: true, Bt: true, Bf: false, s0: s0{X: 123, Y: 456}},
},
}
func TestScanStruct(t *testing.T) {
for _, tt := range scanStructTests {
var reply []interface{}
for _, v := range tt.reply {
reply = append(reply, []byte(v))
}
value := reflect.New(reflect.ValueOf(tt.value).Type().Elem())
if err := redis.ScanStruct(reply, value.Interface()); err != nil {
t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err)
}
if !reflect.DeepEqual(value.Interface(), tt.value) {
t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value)
}
}
}
func TestBadScanStructArgs(t *testing.T) {
x := []interface{}{"A", "b"}
test := func(v interface{}) {
if err := redis.ScanStruct(x, v); err == nil {
t.Errorf("Expect error for ScanStruct(%T, %T)", x, v)
}
}
test(nil)
var v0 *struct{}
test(v0)
var v1 int
test(&v1)
x = x[:1]
v2 := struct{ A string }{}
test(&v2)
}
var scanSliceTests = []struct {
src []interface{}
fieldNames []string
ok bool
dest interface{}
}{
{
[]interface{}{[]byte("1"), nil, []byte("-1")},
nil,
true,
[]int{1, 0, -1},
},
{
[]interface{}{[]byte("1"), nil, []byte("2")},
nil,
true,
[]uint{1, 0, 2},
},
{
[]interface{}{[]byte("-1")},
nil,
false,
[]uint{1},
},
{
[]interface{}{[]byte("hello"), nil, []byte("world")},
nil,
true,
[][]byte{[]byte("hello"), nil, []byte("world")},
},
{
[]interface{}{[]byte("hello"), nil, []byte("world")},
nil,
true,
[]string{"hello", "", "world"},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
nil,
true,
[]struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
},
{
[]interface{}{[]byte("a1"), []byte("b1")},
nil,
false,
[]struct{ A, B, C string }{{"a1", "b1", ""}},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
nil,
true,
[]*struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
[]string{"A", "B"},
true,
[]struct{ A, C, B string }{{"a1", "", "b1"}, {"a2", "", "b2"}},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
nil,
false,
[]struct{}{},
},
}
func TestScanSlice(t *testing.T) {
for _, tt := range scanSliceTests {
typ := reflect.ValueOf(tt.dest).Type()
dest := reflect.New(typ)
err := redis.ScanSlice(tt.src, dest.Interface(), tt.fieldNames...)
if tt.ok != (err == nil) {
t.Errorf("ScanSlice(%v, []%s, %v) returned error %v", tt.src, typ, tt.fieldNames, err)
continue
}
if tt.ok && !reflect.DeepEqual(dest.Elem().Interface(), tt.dest) {
t.Errorf("ScanSlice(src, []%s) returned %#v, want %#v", typ, dest.Elem().Interface(), tt.dest)
}
}
}
func ExampleScanSlice() {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
c.Send("HMSET", "album:1", "title", "Red", "rating", 5)
c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1)
c.Send("HMSET", "album:3", "title", "Beat", "rating", 4)
c.Send("LPUSH", "albums", "1")
c.Send("LPUSH", "albums", "2")
c.Send("LPUSH", "albums", "3")
values, err := redis.Values(c.Do("SORT", "albums",
"BY", "album:*->rating",
"GET", "album:*->title",
"GET", "album:*->rating"))
if err != nil {
panic(err)
}
var albums []struct {
Title string
Rating int
}
if err := redis.ScanSlice(values, &albums); err != nil {
panic(err)
}
fmt.Printf("%v\n", albums)
// Output:
// [{Earthbound 1} {Beat 4} {Red 5}]
}
var argsTests = []struct {
title string
actual redis.Args
expected redis.Args
}{
{"struct ptr",
redis.Args{}.AddFlat(&struct {
I int `redis:"i"`
U uint `redis:"u"`
S string `redis:"s"`
P []byte `redis:"p"`
Bt bool
Bf bool
}{
-1234, 5678, "hello", []byte("world"), true, false,
}),
redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "Bt", true, "Bf", false},
},
{"struct",
redis.Args{}.AddFlat(struct{ I int }{123}),
redis.Args{"I", 123},
},
{"slice",
redis.Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2),
redis.Args{1, "a", "b", "c", 2},
},
}
func TestArgs(t *testing.T) {
for _, tt := range argsTests {
if !reflect.DeepEqual(tt.actual, tt.expected) {
t.Fatalf("%s is %v, want %v", tt.title, tt.actual, tt.expected)
}
}
}
func ExampleArgs() {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
var p1, p2 struct {
Title string `redis:"title"`
Author string `redis:"author"`
Body string `redis:"body"`
}
p1.Title = "Example"
p1.Author = "Gary"
p1.Body = "Hello"
if _, err := c.Do("HMSET", redis.Args{}.Add("id1").AddFlat(&p1)...); err != nil {
panic(err)
}
m := map[string]string{
"title": "Example2",
"author": "Steve",
"body": "Map",
}
if _, err := c.Do("HMSET", redis.Args{}.Add("id2").AddFlat(m)...); err != nil {
panic(err)
}
for _, id := range []string{"id1", "id2"} {
v, err := redis.Values(c.Do("HGETALL", id))
if err != nil {
panic(err)
}
if err := redis.ScanStruct(v, &p2); err != nil {
panic(err)
}
fmt.Printf("%+v\n", p2)
}
// Output:
// {Title:Example Author:Gary Body:Hello}
// {Title:Example2 Author:Steve Body:Map}
}

View File

@ -0,0 +1,86 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"crypto/sha1"
"encoding/hex"
"io"
"strings"
)
// Script encapsulates the source, hash and key count for a Lua script. See
// http://redis.io/commands/eval for information on scripts in Redis.
type Script struct {
keyCount int
src string
hash string
}
// NewScript returns a new script object. If keyCount is greater than or equal
// to zero, then the count is automatically inserted in the EVAL command
// argument list. If keyCount is less than zero, then the application supplies
// the count as the first value in the keysAndArgs argument to the Do, Send and
// SendHash methods.
func NewScript(keyCount int, src string) *Script {
h := sha1.New()
io.WriteString(h, src)
return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))}
}
func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} {
var args []interface{}
if s.keyCount < 0 {
args = make([]interface{}, 1+len(keysAndArgs))
args[0] = spec
copy(args[1:], keysAndArgs)
} else {
args = make([]interface{}, 2+len(keysAndArgs))
args[0] = spec
args[1] = s.keyCount
copy(args[2:], keysAndArgs)
}
return args
}
// Do evaluates the script. Under the covers, Do optimistically evaluates the
// script using the EVALSHA command. If the command fails because the script is
// not loaded, then Do evaluates the script using the EVAL command (thus
// causing the script to load).
func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) {
v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...)
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...)
}
return v, err
}
// SendHash evaluates the script without waiting for the reply. The script is
// evaluated with the EVALSHA command. The application must ensure that the
// script is loaded by a previous call to Send, Do or Load methods.
func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error {
return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...)
}
// Send evaluates the script without waiting for the reply.
func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error {
return c.Send("EVAL", s.args(s.src, keysAndArgs)...)
}
// Load loads the script without evaluating it.
func (s *Script) Load(c Conn) error {
_, err := c.Do("SCRIPT", "LOAD", s.src)
return err
}

View File

@ -0,0 +1,93 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/garyburd/redigo/internal/redistest"
"github.com/garyburd/redigo/redis"
)
func ExampleScript(c redis.Conn, reply interface{}, err error) {
// Initialize a package-level variable with a script.
var getScript = redis.NewScript(1, `return redis.call('get', KEYS[1])`)
// In a function, use the script Do method to evaluate the script. The Do
// method optimistically uses the EVALSHA command. If the script is not
// loaded, then the Do method falls back to the EVAL command.
reply, err = getScript.Do(c, "foo")
}
func TestScript(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
// To test fall back in Do, we make script unique by adding comment with current time.
script := fmt.Sprintf("--%d\nreturn {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", time.Now().UnixNano())
s := redis.NewScript(2, script)
reply := []interface{}{[]byte("key1"), []byte("key2"), []byte("arg1"), []byte("arg2")}
v, err := s.Do(c, "key1", "key2", "arg1", "arg2")
if err != nil {
t.Errorf("s.Do(c, ...) returned %v", err)
}
if !reflect.DeepEqual(v, reply) {
t.Errorf("s.Do(c, ..); = %v, want %v", v, reply)
}
err = s.Load(c)
if err != nil {
t.Errorf("s.Load(c) returned %v", err)
}
err = s.SendHash(c, "key1", "key2", "arg1", "arg2")
if err != nil {
t.Errorf("s.SendHash(c, ...) returned %v", err)
}
err = c.Flush()
if err != nil {
t.Errorf("c.Flush() returned %v", err)
}
v, err = c.Receive()
if !reflect.DeepEqual(v, reply) {
t.Errorf("s.SendHash(c, ..); c.Receive() = %v, want %v", v, reply)
}
err = s.Send(c, "key1", "key2", "arg1", "arg2")
if err != nil {
t.Errorf("s.Send(c, ...) returned %v", err)
}
err = c.Flush()
if err != nil {
t.Errorf("c.Flush() returned %v", err)
}
v, err = c.Receive()
if !reflect.DeepEqual(v, reply) {
t.Errorf("s.Send(c, ..); c.Receive() = %v, want %v", v, reply)
}
}

View File

@ -0,0 +1,38 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bufio"
"net"
"time"
)
func SetNowFunc(f func() time.Time) {
nowFunc = f
}
type nopCloser struct{ net.Conn }
func (nopCloser) Close() error { return nil }
// NewConnBufio is a hook for tests.
func NewConnBufio(rw bufio.ReadWriter) Conn {
return &conn{br: rw.Reader, bw: rw.Writer, conn: nopCloser{}}
}
var (
ErrNegativeInt = errNegativeInt
)

View File

@ -0,0 +1,113 @@
// Copyright 2013 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"fmt"
"github.com/garyburd/redigo/redis"
)
// zpop pops a value from the ZSET key using WATCH/MULTI/EXEC commands.
func zpop(c redis.Conn, key string) (result string, err error) {
defer func() {
// Return connection to normal state on error.
if err != nil {
c.Do("DISCARD")
}
}()
// Loop until transaction is successful.
for {
if _, err := c.Do("WATCH", key); err != nil {
return "", err
}
members, err := redis.Strings(c.Do("ZRANGE", key, 0, 0))
if err != nil {
return "", err
}
if len(members) != 1 {
return "", redis.ErrNil
}
c.Send("MULTI")
c.Send("ZREM", key, members[0])
queued, err := c.Do("EXEC")
if err != nil {
return "", err
}
if queued != nil {
result = members[0]
break
}
}
return result, nil
}
// zpopScript pops a value from a ZSET.
var zpopScript = redis.NewScript(1, `
local r = redis.call('ZRANGE', KEYS[1], 0, 0)
if r ~= nil then
r = r[1]
redis.call('ZREM', KEYS[1], r)
end
return r
`)
// This example implements ZPOP as described at
// http://redis.io/topics/transactions using WATCH/MULTI/EXEC and scripting.
func Example_zpop() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
// Add test data using a pipeline.
for i, member := range []string{"red", "blue", "green"} {
c.Send("ZADD", "zset", i, member)
}
if _, err := c.Do(""); err != nil {
fmt.Println(err)
return
}
// Pop using WATCH/MULTI/EXEC
v, err := zpop(c, "zset")
if err != nil {
fmt.Println(err)
return
}
fmt.Println(v)
// Pop using a script.
v, err = redis.String(zpopScript.Do(c, "zset"))
if err != nil {
fmt.Println(err)
return
}
fmt.Println(v)
// Output:
// red
// blue
}

View File

@ -0,0 +1,152 @@
// Copyright 2014 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redisx
import (
"errors"
"sync"
"github.com/garyburd/redigo/internal"
"github.com/garyburd/redigo/redis"
)
// ConnMux multiplexes one or more connections to a single underlying
// connection. The ConnMux connections do not support concurrency, commands
// that associate server side state with the connection or commands that put
// the connection in a special mode.
type ConnMux struct {
c redis.Conn
sendMu sync.Mutex
sendID uint
recvMu sync.Mutex
recvID uint
recvWait map[uint]chan struct{}
}
func NewConnMux(c redis.Conn) *ConnMux {
return &ConnMux{c: c, recvWait: make(map[uint]chan struct{})}
}
// Get gets a connection. The application must close the returned connection.
func (p *ConnMux) Get() redis.Conn {
c := &muxConn{p: p}
c.ids = c.buf[:0]
return c
}
// Close closes the underlying connection.
func (p *ConnMux) Close() error {
return p.c.Close()
}
type muxConn struct {
p *ConnMux
ids []uint
buf [8]uint
}
func (c *muxConn) send(flush bool, cmd string, args ...interface{}) error {
if internal.LookupCommandInfo(cmd).Set != 0 {
return errors.New("command not supported by mux pool")
}
p := c.p
p.sendMu.Lock()
id := p.sendID
c.ids = append(c.ids, id)
p.sendID++
err := p.c.Send(cmd, args...)
if flush {
err = p.c.Flush()
}
p.sendMu.Unlock()
return err
}
func (c *muxConn) Send(cmd string, args ...interface{}) error {
return c.send(false, cmd, args...)
}
func (c *muxConn) Flush() error {
p := c.p
p.sendMu.Lock()
err := p.c.Flush()
p.sendMu.Unlock()
return err
}
func (c *muxConn) Receive() (interface{}, error) {
if len(c.ids) == 0 {
return nil, errors.New("mux pool underflow")
}
id := c.ids[0]
c.ids = c.ids[1:]
if len(c.ids) == 0 {
c.ids = c.buf[:0]
}
p := c.p
p.recvMu.Lock()
if p.recvID != id {
ch := make(chan struct{})
p.recvWait[id] = ch
p.recvMu.Unlock()
<-ch
p.recvMu.Lock()
if p.recvID != id {
panic("out of sync")
}
}
v, err := p.c.Receive()
id++
p.recvID = id
ch, ok := p.recvWait[id]
if ok {
delete(p.recvWait, id)
}
p.recvMu.Unlock()
if ok {
ch <- struct{}{}
}
return v, err
}
func (c *muxConn) Close() error {
var err error
if len(c.ids) == 0 {
return nil
}
c.Flush()
for _ = range c.ids {
_, err = c.Receive()
}
return err
}
func (c *muxConn) Do(cmd string, args ...interface{}) (interface{}, error) {
if err := c.send(true, cmd, args...); err != nil {
return nil, err
}
return c.Receive()
}
func (c *muxConn) Err() error {
return c.p.c.Err()
}

View File

@ -0,0 +1,259 @@
// Copyright 2014 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redisx_test
import (
"net/textproto"
"sync"
"testing"
"github.com/garyburd/redigo/internal/redistest"
"github.com/garyburd/redigo/redis"
"github.com/garyburd/redigo/redisx"
)
func TestConnMux(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
m := redisx.NewConnMux(c)
defer m.Close()
c1 := m.Get()
c2 := m.Get()
c1.Send("ECHO", "hello")
c2.Send("ECHO", "world")
c1.Flush()
c2.Flush()
s, err := redis.String(c1.Receive())
if err != nil {
t.Fatal(err)
}
if s != "hello" {
t.Fatalf("echo returned %q, want %q", s, "hello")
}
s, err = redis.String(c2.Receive())
if err != nil {
t.Fatal(err)
}
if s != "world" {
t.Fatalf("echo returned %q, want %q", s, "world")
}
c1.Close()
c2.Close()
}
func TestConnMuxClose(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
m := redisx.NewConnMux(c)
defer m.Close()
c1 := m.Get()
c2 := m.Get()
if err := c1.Send("ECHO", "hello"); err != nil {
t.Fatal(err)
}
if err := c1.Close(); err != nil {
t.Fatal(err)
}
if err := c2.Send("ECHO", "world"); err != nil {
t.Fatal(err)
}
if err := c2.Flush(); err != nil {
t.Fatal(err)
}
s, err := redis.String(c2.Receive())
if err != nil {
t.Fatal(err)
}
if s != "world" {
t.Fatalf("echo returned %q, want %q", s, "world")
}
c2.Close()
}
func BenchmarkConn(b *testing.B) {
b.StopTimer()
c, err := redistest.Dial()
if err != nil {
b.Fatalf("error connection to database, %v", err)
}
defer c.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkConnMux(b *testing.B) {
b.StopTimer()
c, err := redistest.Dial()
if err != nil {
b.Fatalf("error connection to database, %v", err)
}
m := redisx.NewConnMux(c)
defer m.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c := m.Get()
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
c.Close()
}
}
func BenchmarkPool(b *testing.B) {
b.StopTimer()
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 1}
defer p.Close()
// Fill the pool.
c := p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c := p.Get()
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
c.Close()
}
}
const numConcurrent = 10
func BenchmarkConnMuxConcurrent(b *testing.B) {
b.StopTimer()
c, err := redistest.Dial()
if err != nil {
b.Fatalf("error connection to database, %v", err)
}
defer c.Close()
m := redisx.NewConnMux(c)
var wg sync.WaitGroup
wg.Add(numConcurrent)
b.StartTimer()
for i := 0; i < numConcurrent; i++ {
go func() {
defer wg.Done()
for i := 0; i < b.N; i++ {
c := m.Get()
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
c.Close()
}
}()
}
wg.Wait()
}
func BenchmarkPoolConcurrent(b *testing.B) {
b.StopTimer()
p := redis.Pool{Dial: redistest.Dial, MaxIdle: numConcurrent}
defer p.Close()
// Fill the pool.
conns := make([]redis.Conn, numConcurrent)
for i := range conns {
c := p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
conns[i] = c
}
for _, c := range conns {
c.Close()
}
var wg sync.WaitGroup
wg.Add(numConcurrent)
b.StartTimer()
for i := 0; i < numConcurrent; i++ {
go func() {
defer wg.Done()
for i := 0; i < b.N; i++ {
c := p.Get()
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
c.Close()
}
}()
}
wg.Wait()
}
func BenchmarkPipelineConcurrency(b *testing.B) {
b.StopTimer()
c, err := redistest.Dial()
if err != nil {
b.Fatalf("error connection to database, %v", err)
}
defer c.Close()
var wg sync.WaitGroup
wg.Add(numConcurrent)
var pipeline textproto.Pipeline
b.StartTimer()
for i := 0; i < numConcurrent; i++ {
go func() {
defer wg.Done()
for i := 0; i < b.N; i++ {
id := pipeline.Next()
pipeline.StartRequest(id)
c.Send("PING")
c.Flush()
pipeline.EndRequest(id)
pipeline.StartResponse(id)
_, err := c.Receive()
if err != nil {
b.Fatal(err)
}
pipeline.EndResponse(id)
}
}()
}
wg.Wait()
}

View File

@ -0,0 +1,17 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package redisx contains experimental features for Redigo. Features in this
// package may be modified or deleted at any time.
package redisx // import "github.com/garyburd/redigo/redisx"

View File

@ -0,0 +1,27 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,7 @@
context
=======
[![Build Status](https://travis-ci.org/gorilla/context.png?branch=master)](https://travis-ci.org/gorilla/context)
gorilla/context is a general purpose registry for global request variables.
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context

View File

@ -0,0 +1,143 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context
import (
"net/http"
"sync"
"time"
)
var (
mutex sync.RWMutex
data = make(map[*http.Request]map[interface{}]interface{})
datat = make(map[*http.Request]int64)
)
// Set stores a value for a given key in a given request.
func Set(r *http.Request, key, val interface{}) {
mutex.Lock()
if data[r] == nil {
data[r] = make(map[interface{}]interface{})
datat[r] = time.Now().Unix()
}
data[r][key] = val
mutex.Unlock()
}
// Get returns a value stored for a given key in a given request.
func Get(r *http.Request, key interface{}) interface{} {
mutex.RLock()
if ctx := data[r]; ctx != nil {
value := ctx[key]
mutex.RUnlock()
return value
}
mutex.RUnlock()
return nil
}
// GetOk returns stored value and presence state like multi-value return of map access.
func GetOk(r *http.Request, key interface{}) (interface{}, bool) {
mutex.RLock()
if _, ok := data[r]; ok {
value, ok := data[r][key]
mutex.RUnlock()
return value, ok
}
mutex.RUnlock()
return nil, false
}
// GetAll returns all stored values for the request as a map. Nil is returned for invalid requests.
func GetAll(r *http.Request) map[interface{}]interface{} {
mutex.RLock()
if context, ok := data[r]; ok {
result := make(map[interface{}]interface{}, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result
}
mutex.RUnlock()
return nil
}
// GetAllOk returns all stored values for the request as a map and a boolean value that indicates if
// the request was registered.
func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) {
mutex.RLock()
context, ok := data[r]
result := make(map[interface{}]interface{}, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result, ok
}
// Delete removes a value stored for a given key in a given request.
func Delete(r *http.Request, key interface{}) {
mutex.Lock()
if data[r] != nil {
delete(data[r], key)
}
mutex.Unlock()
}
// Clear removes all values stored for a given request.
//
// This is usually called by a handler wrapper to clean up request
// variables at the end of a request lifetime. See ClearHandler().
func Clear(r *http.Request) {
mutex.Lock()
clear(r)
mutex.Unlock()
}
// clear is Clear without the lock.
func clear(r *http.Request) {
delete(data, r)
delete(datat, r)
}
// Purge removes request data stored for longer than maxAge, in seconds.
// It returns the amount of requests removed.
//
// If maxAge <= 0, all request data is removed.
//
// This is only used for sanity check: in case context cleaning was not
// properly set some request data can be kept forever, consuming an increasing
// amount of memory. In case this is detected, Purge() must be called
// periodically until the problem is fixed.
func Purge(maxAge int) int {
mutex.Lock()
count := 0
if maxAge <= 0 {
count = len(data)
data = make(map[*http.Request]map[interface{}]interface{})
datat = make(map[*http.Request]int64)
} else {
min := time.Now().Unix() - int64(maxAge)
for r := range data {
if datat[r] < min {
clear(r)
count++
}
}
}
mutex.Unlock()
return count
}
// ClearHandler wraps an http.Handler and clears request values at the end
// of a request lifetime.
func ClearHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer Clear(r)
h.ServeHTTP(w, r)
})
}

View File

@ -0,0 +1,161 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context
import (
"net/http"
"testing"
)
type keyType int
const (
key1 keyType = iota
key2
)
func TestContext(t *testing.T) {
assertEqual := func(val interface{}, exp interface{}) {
if val != exp {
t.Errorf("Expected %v, got %v.", exp, val)
}
}
r, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
emptyR, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
// Get()
assertEqual(Get(r, key1), nil)
// Set()
Set(r, key1, "1")
assertEqual(Get(r, key1), "1")
assertEqual(len(data[r]), 1)
Set(r, key2, "2")
assertEqual(Get(r, key2), "2")
assertEqual(len(data[r]), 2)
//GetOk
value, ok := GetOk(r, key1)
assertEqual(value, "1")
assertEqual(ok, true)
value, ok = GetOk(r, "not exists")
assertEqual(value, nil)
assertEqual(ok, false)
Set(r, "nil value", nil)
value, ok = GetOk(r, "nil value")
assertEqual(value, nil)
assertEqual(ok, true)
// GetAll()
values := GetAll(r)
assertEqual(len(values), 3)
// GetAll() for empty request
values = GetAll(emptyR)
if values != nil {
t.Error("GetAll didn't return nil value for invalid request")
}
// GetAllOk()
values, ok = GetAllOk(r)
assertEqual(len(values), 3)
assertEqual(ok, true)
// GetAllOk() for empty request
values, ok = GetAllOk(emptyR)
assertEqual(value, nil)
assertEqual(ok, false)
// Delete()
Delete(r, key1)
assertEqual(Get(r, key1), nil)
assertEqual(len(data[r]), 2)
Delete(r, key2)
assertEqual(Get(r, key2), nil)
assertEqual(len(data[r]), 1)
// Clear()
Clear(r)
assertEqual(len(data), 0)
}
func parallelReader(r *http.Request, key string, iterations int, wait, done chan struct{}) {
<-wait
for i := 0; i < iterations; i++ {
Get(r, key)
}
done <- struct{}{}
}
func parallelWriter(r *http.Request, key, value string, iterations int, wait, done chan struct{}) {
<-wait
for i := 0; i < iterations; i++ {
Set(r, key, value)
}
done <- struct{}{}
}
func benchmarkMutex(b *testing.B, numReaders, numWriters, iterations int) {
b.StopTimer()
r, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
done := make(chan struct{})
b.StartTimer()
for i := 0; i < b.N; i++ {
wait := make(chan struct{})
for i := 0; i < numReaders; i++ {
go parallelReader(r, "test", iterations, wait, done)
}
for i := 0; i < numWriters; i++ {
go parallelWriter(r, "test", "123", iterations, wait, done)
}
close(wait)
for i := 0; i < numReaders+numWriters; i++ {
<-done
}
}
}
func BenchmarkMutexSameReadWrite1(b *testing.B) {
benchmarkMutex(b, 1, 1, 32)
}
func BenchmarkMutexSameReadWrite2(b *testing.B) {
benchmarkMutex(b, 2, 2, 32)
}
func BenchmarkMutexSameReadWrite4(b *testing.B) {
benchmarkMutex(b, 4, 4, 32)
}
func BenchmarkMutex1(b *testing.B) {
benchmarkMutex(b, 2, 8, 32)
}
func BenchmarkMutex2(b *testing.B) {
benchmarkMutex(b, 16, 4, 64)
}
func BenchmarkMutex3(b *testing.B) {
benchmarkMutex(b, 1, 2, 128)
}
func BenchmarkMutex4(b *testing.B) {
benchmarkMutex(b, 128, 32, 256)
}
func BenchmarkMutex5(b *testing.B) {
benchmarkMutex(b, 1024, 2048, 64)
}
func BenchmarkMutex6(b *testing.B) {
benchmarkMutex(b, 2048, 1024, 512)
}

View File

@ -0,0 +1,82 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package context stores values shared during a request lifetime.
For example, a router can set variables extracted from the URL and later
application handlers can access those values, or it can be used to store
sessions values to be saved at the end of a request. There are several
others common uses.
The idea was posted by Brad Fitzpatrick to the go-nuts mailing list:
http://groups.google.com/group/golang-nuts/msg/e2d679d303aa5d53
Here's the basic usage: first define the keys that you will need. The key
type is interface{} so a key can be of any type that supports equality.
Here we define a key using a custom int type to avoid name collisions:
package foo
import (
"github.com/gorilla/context"
)
type key int
const MyKey key = 0
Then set a variable. Variables are bound to an http.Request object, so you
need a request instance to set a value:
context.Set(r, MyKey, "bar")
The application can later access the variable using the same key you provided:
func MyHandler(w http.ResponseWriter, r *http.Request) {
// val is "bar".
val := context.Get(r, foo.MyKey)
// returns ("bar", true)
val, ok := context.GetOk(r, foo.MyKey)
// ...
}
And that's all about the basic usage. We discuss some other ideas below.
Any type can be stored in the context. To enforce a given type, make the key
private and wrap Get() and Set() to accept and return values of a specific
type:
type key int
const mykey key = 0
// GetMyKey returns a value for this package from the request values.
func GetMyKey(r *http.Request) SomeType {
if rv := context.Get(r, mykey); rv != nil {
return rv.(SomeType)
}
return nil
}
// SetMyKey sets a value for this package in the request values.
func SetMyKey(r *http.Request, val SomeType) {
context.Set(r, mykey, val)
}
Variables must be cleared at the end of a request, to remove all values
that were stored. This can be done in an http.Handler, after a request was
served. Just call Clear() passing the request:
context.Clear(r)
...or use ClearHandler(), which conveniently wraps an http.Handler to clear
variables at the end of a request lifetime.
The Routers from the packages gorilla/mux and gorilla/pat call Clear()
so if you are using either of them you don't need to clear the context manually.
*/
package context

View File

@ -0,0 +1,27 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,7 @@
mux
===
[![Build Status](https://travis-ci.org/gorilla/mux.png?branch=master)](https://travis-ci.org/gorilla/mux)
gorilla/mux is a powerful URL router and dispatcher.
Read the full documentation here: http://www.gorillatoolkit.org/pkg/mux

Some files were not shown because too many files have changed in this diff Show More