Merge branch 'Xe/feat/unit-and-functional-testing'

Adds unit and functional testing to routed.
This commit is contained in:
Cadey Ratio 2017-10-06 16:27:03 -07:00
commit 278af58424
20 changed files with 1127 additions and 441 deletions

View File

@ -1 +1 @@
route
route-cli

View File

@ -35,7 +35,7 @@ func main() {
client, _ := tun2.NewClient(cfg)
for {
err := client.Connect()
err := client.Connect(context.Background())
if err != nil {
ln.Error(context.Background(), err, ln.Action("client connection failed"))
}

View File

@ -0,0 +1,10 @@
package elfs
import "testing"
func TestMakeName(t *testing.T) {
n := MakeName()
if len(n) == 0 {
t.Fatalf("MakeName had a zero output")
}
}

View File

@ -0,0 +1,33 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
func TestTrace(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var executed bool
var handler http.Handler = Trace(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
executed = true
w.WriteHeader(http.StatusOK)
}))
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("error when creating request: %v", err)
}
req = req.WithContext(ctx)
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
if !executed {
t.Fatal("middleware Trace doesn't pass through to underlying handler")
}
}

View File

@ -0,0 +1,41 @@
package routecrypto
import "testing"
var (
rsaPrivKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQC6C94euSI3GAbszcTVvuBI4ejM/fugqe/uUyXz2bUIGemkADBh
OOkNWXFi/gnYylHRrFKOH06wxhzZWpsBMacmwx6tD7a7nKktcw7HsVFL8is0PPnp
syhWfW+DF6vMDZxkgI3iKrr9/WY/3/qUg7ga17s1JXb3SmQ2sMDTh5I6DQIET4Bo
LwKBgCBG2EmsLiVPCXwN+Mk8IGck7BHKhVpcm955VDDiuKNMuFK4F9ak3tbsKOza
UDC+JhqhB1U7/J8zABM+qVqHBwse1sJMZUEXPuGbIuw4vmEHFA+scAuwkpmRx4gA
/Ghi9eWr1rDlrRFMEF5vs18GObY7Z07GxTx/nZPx7FZ+6FqZAkEA24zob4NMKGUj
efHggZ4DFiIGDEbfbRS6a/w7VicJwI41pwhbGj7KCPZEwXYhnXR3H9UXSrowsm14
D0Wbsw4gRwJBANjvAbFVBAW8TWxLCgKx7uyHehygEBl5NY2in/8QHMjJpE7fQX5U
qutOL68A6+8P0lrtoz4VJZSnAxwkaifM8QsCQA37iRRm+Qd64OetQrHj+FhiZlrJ
LAT0CUWmADJ5KYX49B2lfNXDrXOsUG9sZ4tHKRGDt51KC/0KjMgq9BGx41MCQF0y
FxOL0s2EtXz/33V4QA9twe9xUBDY4CMts4Eyq3xlscbBBe4IjwrcKuntJ3POkGPS
Xotb9TDONmrANIqlmbECQCD8Uo0bgt8kR5bShqkbW1e5qVNz5w4+tM7Uh+oQMIGB
bC3xLJD4u2NPTwTdqKxxkeicFMKpuiGvX200M/CcoVc=
-----END RSA PRIVATE KEY-----`)
)
func TestRSA(t *testing.T) {
pk, err := PemToRSAPrivateKey(rsaPrivKey)
if err != nil {
t.Fatalf("can't parse key: %v", err)
}
pkd := RSAPrivateKeyToPem(pk)
pk2, err := PemToRSAPrivateKey(pkd)
if err != nil {
t.Fatalf("can't parse key: %v", err)
}
pkd2 := RSAPrivateKeyToPem(pk2)
if string(pkd) != string(pkd2) {
t.Fatalf("functions are not 1:1")
}
}

View File

@ -0,0 +1,40 @@
package routecrypto
import "testing"
func TestSecretBox(t *testing.T) {
var (
key *[32]byte
sk string
)
t.Run("generate key", func(t *testing.T) {
var err error
key, err = GenerateKey()
if err != nil {
t.Fatalf("can't generate key: %v", err)
}
})
if key == nil {
t.Fatal("can't continue")
}
t.Run("show key", func(t *testing.T) {
sk = ShowKey(key)
if len(sk) == 0 {
t.Fatal("expected output to be a nonzero length string")
}
})
t.Run("read key", func(t *testing.T) {
readKey, err := ParseKey(sk)
if err != nil {
t.Fatal(err)
}
if *key != *readKey {
t.Fatal("key did not parse out correctly")
}
})
}

View File

@ -11,8 +11,12 @@ import (
"git.xeserv.us/xena/route/internal/database"
"git.xeserv.us/xena/route/internal/tun2"
proto "git.xeserv.us/xena/route/proto"
"github.com/Xe/ln"
"github.com/mtneug/pkg/ulid"
"github.com/oxtoacart/bpool"
kcp "github.com/xtaci/kcp-go"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
@ -46,6 +50,56 @@ type Config struct {
CertKey *[32]byte
}
func (s *Server) listenTCP(ctx context.Context, addr string, tcfg *tls.Config) {
l, err := tls.Listen("tcp", addr, tcfg)
if err != nil {
panic(err)
}
ln.Log(ctx, ln.Action("tcp+tls listening"), ln.F{"addr": l.Addr()})
for {
conn, err := l.Accept()
if err != nil {
ln.Error(ctx, err, ln.Action("accept backend client socket"))
}
ln.Log(ctx, ln.F{
"action": "new backend client",
"addr": conn.RemoteAddr(),
"network": conn.RemoteAddr().Network(),
})
go s.ts.HandleConn(conn, false)
}
}
func (s *Server) listenKCP(ctx context.Context, addr string, tcfg *tls.Config) {
l, err := kcp.Listen(addr)
if err != nil {
panic(err)
}
ln.Log(ctx, ln.Action("kcp+tls listening"), ln.F{"addr": l.Addr()})
for {
conn, err := l.Accept()
if err != nil {
ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()})
}
ln.Log(ctx, ln.F{
"action": "new_client",
"network": conn.RemoteAddr().Network(),
"addr": conn.RemoteAddr(),
})
tc := tls.Server(conn, tcfg)
go s.ts.HandleConn(tc, true)
}
}
// New creates a new Server
func New(cfg Config) (*Server, error) {
if cfg.CertKey == nil {
@ -65,11 +119,6 @@ func New(cfg Config) (*Server, error) {
}
tcfg := &tun2.ServerConfig{
TCPAddr: cfg.BackendTCPAddr,
KCPAddr: cfg.BackendKCPAddr,
TLSConfig: &tls.Config{
GetCertificate: m.GetCertificate,
},
Storage: &storageWrapper{
Storage: db,
},
@ -79,6 +128,7 @@ func New(cfg Config) (*Server, error) {
if err != nil {
return nil, err
}
s := &Server{
cfg: &cfg,
db: db,
@ -87,13 +137,15 @@ func New(cfg Config) (*Server, error) {
Manager: m,
}
s.ts = ts
go ts.ListenAndServe()
tc := &tls.Config{
GetCertificate: m.GetCertificate,
}
gs := grpc.NewServer(grpc.Creds(credentials.NewTLS(&tls.Config{
GetCertificate: m.GetCertificate,
InsecureSkipVerify: true,
})))
go s.listenKCP(context.Background(), cfg.BackendKCPAddr, tc)
go s.listenTCP(context.Background(), cfg.BackendTCPAddr, tc)
// gRPC setup
gs := grpc.NewServer(grpc.Creds(credentials.NewTLS(tc)))
proto.RegisterBackendsServer(gs, &Backend{Server: s})
proto.RegisterRoutesServer(gs, &Route{Server: s})
@ -140,7 +192,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Director: s.Director,
Transport: s.ts,
FlushInterval: 1 * time.Second,
//BufferPool: bpool.NewBytePool(256, 4096),
BufferPool: bpool.NewBytePool(256, 4096),
}
rp.ServeHTTP(w, r)

View File

@ -0,0 +1,29 @@
package server
import (
"net/http"
"testing"
)
func TestDirector(t *testing.T) {
s := &Server{}
req, err := http.NewRequest("GET", "https://cetacean.club/", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("X-Forwarded-For", "Rick-James")
req.Header.Add("X-Client-Ip", "56.32.51.84")
s.Director(req)
for _, header := range []string{"X-Forwarded-For", "X-Client-Ip"} {
t.Run(header, func(t *testing.T) {
val := req.Header.Get(header)
if val != "" {
t.Fatalf("expected header %q to have no value, got: %v", header, val)
}
})
}
}

View File

@ -1,5 +1,7 @@
package tun2
import "time"
// Backend is the public state of an individual Connection.
type Backend struct {
ID string
@ -10,3 +12,72 @@ type Backend struct {
Host string
Usable bool
}
type backendMatcher func(*Connection) bool
func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
s.connlock.Lock()
defer s.connlock.Unlock()
var result []Backend
for _, c := range s.conns {
if !bm(c) {
continue
}
protocol := "tcp"
if c.isKCP {
protocol = "kcp"
}
result = append(result, Backend{
ID: c.id,
Proto: protocol,
User: c.user,
Domain: c.domain,
Phi: float32(c.detector.Phi(time.Now())),
Host: c.conn.RemoteAddr().String(),
Usable: c.usable,
})
}
return result
}
// KillBackend forcibly disconnects a given backend but doesn't offer a way to
// "ban" it from reconnecting.
func (s *Server) KillBackend(id string) error {
s.connlock.Lock()
defer s.connlock.Unlock()
for _, c := range s.conns {
if c.id == id {
c.cancel()
return nil
}
}
return ErrNoSuchBackend
}
// GetBackendsForDomain fetches all backends connected to this server associated
// to a single public domain name.
func (s *Server) GetBackendsForDomain(domain string) []Backend {
return s.getBackendsForMatcher(func(c *Connection) bool {
return c.domain == domain
})
}
// GetBackendsForUser fetches all backends connected to this server owned by a
// given user by username.
func (s *Server) GetBackendsForUser(uname string) []Backend {
return s.getBackendsForMatcher(func(c *Connection) bool {
return c.user == uname
})
}
// GetAllBackends fetches every backend connected to this server.
func (s *Server) GetAllBackends() []Backend {
return s.getBackendsForMatcher(func(*Connection) bool { return true })
}

View File

@ -1,6 +1,7 @@
package tun2
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
@ -14,10 +15,14 @@ import (
"github.com/xtaci/smux"
)
// Client connects to a remote tun2 server and sets up authentication before routing
// individual HTTP requests to discrete streams that are reverse proxied to the eventual
// backend.
type Client struct {
cfg *ClientConfig
}
// ClientConfig configures client with settings that the user provides.
type ClientConfig struct {
TLSConfig *tls.Config
ConnType string
@ -25,8 +30,12 @@ type ClientConfig struct {
Token string
Domain string
BackendURL string
// internal use only
forceTCPClear bool
}
// NewClient constructs an instance of Client with a given ClientConfig.
func NewClient(cfg *ClientConfig) (*Client, error) {
if cfg == nil {
return nil, errors.New("tun2: client config needed")
@ -39,7 +48,12 @@ func NewClient(cfg *ClientConfig) (*Client, error) {
return c, nil
}
func (c *Client) Connect() error {
// Connect dials the remote server and negotiates a client session with its
// configured server address. This will then continuously proxy incoming HTTP
// requests to the backend HTTP server.
//
// This is a blocking function.
func (c *Client) Connect(ctx context.Context) error {
return c.connect(c.cfg.ServerAddr)
}
@ -57,7 +71,12 @@ func (c *Client) connect(serverAddr string) error {
switch c.cfg.ConnType {
case "tcp":
conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig)
if c.cfg.forceTCPClear {
conn, err = net.Dial("tcp", serverAddr)
} else {
conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig)
}
if err != nil {
return err
}
@ -117,15 +136,12 @@ func (c *Client) connect(serverAddr string) error {
return nil
}
// smuxListener wraps a smux session as a net.Listener.
type smuxListener struct {
conn net.Conn
session *smux.Session
}
var (
_ net.Listener = &smuxListener{} // interface check
)
func (sl *smuxListener) Accept() (net.Conn, error) {
return sl.session.AcceptStream()
}

View File

@ -0,0 +1,21 @@
package tun2
import (
"net"
"testing"
)
func TestNewClientNullConfig(t *testing.T) {
_, err := NewClient(nil)
if err == nil {
t.Fatalf("expected NewClient(nil) to fail, got non-failure")
}
}
func TestSmuxListenerIsNetListener(t *testing.T) {
var sl interface{} = &smuxListener{}
_, ok := sl.(net.Listener)
if !ok {
t.Fatalf("smuxListener does not implement net.Listener")
}
}

View File

@ -2,11 +2,10 @@ package tun2
import (
"bufio"
"bytes"
"context"
"io/ioutil"
"net"
"net/http"
"strconv"
"time"
"github.com/Xe/ln"
@ -127,7 +126,12 @@ func (c *Connection) RoundTrip(req *http.Request) (*http.Response, error) {
if err != nil {
return nil, errors.Wrap(err, ErrCantOpenSessionStream.Error())
}
defer stream.Close()
go func() {
time.Sleep(30 * time.Minute)
stream.Close()
}()
err = req.Write(stream)
if err != nil {
@ -142,13 +146,13 @@ func (c *Connection) RoundTrip(req *http.Request) (*http.Response, error) {
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
cl := resp.Header.Get("Content-Length")
asInt, err := strconv.Atoi(cl)
if err != nil {
return nil, errors.Wrap(err, "can't read response body")
return nil, err
}
resp.Body = ioutil.NopCloser(bytes.NewBuffer(body))
resp.ContentLength = int64(len(body))
resp.ContentLength = int64(asInt)
return resp, nil
}

View File

@ -3,10 +3,10 @@ package tun2
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
@ -19,7 +19,6 @@ import (
failure "github.com/dgryski/go-failure"
"github.com/mtneug/pkg/ulid"
cmap "github.com/streamrail/concurrent-map"
kcp "github.com/xtaci/kcp-go"
"github.com/xtaci/smux"
)
@ -30,412 +29,7 @@ var (
ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it")
)
// ServerConfig ...
type ServerConfig struct {
TCPAddr string
KCPAddr string
TLSConfig *tls.Config
SmuxConf *smux.Config
Storage Storage
}
// Storage is the minimal subset of features that tun2's Server needs out of a
// persistence layer.
type Storage interface {
HasToken(token string) (user string, scopes []string, err error)
HasRoute(domain string) (user string, err error)
}
// Server routes frontend HTTP traffic to backend TCP traffic.
type Server struct {
cfg *ServerConfig
connlock sync.Mutex
conns map[net.Conn]*Connection
domains cmap.ConcurrentMap
}
// NewServer creates a new Server instance with a given config, acquiring all
// relevant resources.
func NewServer(cfg *ServerConfig) (*Server, error) {
if cfg == nil {
return nil, errors.New("tun2: config must be specified")
}
if cfg.SmuxConf == nil {
cfg.SmuxConf = smux.DefaultConfig()
}
cfg.SmuxConf.KeepAliveInterval = time.Second
cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second
server := &Server{
cfg: cfg,
conns: map[net.Conn]*Connection{},
domains: cmap.New(),
}
return server, nil
}
type backendMatcher func(*Connection) bool
func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
s.connlock.Lock()
defer s.connlock.Unlock()
var result []Backend
for _, c := range s.conns {
if !bm(c) {
continue
}
protocol := "tcp"
if c.isKCP {
protocol = "kcp"
}
result = append(result, Backend{
ID: c.id,
Proto: protocol,
User: c.user,
Domain: c.domain,
Phi: float32(c.detector.Phi(time.Now())),
Host: c.conn.RemoteAddr().String(),
Usable: c.usable,
})
}
return result
}
func (s *Server) KillBackend(id string) error {
s.connlock.Lock()
defer s.connlock.Unlock()
for _, c := range s.conns {
if c.id == id {
c.cancel()
return nil
}
}
return ErrNoSuchBackend
}
func (s *Server) GetBackendsForDomain(domain string) []Backend {
return s.getBackendsForMatcher(func(c *Connection) bool {
return c.domain == domain
})
}
func (s *Server) GetBackendsForUser(uname string) []Backend {
return s.getBackendsForMatcher(func(c *Connection) bool {
return c.user == uname
})
}
func (s *Server) GetAllBackends() []Backend {
return s.getBackendsForMatcher(func(*Connection) bool { return true })
}
// ListenAndServe starts the backend TCP/KCP listeners and relays backend
// traffic to and from them.
func (s *Server) ListenAndServe() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ln.Log(ctx, ln.F{
"action": "listen_and_serve_called",
})
if s.cfg.TCPAddr != "" {
go func() {
l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig)
if err != nil {
panic(err)
}
ln.Log(ctx, ln.F{
"action": "tcp+tls_listening",
"addr": l.Addr(),
})
for {
conn, err := l.Accept()
if err != nil {
ln.Error(ctx, err, ln.F{"kind": "tcp", "addr": l.Addr().String()})
continue
}
ln.Log(ctx, ln.F{
"action": "new_client",
"kcp": false,
"addr": conn.RemoteAddr(),
})
go s.HandleConn(conn, false)
}
}()
}
if s.cfg.KCPAddr != "" {
go func() {
l, err := kcp.Listen(s.cfg.KCPAddr)
if err != nil {
panic(err)
}
ln.Log(ctx, ln.F{
"action": "kcp+tls_listening",
"addr": l.Addr(),
})
for {
conn, err := l.Accept()
if err != nil {
ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()})
}
ln.Log(ctx, ln.F{
"action": "new_client",
"kcp": true,
"addr": conn.RemoteAddr(),
})
tc := tls.Server(conn, s.cfg.TLSConfig)
go s.HandleConn(tc, true)
}
}()
}
// XXX experimental, might get rid of this inside this process
go func() {
for {
time.Sleep(time.Second)
now := time.Now()
s.connlock.Lock()
for _, c := range s.conns {
failureChance := c.detector.Phi(now)
if failureChance > 0.8 {
ln.Log(ctx, c.F(), ln.F{
"action": "phi_failure_detection",
"value": failureChance,
})
}
}
s.connlock.Unlock()
}
}()
return nil
}
// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from
// the currently connected backend.
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
// XXX TODO clean this up it's really ugly.
defer c.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
session, err := smux.Server(c, s.cfg.SmuxConf)
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "session_failure",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
c.Close()
return
}
defer session.Close()
controlStream, err := session.OpenStream()
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "control_stream_failure",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
return
}
defer controlStream.Close()
csd := json.NewDecoder(controlStream)
auth := &Auth{}
err = csd.Decode(auth)
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "control_stream_auth_decoding_failure",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
return
}
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "nosuch_domain",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
return
}
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "nosuch_token",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
return
}
ok := false
for _, sc := range scopes {
if sc == "connect" {
ok = true
break
}
}
if !ok {
ln.Error(ctx, ErrAuthMismatch, ln.F{
"action": "token_not_authorized",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
}
if routeUser != tokenUser {
ln.Error(ctx, ErrAuthMismatch, ln.F{
"action": "auth_mismatch",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
return
}
connection := &Connection{
id: ulid.New().String(),
conn: c,
isKCP: isKCP,
session: session,
user: tokenUser,
domain: auth.Domain,
cf: cancel,
detector: failure.New(15, 1),
Auth: auth,
}
defer func() {
if r := recover(); r != nil {
ln.Log(ctx, connection, ln.F{"action": "connection handler panic", "err": r})
}
}()
ln.Log(ctx, ln.F{
"action": "backend_connected",
}, connection.F())
s.connlock.Lock()
s.conns[c] = connection
s.connlock.Unlock()
var conns []*Connection
val, ok := s.domains.Get(auth.Domain)
if ok {
conns, ok = val.([]*Connection)
if !ok {
conns = nil
s.domains.Remove(auth.Domain)
}
}
conns = append(conns, connection)
s.domains.Set(auth.Domain, conns)
connection.usable = true
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := connection.Ping()
if err != nil {
connection.cancel()
}
case <-ctx.Done():
s.RemoveConn(ctx, connection)
connection.Close()
return
}
}
}
// RemoveConn removes a connection.
func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
s.connlock.Lock()
delete(s.conns, connection.conn)
s.connlock.Unlock()
auth := connection.Auth
var conns []*Connection
val, ok := s.domains.Get(auth.Domain)
if ok {
conns, ok = val.([]*Connection)
if !ok {
ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection.F(), ln.F{
"action": "looking_up_for_disconnect_removal",
})
return
}
}
for i, cntn := range conns {
if cntn.id == connection.id {
conns[i] = conns[len(conns)-1]
conns = conns[:len(conns)-1]
}
}
if len(conns) != 0 {
s.domains.Set(auth.Domain, conns)
} else {
s.domains.Remove(auth.Domain)
}
ln.Log(ctx, connection.F(), ln.F{
"action": "client_disconnecting",
})
}
// gen502Page creates the page that is shown when a backend is not connected to a given route.
func gen502Page(req *http.Request) *http.Response {
template := `<html><head><title>no backends connected</title></head><body><h1>no backends connected</h1><p>Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.</p></body></html>`
@ -469,6 +63,347 @@ func gen502Page(req *http.Request) *http.Response {
return resp
}
// ServerConfig ...
type ServerConfig struct {
SmuxConf *smux.Config
Storage Storage
}
// Storage is the minimal subset of features that tun2's Server needs out of a
// persistence layer.
type Storage interface {
HasToken(token string) (user string, scopes []string, err error)
HasRoute(domain string) (user string, err error)
}
// Server routes frontend HTTP traffic to backend TCP traffic.
type Server struct {
cfg *ServerConfig
ctx context.Context
cancel context.CancelFunc
connlock sync.Mutex
conns map[net.Conn]*Connection
domains cmap.ConcurrentMap
}
// NewServer creates a new Server instance with a given config, acquiring all
// relevant resources.
func NewServer(cfg *ServerConfig) (*Server, error) {
if cfg == nil {
return nil, errors.New("tun2: config must be specified")
}
if cfg.SmuxConf == nil {
cfg.SmuxConf = smux.DefaultConfig()
cfg.SmuxConf.KeepAliveInterval = time.Second
cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second
}
ctx, cancel := context.WithCancel(context.Background())
server := &Server{
cfg: cfg,
conns: map[net.Conn]*Connection{},
domains: cmap.New(),
ctx: ctx,
cancel: cancel,
}
go server.phiDetectionLoop(ctx)
return server, nil
}
// Close stops the background tasks for this Server.
func (s *Server) Close() {
s.cancel()
}
// Wait blocks until the server context is cancelled.
func (s *Server) Wait() {
for {
select {
case <-s.ctx.Done():
return
}
}
}
// Listen passes this Server a given net.Listener to accept backend connections.
func (s *Server) Listen(l net.Listener, isKCP bool) {
ctx := s.ctx
f := ln.F{
"listener_addr": l.Addr(),
"listener_network": l.Addr().Network(),
}
for {
select {
case <-ctx.Done():
return
default:
}
conn, err := l.Accept()
if err != nil {
ln.Error(ctx, err, f, ln.Action("accept connection"))
continue
}
ln.Log(ctx, f, ln.Action("new backend client connected"), ln.F{
"conn_addr": conn.RemoteAddr(),
"conn_network": conn.RemoteAddr().Network(),
})
go s.HandleConn(conn, isKCP)
}
}
// phiDetectionLoop is an infinite loop that will run the [phi accrual failure detector]
// for each of the backends connected to the Server. This is fairly experimental and
// may be removed.
//
// [phi accrual failure detector]: https://dspace.jaist.ac.jp/dspace/handle/10119/4784
func (s *Server) phiDetectionLoop(ctx context.Context) {
t := time.NewTicker(5 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
now := time.Now()
s.connlock.Lock()
for _, c := range s.conns {
failureChance := c.detector.Phi(now)
const thresh = 0.9 // the threshold for phi failure detection causing logs
if failureChance > thresh {
ln.Log(ctx, c, ln.Action("phi failure detection"), ln.F{
"value": failureChance,
"threshold": thresh,
})
}
}
s.connlock.Unlock()
}
}
}
// backendAuthv1 runs a simple backend authentication check. It expects the
// client to write a json-encoded instance of Auth. This is then checked
// for token validity and domain matching.
//
// This returns the user that was authenticated and the domain they identified
// with.
func (s *Server) backendAuthv1(ctx context.Context, st io.Reader) (string, *Auth, error) {
f := ln.F{
"action": "backend authentication",
"backend_auth_version": 1,
}
f["stage"] = "json decoding"
d := json.NewDecoder(st)
var auth Auth
err := d.Decode(&auth)
if err != nil {
ln.Error(ctx, err, f)
return "", nil, err
}
f["auth_domain"] = auth.Domain
f["stage"] = "checking domain"
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
if err != nil {
ln.Error(ctx, err, f)
return "", nil, err
}
f["route_user"] = routeUser
f["stage"] = "checking token"
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
if err != nil {
ln.Error(ctx, err, f)
return "", nil, err
}
f["token_user"] = tokenUser
f["stage"] = "checking token scopes"
ok := false
for _, sc := range scopes {
if sc == "connect" {
ok = true
break
}
}
if !ok {
ln.Error(ctx, ErrAuthMismatch, f)
return "", nil, ErrAuthMismatch
}
f["stage"] = "user verification"
if routeUser != tokenUser {
ln.Error(ctx, ErrAuthMismatch, f)
return "", nil, ErrAuthMismatch
}
return routeUser, &auth, nil
}
// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from
// the currently connected backend.
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
defer c.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
f := ln.F{
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
}
session, err := smux.Server(c, s.cfg.SmuxConf)
if err != nil {
ln.Error(ctx, err, f, ln.Action("establish server side of smux"))
return
}
defer session.Close()
controlStream, err := session.OpenStream()
if err != nil {
ln.Error(ctx, err, f, ln.Action("opening control stream"))
return
}
defer controlStream.Close()
user, auth, err := s.backendAuthv1(ctx, controlStream)
if err != nil {
return
}
connection := &Connection{
id: ulid.New().String(),
conn: c,
isKCP: isKCP,
session: session,
user: user,
domain: auth.Domain,
cf: cancel,
detector: failure.New(15, 1),
Auth: auth,
}
defer func() {
if r := recover(); r != nil {
ln.Log(ctx, connection, ln.F{"action": "connection handler panic", "err": r})
}
}()
ln.Log(ctx, connection, ln.Action("backend successfully connected"))
s.addConn(ctx, connection)
connection.usable = true // XXX set this to true once health checks pass?
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := connection.Ping()
if err != nil {
connection.cancel()
}
// case <-s.ctx.Done():
// ln.Log(ctx, connection, ln.Action("server context finished"))
// s.removeConn(ctx, connection)
// connection.Close()
// return
case <-ctx.Done():
ln.Log(ctx, connection, ln.Action("client context finished"))
s.removeConn(ctx, connection)
connection.Close()
return
}
}
}
// addConn adds a connection to the pool of backend connections.
func (s *Server) addConn(ctx context.Context, connection *Connection) {
s.connlock.Lock()
s.conns[connection.conn] = connection
s.connlock.Unlock()
var conns []*Connection
val, ok := s.domains.Get(connection.domain)
if ok {
conns, ok = val.([]*Connection)
if !ok {
conns = nil
s.domains.Remove(connection.domain)
}
}
conns = append(conns, connection)
s.domains.Set(connection.domain, conns)
}
// removeConn removes a connection from pool of backend connections.
func (s *Server) removeConn(ctx context.Context, connection *Connection) {
s.connlock.Lock()
delete(s.conns, connection.conn)
s.connlock.Unlock()
auth := connection.Auth
var conns []*Connection
val, ok := s.domains.Get(auth.Domain)
if ok {
conns, ok = val.([]*Connection)
if !ok {
ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection, ln.Action("looking up for disconnect removal"))
return
}
}
for i, cntn := range conns {
if cntn.id == connection.id {
conns[i] = conns[len(conns)-1]
conns = conns[:len(conns)-1]
}
}
if len(conns) != 0 {
s.domains.Set(auth.Domain, conns)
} else {
s.domains.Remove(auth.Domain)
}
}
// RoundTrip sends a HTTP request to a backend and then returns its response.
func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {
var conns []*Connection

View File

@ -0,0 +1,324 @@
package tun2
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/Xe/uuid"
)
// testing constants
const (
user = "shachi"
token = "orcaz r kewl"
noPermToken = "aw heck"
otherUserToken = "even more heck"
domain = "cetacean.club"
)
func TestNewServerNullConfig(t *testing.T) {
_, err := NewServer(nil)
if err == nil {
t.Fatalf("expected NewServer(nil) to fail, got non-failure")
}
}
func TestGen502Page(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req, err := http.NewRequest("GET", "http://cetacean.club", nil)
if err != nil {
t.Fatal(err)
}
substring := uuid.New()
req = req.WithContext(ctx)
req.Header.Add("X-Request-Id", substring)
req.Host = "cetacean.club"
resp := gen502Page(req)
if resp == nil {
t.Fatalf("expected response to be non-nil")
}
if resp.Body != nil {
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(data), substring) {
fmt.Println(string(data))
t.Fatalf("502 page did not contain needed substring %q", substring)
}
}
}
func TestBackendAuthV1(t *testing.T) {
st := MockStorage()
s, err := NewServer(&ServerConfig{
Storage: st,
})
if err != nil {
t.Fatal(err)
}
defer s.Close()
st.AddRoute(domain, user)
st.AddToken(token, user, []string{"connect"})
st.AddToken(noPermToken, user, nil)
st.AddToken(otherUserToken, "cadey", []string{"connect"})
cases := []struct {
name string
auth Auth
wantErr bool
}{
{
name: "basic everything should work",
auth: Auth{
Token: token,
Domain: domain,
},
wantErr: false,
},
{
name: "invalid domain",
auth: Auth{
Token: token,
Domain: "aw.heck",
},
wantErr: true,
},
{
name: "invalid token",
auth: Auth{
Token: "asdfwtweg",
Domain: domain,
},
wantErr: true,
},
{
name: "invalid token scopes",
auth: Auth{
Token: noPermToken,
Domain: domain,
},
wantErr: true,
},
{
name: "user token doesn't match domain owner",
auth: Auth{
Token: otherUserToken,
Domain: domain,
},
wantErr: true,
},
}
for _, cs := range cases {
t.Run(cs.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
data, err := json.Marshal(cs.auth)
if err != nil {
t.Fatal(err)
}
_, _, err = s.backendAuthv1(ctx, bytes.NewBuffer(data))
if cs.wantErr && err == nil {
t.Fatalf("auth did not err as expected")
}
if !cs.wantErr && err != nil {
t.Fatalf("unexpected auth err: %v", err)
}
})
}
}
func TestBackendRouting(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
st := MockStorage()
st.AddRoute(domain, user)
st.AddToken(token, user, []string{"connect"})
s, err := NewServer(&ServerConfig{
Storage: st,
})
if err != nil {
t.Fatal(err)
}
defer s.Close()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
go s.Listen(l, false)
cases := []struct {
name string
should200 bool
handler http.HandlerFunc
}{
{
name: "200 everything's okay",
should200: true,
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "HTTP 200, everything is okay :)", http.StatusOK)
}),
},
}
for _, cs := range cases {
t.Run(cs.name, func(t *testing.T) {
ts := httptest.NewServer(cs.handler)
defer ts.Close()
cc := &ClientConfig{
ConnType: "tcp",
ServerAddr: l.Addr().String(),
Token: token,
BackendURL: ts.URL,
Domain: domain,
forceTCPClear: true,
}
c, err := NewClient(cc)
if err != nil {
t.Fatal(err)
}
go c.Connect(ctx) // TODO: fix the client library so this ends up actually getting cleaned up
time.Sleep(time.Second)
req, err := http.NewRequest("GET", "http://cetacean.club/", nil)
if err != nil {
t.Fatal(err)
}
resp, err := s.RoundTrip(req)
if err != nil {
t.Fatalf("error in doing round trip: %v", err)
}
if cs.should200 && resp.StatusCode != http.StatusOK {
resp.Write(os.Stdout)
t.Fatalf("got status %d instead of StatusOK", resp.StatusCode)
}
})
}
}
func setupTestServer() (*Server, *mockStorage, net.Listener, error) {
st := MockStorage()
st.AddRoute(domain, user)
st.AddToken(token, user, []string{"connect"})
s, err := NewServer(&ServerConfig{
Storage: st,
})
if err != nil {
return nil, nil, nil, err
}
defer s.Close()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, nil, nil, err
}
go s.Listen(l, false)
return s, st, l, nil
}
func BenchmarkHTTP200(b *testing.B) {
b.Skip("this benchmark doesn't work yet")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s, _, l, err := setupTestServer()
if err != nil {
b.Fatal(err)
}
defer s.Close()
defer l.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer ts.Close()
cc := &ClientConfig{
ConnType: "tcp",
ServerAddr: l.Addr().String(),
Token: token,
BackendURL: ts.URL,
Domain: domain,
forceTCPClear: true,
}
c, err := NewClient(cc)
if err != nil {
b.Fatal(err)
}
go c.Connect(ctx) // TODO: fix the client library so this ends up actually getting cleaned up
for {
r := s.GetBackendsForDomain(domain)
if len(r) == 0 {
time.Sleep(125 * time.Millisecond)
continue
}
break
}
req, err := http.NewRequest("GET", "http://cetacean.club/", nil)
if err != nil {
b.Fatal(err)
}
_, err = s.RoundTrip(req)
if err != nil {
b.Fatalf("got error on initial request exchange: %v", err)
}
for n := 0; n < b.N; n++ {
resp, err := s.RoundTrip(req)
if err != nil {
b.Fatalf("got error on %d: %v", n, err)
}
if resp.StatusCode != http.StatusOK {
b.Fail()
b.Logf("got %d instead of 200", resp.StatusCode)
}
}
}

View File

@ -0,0 +1,99 @@
package tun2
import (
"errors"
"sync"
"testing"
)
func MockStorage() *mockStorage {
return &mockStorage{
tokens: make(map[string]mockToken),
domains: make(map[string]string),
}
}
type mockToken struct {
user string
scopes []string
}
// mockStorage is a simple mock of the Storage interface suitable for testing.
type mockStorage struct {
sync.Mutex
tokens map[string]mockToken
domains map[string]string
}
func (ms *mockStorage) AddToken(token, user string, scopes []string) {
ms.Lock()
defer ms.Unlock()
ms.tokens[token] = mockToken{user: user, scopes: scopes}
}
func (ms *mockStorage) AddRoute(domain, user string) {
ms.Lock()
defer ms.Unlock()
ms.domains[domain] = user
}
func (ms *mockStorage) HasToken(token string) (string, []string, error) {
ms.Lock()
defer ms.Unlock()
tok, ok := ms.tokens[token]
if !ok {
return "", nil, errors.New("no such token")
}
return tok.user, tok.scopes, nil
}
func (ms *mockStorage) HasRoute(domain string) (string, error) {
ms.Lock()
defer ms.Unlock()
user, ok := ms.domains[domain]
if !ok {
return "", errors.New("no such route")
}
return user, nil
}
func TestMockStorage(t *testing.T) {
ms := MockStorage()
t.Run("token", func(t *testing.T) {
ms.AddToken(token, user, []string{"connect"})
us, sc, err := ms.HasToken(token)
if err != nil {
t.Fatal(err)
}
if us != user {
t.Fatalf("username was %q, expected %q", us, user)
}
if sc[0] != "connect" {
t.Fatalf("token expected to only have one scope, connect")
}
})
t.Run("domain", func(t *testing.T) {
ms.AddRoute(domain, user)
us, err := ms.HasRoute(domain)
if err != nil {
t.Fatal(err)
}
if us != user {
t.Fatalf("username was %q, expected %q", us, user)
}
})
}

View File

@ -187,8 +187,17 @@ func Package() {
}
}
// Version is the version as git reports.
func Version() {
ver, err := gitTag()
qod.ANE(err)
qod.Printlnf("route-%s", ver)
}
// Test runs all of the functional and unit tests for the project.
func Test() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
shouldWork(ctx, nil, wd, "go", "test", "-v", "./...")
}

View File

@ -32,7 +32,7 @@ func mustEnv(key string, def string) string {
return val
}
func doHttpAgent() {
func doHTTPAgent() {
go func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -47,17 +47,17 @@ func doHttpAgent() {
}
client, _ := tun2.NewClient(cfg)
err := client.Connect()
err := client.Connect(ctx)
if err != nil {
ln.Error(ctx, err, ln.Action("client connection error, restarting"))
time.Sleep(5 * time.Second)
doHttpAgent()
doHTTPAgent()
}
}()
}
func init() {
doHttpAgent()
doHTTPAgent()
}

View File

@ -0,0 +1,3 @@
package main
func main() {}

View File

@ -1 +0,0 @@
xena@greedo.xeserv.us.17867:1486865539