Merge branch 'Xe/feat/unit-and-functional-testing'
Adds unit and functional testing to routed.
This commit is contained in:
commit
278af58424
|
@ -1 +1 @@
|
|||
route
|
||||
route-cli
|
||||
|
|
|
@ -35,7 +35,7 @@ func main() {
|
|||
client, _ := tun2.NewClient(cfg)
|
||||
|
||||
for {
|
||||
err := client.Connect()
|
||||
err := client.Connect(context.Background())
|
||||
if err != nil {
|
||||
ln.Error(context.Background(), err, ln.Action("client connection failed"))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
package elfs
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMakeName(t *testing.T) {
|
||||
n := MakeName()
|
||||
if len(n) == 0 {
|
||||
t.Fatalf("MakeName had a zero output")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTrace(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var executed bool
|
||||
var handler http.Handler = Trace(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
executed = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error when creating request: %v", err)
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
if !executed {
|
||||
t.Fatal("middleware Trace doesn't pass through to underlying handler")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package routecrypto
|
||||
|
||||
import "testing"
|
||||
|
||||
var (
|
||||
rsaPrivKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQC6C94euSI3GAbszcTVvuBI4ejM/fugqe/uUyXz2bUIGemkADBh
|
||||
OOkNWXFi/gnYylHRrFKOH06wxhzZWpsBMacmwx6tD7a7nKktcw7HsVFL8is0PPnp
|
||||
syhWfW+DF6vMDZxkgI3iKrr9/WY/3/qUg7ga17s1JXb3SmQ2sMDTh5I6DQIET4Bo
|
||||
LwKBgCBG2EmsLiVPCXwN+Mk8IGck7BHKhVpcm955VDDiuKNMuFK4F9ak3tbsKOza
|
||||
UDC+JhqhB1U7/J8zABM+qVqHBwse1sJMZUEXPuGbIuw4vmEHFA+scAuwkpmRx4gA
|
||||
/Ghi9eWr1rDlrRFMEF5vs18GObY7Z07GxTx/nZPx7FZ+6FqZAkEA24zob4NMKGUj
|
||||
efHggZ4DFiIGDEbfbRS6a/w7VicJwI41pwhbGj7KCPZEwXYhnXR3H9UXSrowsm14
|
||||
D0Wbsw4gRwJBANjvAbFVBAW8TWxLCgKx7uyHehygEBl5NY2in/8QHMjJpE7fQX5U
|
||||
qutOL68A6+8P0lrtoz4VJZSnAxwkaifM8QsCQA37iRRm+Qd64OetQrHj+FhiZlrJ
|
||||
LAT0CUWmADJ5KYX49B2lfNXDrXOsUG9sZ4tHKRGDt51KC/0KjMgq9BGx41MCQF0y
|
||||
FxOL0s2EtXz/33V4QA9twe9xUBDY4CMts4Eyq3xlscbBBe4IjwrcKuntJ3POkGPS
|
||||
Xotb9TDONmrANIqlmbECQCD8Uo0bgt8kR5bShqkbW1e5qVNz5w4+tM7Uh+oQMIGB
|
||||
bC3xLJD4u2NPTwTdqKxxkeicFMKpuiGvX200M/CcoVc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
)
|
||||
|
||||
func TestRSA(t *testing.T) {
|
||||
pk, err := PemToRSAPrivateKey(rsaPrivKey)
|
||||
if err != nil {
|
||||
t.Fatalf("can't parse key: %v", err)
|
||||
}
|
||||
|
||||
pkd := RSAPrivateKeyToPem(pk)
|
||||
|
||||
pk2, err := PemToRSAPrivateKey(pkd)
|
||||
if err != nil {
|
||||
t.Fatalf("can't parse key: %v", err)
|
||||
}
|
||||
|
||||
pkd2 := RSAPrivateKeyToPem(pk2)
|
||||
|
||||
if string(pkd) != string(pkd2) {
|
||||
t.Fatalf("functions are not 1:1")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
package routecrypto
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSecretBox(t *testing.T) {
|
||||
var (
|
||||
key *[32]byte
|
||||
sk string
|
||||
)
|
||||
|
||||
t.Run("generate key", func(t *testing.T) {
|
||||
var err error
|
||||
key, err = GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("can't generate key: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if key == nil {
|
||||
t.Fatal("can't continue")
|
||||
}
|
||||
|
||||
t.Run("show key", func(t *testing.T) {
|
||||
sk = ShowKey(key)
|
||||
if len(sk) == 0 {
|
||||
t.Fatal("expected output to be a nonzero length string")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("read key", func(t *testing.T) {
|
||||
readKey, err := ParseKey(sk)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if *key != *readKey {
|
||||
t.Fatal("key did not parse out correctly")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -11,8 +11,12 @@ import (
|
|||
"git.xeserv.us/xena/route/internal/database"
|
||||
"git.xeserv.us/xena/route/internal/tun2"
|
||||
proto "git.xeserv.us/xena/route/proto"
|
||||
"github.com/Xe/ln"
|
||||
"github.com/mtneug/pkg/ulid"
|
||||
"github.com/oxtoacart/bpool"
|
||||
kcp "github.com/xtaci/kcp-go"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
@ -46,6 +50,56 @@ type Config struct {
|
|||
CertKey *[32]byte
|
||||
}
|
||||
|
||||
func (s *Server) listenTCP(ctx context.Context, addr string, tcfg *tls.Config) {
|
||||
l, err := tls.Listen("tcp", addr, tcfg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.Action("tcp+tls listening"), ln.F{"addr": l.Addr()})
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.Action("accept backend client socket"))
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "new backend client",
|
||||
"addr": conn.RemoteAddr(),
|
||||
"network": conn.RemoteAddr().Network(),
|
||||
})
|
||||
|
||||
go s.ts.HandleConn(conn, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listenKCP(ctx context.Context, addr string, tcfg *tls.Config) {
|
||||
l, err := kcp.Listen(addr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.Action("kcp+tls listening"), ln.F{"addr": l.Addr()})
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()})
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "new_client",
|
||||
"network": conn.RemoteAddr().Network(),
|
||||
"addr": conn.RemoteAddr(),
|
||||
})
|
||||
|
||||
tc := tls.Server(conn, tcfg)
|
||||
|
||||
go s.ts.HandleConn(tc, true)
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Server
|
||||
func New(cfg Config) (*Server, error) {
|
||||
if cfg.CertKey == nil {
|
||||
|
@ -65,11 +119,6 @@ func New(cfg Config) (*Server, error) {
|
|||
}
|
||||
|
||||
tcfg := &tun2.ServerConfig{
|
||||
TCPAddr: cfg.BackendTCPAddr,
|
||||
KCPAddr: cfg.BackendKCPAddr,
|
||||
TLSConfig: &tls.Config{
|
||||
GetCertificate: m.GetCertificate,
|
||||
},
|
||||
Storage: &storageWrapper{
|
||||
Storage: db,
|
||||
},
|
||||
|
@ -79,6 +128,7 @@ func New(cfg Config) (*Server, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cfg: &cfg,
|
||||
db: db,
|
||||
|
@ -87,13 +137,15 @@ func New(cfg Config) (*Server, error) {
|
|||
Manager: m,
|
||||
}
|
||||
|
||||
s.ts = ts
|
||||
go ts.ListenAndServe()
|
||||
tc := &tls.Config{
|
||||
GetCertificate: m.GetCertificate,
|
||||
}
|
||||
|
||||
gs := grpc.NewServer(grpc.Creds(credentials.NewTLS(&tls.Config{
|
||||
GetCertificate: m.GetCertificate,
|
||||
InsecureSkipVerify: true,
|
||||
})))
|
||||
go s.listenKCP(context.Background(), cfg.BackendKCPAddr, tc)
|
||||
go s.listenTCP(context.Background(), cfg.BackendTCPAddr, tc)
|
||||
|
||||
// gRPC setup
|
||||
gs := grpc.NewServer(grpc.Creds(credentials.NewTLS(tc)))
|
||||
|
||||
proto.RegisterBackendsServer(gs, &Backend{Server: s})
|
||||
proto.RegisterRoutesServer(gs, &Route{Server: s})
|
||||
|
@ -140,7 +192,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
Director: s.Director,
|
||||
Transport: s.ts,
|
||||
FlushInterval: 1 * time.Second,
|
||||
//BufferPool: bpool.NewBytePool(256, 4096),
|
||||
BufferPool: bpool.NewBytePool(256, 4096),
|
||||
}
|
||||
|
||||
rp.ServeHTTP(w, r)
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDirector(t *testing.T) {
|
||||
s := &Server{}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://cetacean.club/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Header.Add("X-Forwarded-For", "Rick-James")
|
||||
req.Header.Add("X-Client-Ip", "56.32.51.84")
|
||||
|
||||
s.Director(req)
|
||||
|
||||
for _, header := range []string{"X-Forwarded-For", "X-Client-Ip"} {
|
||||
t.Run(header, func(t *testing.T) {
|
||||
val := req.Header.Get(header)
|
||||
if val != "" {
|
||||
t.Fatalf("expected header %q to have no value, got: %v", header, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,5 +1,7 @@
|
|||
package tun2
|
||||
|
||||
import "time"
|
||||
|
||||
// Backend is the public state of an individual Connection.
|
||||
type Backend struct {
|
||||
ID string
|
||||
|
@ -10,3 +12,72 @@ type Backend struct {
|
|||
Host string
|
||||
Usable bool
|
||||
}
|
||||
|
||||
type backendMatcher func(*Connection) bool
|
||||
|
||||
func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
|
||||
s.connlock.Lock()
|
||||
defer s.connlock.Unlock()
|
||||
|
||||
var result []Backend
|
||||
|
||||
for _, c := range s.conns {
|
||||
if !bm(c) {
|
||||
continue
|
||||
}
|
||||
|
||||
protocol := "tcp"
|
||||
if c.isKCP {
|
||||
protocol = "kcp"
|
||||
}
|
||||
|
||||
result = append(result, Backend{
|
||||
ID: c.id,
|
||||
Proto: protocol,
|
||||
User: c.user,
|
||||
Domain: c.domain,
|
||||
Phi: float32(c.detector.Phi(time.Now())),
|
||||
Host: c.conn.RemoteAddr().String(),
|
||||
Usable: c.usable,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// KillBackend forcibly disconnects a given backend but doesn't offer a way to
|
||||
// "ban" it from reconnecting.
|
||||
func (s *Server) KillBackend(id string) error {
|
||||
s.connlock.Lock()
|
||||
defer s.connlock.Unlock()
|
||||
|
||||
for _, c := range s.conns {
|
||||
if c.id == id {
|
||||
c.cancel()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrNoSuchBackend
|
||||
}
|
||||
|
||||
// GetBackendsForDomain fetches all backends connected to this server associated
|
||||
// to a single public domain name.
|
||||
func (s *Server) GetBackendsForDomain(domain string) []Backend {
|
||||
return s.getBackendsForMatcher(func(c *Connection) bool {
|
||||
return c.domain == domain
|
||||
})
|
||||
}
|
||||
|
||||
// GetBackendsForUser fetches all backends connected to this server owned by a
|
||||
// given user by username.
|
||||
func (s *Server) GetBackendsForUser(uname string) []Backend {
|
||||
return s.getBackendsForMatcher(func(c *Connection) bool {
|
||||
return c.user == uname
|
||||
})
|
||||
}
|
||||
|
||||
// GetAllBackends fetches every backend connected to this server.
|
||||
func (s *Server) GetAllBackends() []Backend {
|
||||
return s.getBackendsForMatcher(func(*Connection) bool { return true })
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package tun2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -14,10 +15,14 @@ import (
|
|||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
// Client connects to a remote tun2 server and sets up authentication before routing
|
||||
// individual HTTP requests to discrete streams that are reverse proxied to the eventual
|
||||
// backend.
|
||||
type Client struct {
|
||||
cfg *ClientConfig
|
||||
}
|
||||
|
||||
// ClientConfig configures client with settings that the user provides.
|
||||
type ClientConfig struct {
|
||||
TLSConfig *tls.Config
|
||||
ConnType string
|
||||
|
@ -25,8 +30,12 @@ type ClientConfig struct {
|
|||
Token string
|
||||
Domain string
|
||||
BackendURL string
|
||||
|
||||
// internal use only
|
||||
forceTCPClear bool
|
||||
}
|
||||
|
||||
// NewClient constructs an instance of Client with a given ClientConfig.
|
||||
func NewClient(cfg *ClientConfig) (*Client, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("tun2: client config needed")
|
||||
|
@ -39,7 +48,12 @@ func NewClient(cfg *ClientConfig) (*Client, error) {
|
|||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) Connect() error {
|
||||
// Connect dials the remote server and negotiates a client session with its
|
||||
// configured server address. This will then continuously proxy incoming HTTP
|
||||
// requests to the backend HTTP server.
|
||||
//
|
||||
// This is a blocking function.
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
return c.connect(c.cfg.ServerAddr)
|
||||
}
|
||||
|
||||
|
@ -57,7 +71,12 @@ func (c *Client) connect(serverAddr string) error {
|
|||
|
||||
switch c.cfg.ConnType {
|
||||
case "tcp":
|
||||
conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig)
|
||||
if c.cfg.forceTCPClear {
|
||||
conn, err = net.Dial("tcp", serverAddr)
|
||||
} else {
|
||||
conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -117,15 +136,12 @@ func (c *Client) connect(serverAddr string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// smuxListener wraps a smux session as a net.Listener.
|
||||
type smuxListener struct {
|
||||
conn net.Conn
|
||||
session *smux.Session
|
||||
}
|
||||
|
||||
var (
|
||||
_ net.Listener = &smuxListener{} // interface check
|
||||
)
|
||||
|
||||
func (sl *smuxListener) Accept() (net.Conn, error) {
|
||||
return sl.session.AcceptStream()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package tun2
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewClientNullConfig(t *testing.T) {
|
||||
_, err := NewClient(nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected NewClient(nil) to fail, got non-failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmuxListenerIsNetListener(t *testing.T) {
|
||||
var sl interface{} = &smuxListener{}
|
||||
_, ok := sl.(net.Listener)
|
||||
if !ok {
|
||||
t.Fatalf("smuxListener does not implement net.Listener")
|
||||
}
|
||||
}
|
|
@ -2,11 +2,10 @@ package tun2
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Xe/ln"
|
||||
|
@ -127,7 +126,12 @@ func (c *Connection) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, ErrCantOpenSessionStream.Error())
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
go func() {
|
||||
time.Sleep(30 * time.Minute)
|
||||
|
||||
stream.Close()
|
||||
}()
|
||||
|
||||
err = req.Write(stream)
|
||||
if err != nil {
|
||||
|
@ -142,13 +146,13 @@ func (c *Connection) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
cl := resp.Header.Get("Content-Length")
|
||||
asInt, err := strconv.Atoi(cl)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "can't read response body")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
resp.ContentLength = int64(len(body))
|
||||
resp.ContentLength = int64(asInt)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
|
|
@ -3,10 +3,10 @@ package tun2
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
@ -19,7 +19,6 @@ import (
|
|||
failure "github.com/dgryski/go-failure"
|
||||
"github.com/mtneug/pkg/ulid"
|
||||
cmap "github.com/streamrail/concurrent-map"
|
||||
kcp "github.com/xtaci/kcp-go"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
|
@ -30,412 +29,7 @@ var (
|
|||
ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it")
|
||||
)
|
||||
|
||||
// ServerConfig ...
|
||||
type ServerConfig struct {
|
||||
TCPAddr string
|
||||
KCPAddr string
|
||||
TLSConfig *tls.Config
|
||||
|
||||
SmuxConf *smux.Config
|
||||
Storage Storage
|
||||
}
|
||||
|
||||
// Storage is the minimal subset of features that tun2's Server needs out of a
|
||||
// persistence layer.
|
||||
type Storage interface {
|
||||
HasToken(token string) (user string, scopes []string, err error)
|
||||
HasRoute(domain string) (user string, err error)
|
||||
}
|
||||
|
||||
// Server routes frontend HTTP traffic to backend TCP traffic.
|
||||
type Server struct {
|
||||
cfg *ServerConfig
|
||||
|
||||
connlock sync.Mutex
|
||||
conns map[net.Conn]*Connection
|
||||
|
||||
domains cmap.ConcurrentMap
|
||||
}
|
||||
|
||||
// NewServer creates a new Server instance with a given config, acquiring all
|
||||
// relevant resources.
|
||||
func NewServer(cfg *ServerConfig) (*Server, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("tun2: config must be specified")
|
||||
}
|
||||
|
||||
if cfg.SmuxConf == nil {
|
||||
cfg.SmuxConf = smux.DefaultConfig()
|
||||
}
|
||||
|
||||
cfg.SmuxConf.KeepAliveInterval = time.Second
|
||||
cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second
|
||||
|
||||
server := &Server{
|
||||
cfg: cfg,
|
||||
|
||||
conns: map[net.Conn]*Connection{},
|
||||
domains: cmap.New(),
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
type backendMatcher func(*Connection) bool
|
||||
|
||||
func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
|
||||
s.connlock.Lock()
|
||||
defer s.connlock.Unlock()
|
||||
|
||||
var result []Backend
|
||||
|
||||
for _, c := range s.conns {
|
||||
if !bm(c) {
|
||||
continue
|
||||
}
|
||||
|
||||
protocol := "tcp"
|
||||
if c.isKCP {
|
||||
protocol = "kcp"
|
||||
}
|
||||
|
||||
result = append(result, Backend{
|
||||
ID: c.id,
|
||||
Proto: protocol,
|
||||
User: c.user,
|
||||
Domain: c.domain,
|
||||
Phi: float32(c.detector.Phi(time.Now())),
|
||||
Host: c.conn.RemoteAddr().String(),
|
||||
Usable: c.usable,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Server) KillBackend(id string) error {
|
||||
s.connlock.Lock()
|
||||
defer s.connlock.Unlock()
|
||||
|
||||
for _, c := range s.conns {
|
||||
if c.id == id {
|
||||
c.cancel()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrNoSuchBackend
|
||||
}
|
||||
|
||||
func (s *Server) GetBackendsForDomain(domain string) []Backend {
|
||||
return s.getBackendsForMatcher(func(c *Connection) bool {
|
||||
return c.domain == domain
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) GetBackendsForUser(uname string) []Backend {
|
||||
return s.getBackendsForMatcher(func(c *Connection) bool {
|
||||
return c.user == uname
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) GetAllBackends() []Backend {
|
||||
return s.getBackendsForMatcher(func(*Connection) bool { return true })
|
||||
}
|
||||
|
||||
// ListenAndServe starts the backend TCP/KCP listeners and relays backend
|
||||
// traffic to and from them.
|
||||
func (s *Server) ListenAndServe() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "listen_and_serve_called",
|
||||
})
|
||||
|
||||
if s.cfg.TCPAddr != "" {
|
||||
go func() {
|
||||
l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "tcp+tls_listening",
|
||||
"addr": l.Addr(),
|
||||
})
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{"kind": "tcp", "addr": l.Addr().String()})
|
||||
continue
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "new_client",
|
||||
"kcp": false,
|
||||
"addr": conn.RemoteAddr(),
|
||||
})
|
||||
|
||||
go s.HandleConn(conn, false)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if s.cfg.KCPAddr != "" {
|
||||
go func() {
|
||||
l, err := kcp.Listen(s.cfg.KCPAddr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "kcp+tls_listening",
|
||||
"addr": l.Addr(),
|
||||
})
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()})
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "new_client",
|
||||
"kcp": true,
|
||||
"addr": conn.RemoteAddr(),
|
||||
})
|
||||
|
||||
tc := tls.Server(conn, s.cfg.TLSConfig)
|
||||
|
||||
go s.HandleConn(tc, true)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// XXX experimental, might get rid of this inside this process
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
s.connlock.Lock()
|
||||
for _, c := range s.conns {
|
||||
failureChance := c.detector.Phi(now)
|
||||
|
||||
if failureChance > 0.8 {
|
||||
ln.Log(ctx, c.F(), ln.F{
|
||||
"action": "phi_failure_detection",
|
||||
"value": failureChance,
|
||||
})
|
||||
}
|
||||
}
|
||||
s.connlock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from
|
||||
// the currently connected backend.
|
||||
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||
// XXX TODO clean this up it's really ugly.
|
||||
defer c.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
session, err := smux.Server(c, s.cfg.SmuxConf)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "session_failure",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
c.Close()
|
||||
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
controlStream, err := session.OpenStream()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "control_stream_failure",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
defer controlStream.Close()
|
||||
|
||||
csd := json.NewDecoder(controlStream)
|
||||
auth := &Auth{}
|
||||
err = csd.Decode(auth)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "control_stream_auth_decoding_failure",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "nosuch_domain",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "nosuch_token",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ok := false
|
||||
for _, sc := range scopes {
|
||||
if sc == "connect" {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
ln.Error(ctx, ErrAuthMismatch, ln.F{
|
||||
"action": "token_not_authorized",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
}
|
||||
|
||||
if routeUser != tokenUser {
|
||||
ln.Error(ctx, ErrAuthMismatch, ln.F{
|
||||
"action": "auth_mismatch",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
connection := &Connection{
|
||||
id: ulid.New().String(),
|
||||
conn: c,
|
||||
isKCP: isKCP,
|
||||
session: session,
|
||||
user: tokenUser,
|
||||
domain: auth.Domain,
|
||||
cf: cancel,
|
||||
detector: failure.New(15, 1),
|
||||
Auth: auth,
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
ln.Log(ctx, connection, ln.F{"action": "connection handler panic", "err": r})
|
||||
}
|
||||
}()
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "backend_connected",
|
||||
}, connection.F())
|
||||
|
||||
s.connlock.Lock()
|
||||
s.conns[c] = connection
|
||||
s.connlock.Unlock()
|
||||
|
||||
var conns []*Connection
|
||||
|
||||
val, ok := s.domains.Get(auth.Domain)
|
||||
if ok {
|
||||
conns, ok = val.([]*Connection)
|
||||
if !ok {
|
||||
conns = nil
|
||||
|
||||
s.domains.Remove(auth.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
conns = append(conns, connection)
|
||||
|
||||
s.domains.Set(auth.Domain, conns)
|
||||
connection.usable = true
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := connection.Ping()
|
||||
if err != nil {
|
||||
connection.cancel()
|
||||
}
|
||||
case <-ctx.Done():
|
||||
s.RemoveConn(ctx, connection)
|
||||
connection.Close()
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveConn removes a connection.
|
||||
func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
|
||||
s.connlock.Lock()
|
||||
delete(s.conns, connection.conn)
|
||||
s.connlock.Unlock()
|
||||
|
||||
auth := connection.Auth
|
||||
|
||||
var conns []*Connection
|
||||
|
||||
val, ok := s.domains.Get(auth.Domain)
|
||||
if ok {
|
||||
conns, ok = val.([]*Connection)
|
||||
if !ok {
|
||||
ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection.F(), ln.F{
|
||||
"action": "looking_up_for_disconnect_removal",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for i, cntn := range conns {
|
||||
if cntn.id == connection.id {
|
||||
conns[i] = conns[len(conns)-1]
|
||||
conns = conns[:len(conns)-1]
|
||||
}
|
||||
}
|
||||
|
||||
if len(conns) != 0 {
|
||||
s.domains.Set(auth.Domain, conns)
|
||||
} else {
|
||||
s.domains.Remove(auth.Domain)
|
||||
}
|
||||
|
||||
ln.Log(ctx, connection.F(), ln.F{
|
||||
"action": "client_disconnecting",
|
||||
})
|
||||
}
|
||||
|
||||
// gen502Page creates the page that is shown when a backend is not connected to a given route.
|
||||
func gen502Page(req *http.Request) *http.Response {
|
||||
template := `<html><head><title>no backends connected</title></head><body><h1>no backends connected</h1><p>Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.</p></body></html>`
|
||||
|
||||
|
@ -469,6 +63,347 @@ func gen502Page(req *http.Request) *http.Response {
|
|||
return resp
|
||||
}
|
||||
|
||||
// ServerConfig ...
|
||||
type ServerConfig struct {
|
||||
SmuxConf *smux.Config
|
||||
Storage Storage
|
||||
}
|
||||
|
||||
// Storage is the minimal subset of features that tun2's Server needs out of a
|
||||
// persistence layer.
|
||||
type Storage interface {
|
||||
HasToken(token string) (user string, scopes []string, err error)
|
||||
HasRoute(domain string) (user string, err error)
|
||||
}
|
||||
|
||||
// Server routes frontend HTTP traffic to backend TCP traffic.
|
||||
type Server struct {
|
||||
cfg *ServerConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
connlock sync.Mutex
|
||||
conns map[net.Conn]*Connection
|
||||
|
||||
domains cmap.ConcurrentMap
|
||||
}
|
||||
|
||||
// NewServer creates a new Server instance with a given config, acquiring all
|
||||
// relevant resources.
|
||||
func NewServer(cfg *ServerConfig) (*Server, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("tun2: config must be specified")
|
||||
}
|
||||
|
||||
if cfg.SmuxConf == nil {
|
||||
cfg.SmuxConf = smux.DefaultConfig()
|
||||
|
||||
cfg.SmuxConf.KeepAliveInterval = time.Second
|
||||
cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
server := &Server{
|
||||
cfg: cfg,
|
||||
|
||||
conns: map[net.Conn]*Connection{},
|
||||
domains: cmap.New(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
go server.phiDetectionLoop(ctx)
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// Close stops the background tasks for this Server.
|
||||
func (s *Server) Close() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// Wait blocks until the server context is cancelled.
|
||||
func (s *Server) Wait() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Listen passes this Server a given net.Listener to accept backend connections.
|
||||
func (s *Server) Listen(l net.Listener, isKCP bool) {
|
||||
ctx := s.ctx
|
||||
|
||||
f := ln.F{
|
||||
"listener_addr": l.Addr(),
|
||||
"listener_network": l.Addr().Network(),
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.Action("accept connection"))
|
||||
continue
|
||||
}
|
||||
|
||||
ln.Log(ctx, f, ln.Action("new backend client connected"), ln.F{
|
||||
"conn_addr": conn.RemoteAddr(),
|
||||
"conn_network": conn.RemoteAddr().Network(),
|
||||
})
|
||||
|
||||
go s.HandleConn(conn, isKCP)
|
||||
}
|
||||
}
|
||||
|
||||
// phiDetectionLoop is an infinite loop that will run the [phi accrual failure detector]
|
||||
// for each of the backends connected to the Server. This is fairly experimental and
|
||||
// may be removed.
|
||||
//
|
||||
// [phi accrual failure detector]: https://dspace.jaist.ac.jp/dspace/handle/10119/4784
|
||||
func (s *Server) phiDetectionLoop(ctx context.Context) {
|
||||
t := time.NewTicker(5 * time.Second)
|
||||
defer t.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
now := time.Now()
|
||||
|
||||
s.connlock.Lock()
|
||||
for _, c := range s.conns {
|
||||
failureChance := c.detector.Phi(now)
|
||||
const thresh = 0.9 // the threshold for phi failure detection causing logs
|
||||
|
||||
if failureChance > thresh {
|
||||
ln.Log(ctx, c, ln.Action("phi failure detection"), ln.F{
|
||||
"value": failureChance,
|
||||
"threshold": thresh,
|
||||
})
|
||||
}
|
||||
}
|
||||
s.connlock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backendAuthv1 runs a simple backend authentication check. It expects the
|
||||
// client to write a json-encoded instance of Auth. This is then checked
|
||||
// for token validity and domain matching.
|
||||
//
|
||||
// This returns the user that was authenticated and the domain they identified
|
||||
// with.
|
||||
func (s *Server) backendAuthv1(ctx context.Context, st io.Reader) (string, *Auth, error) {
|
||||
f := ln.F{
|
||||
"action": "backend authentication",
|
||||
"backend_auth_version": 1,
|
||||
}
|
||||
|
||||
f["stage"] = "json decoding"
|
||||
|
||||
d := json.NewDecoder(st)
|
||||
var auth Auth
|
||||
err := d.Decode(&auth)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
f["auth_domain"] = auth.Domain
|
||||
f["stage"] = "checking domain"
|
||||
|
||||
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
f["route_user"] = routeUser
|
||||
f["stage"] = "checking token"
|
||||
|
||||
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
f["token_user"] = tokenUser
|
||||
f["stage"] = "checking token scopes"
|
||||
|
||||
ok := false
|
||||
for _, sc := range scopes {
|
||||
if sc == "connect" {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
ln.Error(ctx, ErrAuthMismatch, f)
|
||||
return "", nil, ErrAuthMismatch
|
||||
}
|
||||
|
||||
f["stage"] = "user verification"
|
||||
|
||||
if routeUser != tokenUser {
|
||||
ln.Error(ctx, ErrAuthMismatch, f)
|
||||
return "", nil, ErrAuthMismatch
|
||||
}
|
||||
|
||||
return routeUser, &auth, nil
|
||||
}
|
||||
|
||||
// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from
|
||||
// the currently connected backend.
|
||||
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||
defer c.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
f := ln.F{
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
}
|
||||
|
||||
session, err := smux.Server(c, s.cfg.SmuxConf)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.Action("establish server side of smux"))
|
||||
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
controlStream, err := session.OpenStream()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.Action("opening control stream"))
|
||||
|
||||
return
|
||||
}
|
||||
defer controlStream.Close()
|
||||
|
||||
user, auth, err := s.backendAuthv1(ctx, controlStream)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
connection := &Connection{
|
||||
id: ulid.New().String(),
|
||||
conn: c,
|
||||
isKCP: isKCP,
|
||||
session: session,
|
||||
user: user,
|
||||
domain: auth.Domain,
|
||||
cf: cancel,
|
||||
detector: failure.New(15, 1),
|
||||
Auth: auth,
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
ln.Log(ctx, connection, ln.F{"action": "connection handler panic", "err": r})
|
||||
}
|
||||
}()
|
||||
|
||||
ln.Log(ctx, connection, ln.Action("backend successfully connected"))
|
||||
|
||||
s.addConn(ctx, connection)
|
||||
|
||||
connection.usable = true // XXX set this to true once health checks pass?
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := connection.Ping()
|
||||
if err != nil {
|
||||
connection.cancel()
|
||||
}
|
||||
// case <-s.ctx.Done():
|
||||
// ln.Log(ctx, connection, ln.Action("server context finished"))
|
||||
// s.removeConn(ctx, connection)
|
||||
// connection.Close()
|
||||
|
||||
// return
|
||||
case <-ctx.Done():
|
||||
ln.Log(ctx, connection, ln.Action("client context finished"))
|
||||
s.removeConn(ctx, connection)
|
||||
connection.Close()
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addConn adds a connection to the pool of backend connections.
|
||||
func (s *Server) addConn(ctx context.Context, connection *Connection) {
|
||||
s.connlock.Lock()
|
||||
s.conns[connection.conn] = connection
|
||||
s.connlock.Unlock()
|
||||
|
||||
var conns []*Connection
|
||||
|
||||
val, ok := s.domains.Get(connection.domain)
|
||||
if ok {
|
||||
conns, ok = val.([]*Connection)
|
||||
if !ok {
|
||||
conns = nil
|
||||
|
||||
s.domains.Remove(connection.domain)
|
||||
}
|
||||
}
|
||||
|
||||
conns = append(conns, connection)
|
||||
|
||||
s.domains.Set(connection.domain, conns)
|
||||
}
|
||||
|
||||
// removeConn removes a connection from pool of backend connections.
|
||||
func (s *Server) removeConn(ctx context.Context, connection *Connection) {
|
||||
s.connlock.Lock()
|
||||
delete(s.conns, connection.conn)
|
||||
s.connlock.Unlock()
|
||||
|
||||
auth := connection.Auth
|
||||
|
||||
var conns []*Connection
|
||||
|
||||
val, ok := s.domains.Get(auth.Domain)
|
||||
if ok {
|
||||
conns, ok = val.([]*Connection)
|
||||
if !ok {
|
||||
ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection, ln.Action("looking up for disconnect removal"))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for i, cntn := range conns {
|
||||
if cntn.id == connection.id {
|
||||
conns[i] = conns[len(conns)-1]
|
||||
conns = conns[:len(conns)-1]
|
||||
}
|
||||
}
|
||||
|
||||
if len(conns) != 0 {
|
||||
s.domains.Set(auth.Domain, conns)
|
||||
} else {
|
||||
s.domains.Remove(auth.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip sends a HTTP request to a backend and then returns its response.
|
||||
func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
var conns []*Connection
|
||||
|
|
|
@ -0,0 +1,324 @@
|
|||
package tun2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Xe/uuid"
|
||||
)
|
||||
|
||||
// testing constants
|
||||
const (
|
||||
user = "shachi"
|
||||
token = "orcaz r kewl"
|
||||
noPermToken = "aw heck"
|
||||
otherUserToken = "even more heck"
|
||||
domain = "cetacean.club"
|
||||
)
|
||||
|
||||
func TestNewServerNullConfig(t *testing.T) {
|
||||
_, err := NewServer(nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected NewServer(nil) to fail, got non-failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGen502Page(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequest("GET", "http://cetacean.club", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
substring := uuid.New()
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Add("X-Request-Id", substring)
|
||||
req.Host = "cetacean.club"
|
||||
|
||||
resp := gen502Page(req)
|
||||
if resp == nil {
|
||||
t.Fatalf("expected response to be non-nil")
|
||||
}
|
||||
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(data), substring) {
|
||||
fmt.Println(string(data))
|
||||
t.Fatalf("502 page did not contain needed substring %q", substring)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendAuthV1(t *testing.T) {
|
||||
st := MockStorage()
|
||||
|
||||
s, err := NewServer(&ServerConfig{
|
||||
Storage: st,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
st.AddRoute(domain, user)
|
||||
st.AddToken(token, user, []string{"connect"})
|
||||
st.AddToken(noPermToken, user, nil)
|
||||
st.AddToken(otherUserToken, "cadey", []string{"connect"})
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
auth Auth
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic everything should work",
|
||||
auth: Auth{
|
||||
Token: token,
|
||||
Domain: domain,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid domain",
|
||||
auth: Auth{
|
||||
Token: token,
|
||||
Domain: "aw.heck",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
auth: Auth{
|
||||
Token: "asdfwtweg",
|
||||
Domain: domain,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid token scopes",
|
||||
auth: Auth{
|
||||
Token: noPermToken,
|
||||
Domain: domain,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user token doesn't match domain owner",
|
||||
auth: Auth{
|
||||
Token: otherUserToken,
|
||||
Domain: domain,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
data, err := json.Marshal(cs.auth)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, _, err = s.backendAuthv1(ctx, bytes.NewBuffer(data))
|
||||
|
||||
if cs.wantErr && err == nil {
|
||||
t.Fatalf("auth did not err as expected")
|
||||
}
|
||||
|
||||
if !cs.wantErr && err != nil {
|
||||
t.Fatalf("unexpected auth err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendRouting(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
st := MockStorage()
|
||||
|
||||
st.AddRoute(domain, user)
|
||||
st.AddToken(token, user, []string{"connect"})
|
||||
|
||||
s, err := NewServer(&ServerConfig{
|
||||
Storage: st,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go s.Listen(l, false)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
should200 bool
|
||||
handler http.HandlerFunc
|
||||
}{
|
||||
{
|
||||
name: "200 everything's okay",
|
||||
should200: true,
|
||||
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "HTTP 200, everything is okay :)", http.StatusOK)
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
ts := httptest.NewServer(cs.handler)
|
||||
defer ts.Close()
|
||||
|
||||
cc := &ClientConfig{
|
||||
ConnType: "tcp",
|
||||
ServerAddr: l.Addr().String(),
|
||||
Token: token,
|
||||
BackendURL: ts.URL,
|
||||
Domain: domain,
|
||||
|
||||
forceTCPClear: true,
|
||||
}
|
||||
|
||||
c, err := NewClient(cc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go c.Connect(ctx) // TODO: fix the client library so this ends up actually getting cleaned up
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://cetacean.club/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := s.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("error in doing round trip: %v", err)
|
||||
}
|
||||
|
||||
if cs.should200 && resp.StatusCode != http.StatusOK {
|
||||
resp.Write(os.Stdout)
|
||||
t.Fatalf("got status %d instead of StatusOK", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestServer() (*Server, *mockStorage, net.Listener, error) {
|
||||
st := MockStorage()
|
||||
|
||||
st.AddRoute(domain, user)
|
||||
st.AddToken(token, user, []string{"connect"})
|
||||
|
||||
s, err := NewServer(&ServerConfig{
|
||||
Storage: st,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
go s.Listen(l, false)
|
||||
|
||||
return s, st, l, nil
|
||||
}
|
||||
|
||||
func BenchmarkHTTP200(b *testing.B) {
|
||||
b.Skip("this benchmark doesn't work yet")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, _, l, err := setupTestServer()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer s.Close()
|
||||
defer l.Close()
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer ts.Close()
|
||||
|
||||
cc := &ClientConfig{
|
||||
ConnType: "tcp",
|
||||
ServerAddr: l.Addr().String(),
|
||||
Token: token,
|
||||
BackendURL: ts.URL,
|
||||
Domain: domain,
|
||||
|
||||
forceTCPClear: true,
|
||||
}
|
||||
|
||||
c, err := NewClient(cc)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
go c.Connect(ctx) // TODO: fix the client library so this ends up actually getting cleaned up
|
||||
|
||||
for {
|
||||
r := s.GetBackendsForDomain(domain)
|
||||
if len(r) == 0 {
|
||||
time.Sleep(125 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "http://cetacean.club/", nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = s.RoundTrip(req)
|
||||
if err != nil {
|
||||
b.Fatalf("got error on initial request exchange: %v", err)
|
||||
}
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
resp, err := s.RoundTrip(req)
|
||||
if err != nil {
|
||||
b.Fatalf("got error on %d: %v", n, err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b.Fail()
|
||||
b.Logf("got %d instead of 200", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
package tun2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func MockStorage() *mockStorage {
|
||||
return &mockStorage{
|
||||
tokens: make(map[string]mockToken),
|
||||
domains: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
type mockToken struct {
|
||||
user string
|
||||
scopes []string
|
||||
}
|
||||
|
||||
// mockStorage is a simple mock of the Storage interface suitable for testing.
|
||||
type mockStorage struct {
|
||||
sync.Mutex
|
||||
tokens map[string]mockToken
|
||||
domains map[string]string
|
||||
}
|
||||
|
||||
func (ms *mockStorage) AddToken(token, user string, scopes []string) {
|
||||
ms.Lock()
|
||||
defer ms.Unlock()
|
||||
|
||||
ms.tokens[token] = mockToken{user: user, scopes: scopes}
|
||||
}
|
||||
|
||||
func (ms *mockStorage) AddRoute(domain, user string) {
|
||||
ms.Lock()
|
||||
defer ms.Unlock()
|
||||
|
||||
ms.domains[domain] = user
|
||||
}
|
||||
|
||||
func (ms *mockStorage) HasToken(token string) (string, []string, error) {
|
||||
ms.Lock()
|
||||
defer ms.Unlock()
|
||||
|
||||
tok, ok := ms.tokens[token]
|
||||
if !ok {
|
||||
return "", nil, errors.New("no such token")
|
||||
}
|
||||
|
||||
return tok.user, tok.scopes, nil
|
||||
}
|
||||
|
||||
func (ms *mockStorage) HasRoute(domain string) (string, error) {
|
||||
ms.Lock()
|
||||
defer ms.Unlock()
|
||||
|
||||
user, ok := ms.domains[domain]
|
||||
if !ok {
|
||||
return "", errors.New("no such route")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func TestMockStorage(t *testing.T) {
|
||||
ms := MockStorage()
|
||||
|
||||
t.Run("token", func(t *testing.T) {
|
||||
ms.AddToken(token, user, []string{"connect"})
|
||||
|
||||
us, sc, err := ms.HasToken(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if us != user {
|
||||
t.Fatalf("username was %q, expected %q", us, user)
|
||||
}
|
||||
|
||||
if sc[0] != "connect" {
|
||||
t.Fatalf("token expected to only have one scope, connect")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("domain", func(t *testing.T) {
|
||||
ms.AddRoute(domain, user)
|
||||
|
||||
us, err := ms.HasRoute(domain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if us != user {
|
||||
t.Fatalf("username was %q, expected %q", us, user)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
9
mage.go
9
mage.go
|
@ -187,8 +187,17 @@ func Package() {
|
|||
}
|
||||
}
|
||||
|
||||
// Version is the version as git reports.
|
||||
func Version() {
|
||||
ver, err := gitTag()
|
||||
qod.ANE(err)
|
||||
qod.Printlnf("route-%s", ver)
|
||||
}
|
||||
|
||||
// Test runs all of the functional and unit tests for the project.
|
||||
func Test() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
shouldWork(ctx, nil, wd, "go", "test", "-v", "./...")
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ func mustEnv(key string, def string) string {
|
|||
return val
|
||||
}
|
||||
|
||||
func doHttpAgent() {
|
||||
func doHTTPAgent() {
|
||||
go func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
@ -47,17 +47,17 @@ func doHttpAgent() {
|
|||
}
|
||||
|
||||
client, _ := tun2.NewClient(cfg)
|
||||
err := client.Connect()
|
||||
err := client.Connect(ctx)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.Action("client connection error, restarting"))
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
doHttpAgent()
|
||||
doHTTPAgent()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func init() {
|
||||
doHttpAgent()
|
||||
doHTTPAgent()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
package main
|
||||
|
||||
func main() {}
|
|
@ -1 +0,0 @@
|
|||
xena@greedo.xeserv.us.17867:1486865539
|
Loading…
Reference in New Issue