tun2: some experimenting on the core
This commit is contained in:
parent
a47fd75c5f
commit
59a3f45150
|
@ -1,6 +1,7 @@
|
|||
package tun2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -29,6 +30,9 @@ type ClientConfig struct {
|
|||
Token string
|
||||
Domain string
|
||||
BackendURL string
|
||||
|
||||
// internal use only
|
||||
forceTCPClear bool
|
||||
}
|
||||
|
||||
// NewClient constructs an instance of Client with a given ClientConfig.
|
||||
|
@ -49,7 +53,7 @@ func NewClient(cfg *ClientConfig) (*Client, error) {
|
|||
// requests to the backend HTTP server.
|
||||
//
|
||||
// This is a blocking function.
|
||||
func (c *Client) Connect() error {
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
return c.connect(c.cfg.ServerAddr)
|
||||
}
|
||||
|
||||
|
@ -67,7 +71,12 @@ func (c *Client) connect(serverAddr string) error {
|
|||
|
||||
switch c.cfg.ConnType {
|
||||
case "tcp":
|
||||
if c.cfg.forceTCPClear {
|
||||
conn, err = net.Dial("tcp", serverAddr)
|
||||
} else {
|
||||
conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -3,10 +3,10 @@ package tun2
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
@ -17,9 +17,9 @@ import (
|
|||
|
||||
"github.com/Xe/ln"
|
||||
failure "github.com/dgryski/go-failure"
|
||||
"github.com/kr/pretty"
|
||||
"github.com/mtneug/pkg/ulid"
|
||||
cmap "github.com/streamrail/concurrent-map"
|
||||
kcp "github.com/xtaci/kcp-go"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
|
@ -66,10 +66,6 @@ func gen502Page(req *http.Request) *http.Response {
|
|||
|
||||
// ServerConfig ...
|
||||
type ServerConfig struct {
|
||||
TCPAddr string
|
||||
KCPAddr string
|
||||
TLSConfig *tls.Config
|
||||
|
||||
SmuxConf *smux.Config
|
||||
Storage Storage
|
||||
}
|
||||
|
@ -84,6 +80,8 @@ type Storage interface {
|
|||
// Server routes frontend HTTP traffic to backend TCP traffic.
|
||||
type Server struct {
|
||||
cfg *ServerConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
connlock sync.Mutex
|
||||
conns map[net.Conn]*Connection
|
||||
|
@ -100,146 +98,174 @@ func NewServer(cfg *ServerConfig) (*Server, error) {
|
|||
|
||||
if cfg.SmuxConf == nil {
|
||||
cfg.SmuxConf = smux.DefaultConfig()
|
||||
}
|
||||
|
||||
cfg.SmuxConf.KeepAliveInterval = time.Second
|
||||
cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
server := &Server{
|
||||
cfg: cfg,
|
||||
|
||||
conns: map[net.Conn]*Connection{},
|
||||
domains: cmap.New(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
go server.phiDetectionLoop(ctx)
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// Close stops the background tasks for this Server.
|
||||
func (s *Server) Close() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// Wait blocks until the server context is cancelled.
|
||||
func (s *Server) Wait() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Listen passes this Server a given net.Listener to accept backend connections.
|
||||
func (s *Server) Listen(l net.Listener, isKCP bool) {
|
||||
ctx := context.Background()
|
||||
|
||||
f := ln.F{
|
||||
"listener_addr": l.Addr(),
|
||||
"listener_network": l.Addr().Network(),
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"addr": l.Addr().String(),
|
||||
"network": l.Addr().Network(),
|
||||
})
|
||||
ln.Error(ctx, err, f, ln.Action("accept connection"))
|
||||
continue
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "new_client",
|
||||
"network": conn.RemoteAddr().Network(),
|
||||
"addr": conn.RemoteAddr(),
|
||||
"list": conn.LocalAddr(),
|
||||
ln.Log(ctx, f, ln.Action("new backend client connected"), ln.F{
|
||||
"conn_addr": conn.RemoteAddr(),
|
||||
"conn_network": conn.RemoteAddr().Network(),
|
||||
})
|
||||
|
||||
go s.HandleConn(conn, isKCP)
|
||||
}
|
||||
}
|
||||
|
||||
// ListenAndServe starts the backend TCP/KCP listeners and relays backend
|
||||
// traffic to and from them.
|
||||
func (s *Server) ListenAndServe() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "listen_and_serve_called",
|
||||
})
|
||||
|
||||
if s.cfg.TCPAddr != "" {
|
||||
go func() {
|
||||
l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "tcp+tls_listening",
|
||||
"addr": l.Addr(),
|
||||
})
|
||||
// phiDetectionLoop is an infinite loop that will run the [phi accrual failure detector]
|
||||
// for each of the backends connected to the Server. This is fairly experimental and
|
||||
// may be removed.
|
||||
//
|
||||
// [phi accrual failure detector]: https://dspace.jaist.ac.jp/dspace/handle/10119/4784
|
||||
func (s *Server) phiDetectionLoop(ctx context.Context) {
|
||||
t := time.NewTicker(5 * time.Second)
|
||||
defer t.Stop()
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{"kind": "tcp", "addr": l.Addr().String()})
|
||||
continue
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "new_client",
|
||||
"kcp": false,
|
||||
"addr": conn.RemoteAddr(),
|
||||
})
|
||||
|
||||
go s.HandleConn(conn, false)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if s.cfg.KCPAddr != "" {
|
||||
go func() {
|
||||
l, err := kcp.Listen(s.cfg.KCPAddr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "kcp+tls_listening",
|
||||
"addr": l.Addr(),
|
||||
})
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()})
|
||||
}
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "new_client",
|
||||
"kcp": true,
|
||||
"addr": conn.RemoteAddr(),
|
||||
})
|
||||
|
||||
tc := tls.Server(conn, s.cfg.TLSConfig)
|
||||
|
||||
go s.HandleConn(tc, true)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// XXX experimental, might get rid of this inside this process
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
now := time.Now()
|
||||
|
||||
s.connlock.Lock()
|
||||
for _, c := range s.conns {
|
||||
failureChance := c.detector.Phi(now)
|
||||
const thresh = 0.9 // the threshold for phi failure detection causing logs
|
||||
|
||||
if failureChance > 0.8 {
|
||||
ln.Log(ctx, c.F(), ln.F{
|
||||
"action": "phi_failure_detection",
|
||||
if failureChance > thresh {
|
||||
ln.Log(ctx, c, ln.Action("phi failure detection"), ln.F{
|
||||
"value": failureChance,
|
||||
"threshold": thresh,
|
||||
})
|
||||
}
|
||||
}
|
||||
s.connlock.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// backendAuthv1 runs a simple backend authentication check. It expects the
|
||||
// client to write a json-encoded instance of Auth. This is then checked
|
||||
// for token validity and domain matching.
|
||||
//
|
||||
// This returns the user that was authenticated and the domain they identified
|
||||
// with.
|
||||
func (s *Server) backendAuthv1(ctx context.Context, st io.Reader) (string, *Auth, error) {
|
||||
f := ln.F{
|
||||
"action": "backend authentication",
|
||||
"backend_auth_version": 1,
|
||||
}
|
||||
|
||||
f["stage"] = "json decoding"
|
||||
ln.Log(ctx, f)
|
||||
|
||||
d := json.NewDecoder(st)
|
||||
var auth Auth
|
||||
err := d.Decode(&auth)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
f["stage"] = "checking domain"
|
||||
ln.Log(ctx, f)
|
||||
|
||||
pretty.Println(s.cfg.Storage)
|
||||
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
f["route_user"] = routeUser
|
||||
f["stage"] = "checking token"
|
||||
ln.Log(ctx, f)
|
||||
|
||||
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
f["token_user"] = tokenUser
|
||||
f["stage"] = "checking token scopes"
|
||||
ln.Log(ctx, f)
|
||||
|
||||
ok := false
|
||||
for _, sc := range scopes {
|
||||
if sc == "connect" {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
ln.Error(ctx, ErrAuthMismatch, f)
|
||||
return "", nil, ErrAuthMismatch
|
||||
}
|
||||
|
||||
f["stage"] = "user verification"
|
||||
ln.Log(ctx, f)
|
||||
|
||||
if routeUser != tokenUser {
|
||||
ln.Error(ctx, ErrAuthMismatch, f)
|
||||
return "", nil, ErrAuthMismatch
|
||||
}
|
||||
|
||||
return routeUser, &auth, nil
|
||||
}
|
||||
|
||||
// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from
|
||||
// the currently connected backend.
|
||||
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||
// XXX TODO clean this up it's really ugly.
|
||||
defer c.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -258,8 +284,6 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
}
|
||||
defer session.Close()
|
||||
|
||||
f["stage"] = "smux_setup"
|
||||
|
||||
controlStream, err := session.OpenStream()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.Action("opening control stream"))
|
||||
|
@ -268,58 +292,8 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
}
|
||||
defer controlStream.Close()
|
||||
|
||||
f["stage"] = "control_stream_open"
|
||||
|
||||
csd := json.NewDecoder(controlStream)
|
||||
auth := &Auth{}
|
||||
err = csd.Decode(auth)
|
||||
user, auth, err := s.backendAuthv1(ctx, controlStream)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.Action("decode control stream authenication message"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f["stage"] = "checking_domain"
|
||||
|
||||
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.Action("no such domain when checking client auth"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f["route_user"] = routeUser
|
||||
f["stage"] = "checking_token"
|
||||
|
||||
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.Action("no such token exists or other token error"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f["token_user"] = tokenUser
|
||||
f["stage"] = "checking_token_scopes"
|
||||
|
||||
ok := false
|
||||
for _, sc := range scopes {
|
||||
if sc == "connect" {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
ln.Error(ctx, ErrAuthMismatch, f, ln.Action("token not authorized to connect"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f["stage"] = "user_verification"
|
||||
|
||||
if routeUser != tokenUser {
|
||||
ln.Error(ctx, ErrAuthMismatch, f, ln.Action("auth mismatch"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -328,7 +302,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
conn: c,
|
||||
isKCP: isKCP,
|
||||
session: session,
|
||||
user: tokenUser,
|
||||
user: user,
|
||||
domain: auth.Domain,
|
||||
cf: cancel,
|
||||
detector: failure.New(15, 1),
|
||||
|
@ -343,26 +317,8 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
|
||||
ln.Log(ctx, connection, ln.Action("backend successfully connected"))
|
||||
|
||||
// TODO: put these lines in a function?
|
||||
s.connlock.Lock()
|
||||
s.conns[c] = connection
|
||||
s.connlock.Unlock()
|
||||
s.addConn(ctx, connection)
|
||||
|
||||
var conns []*Connection
|
||||
|
||||
val, ok := s.domains.Get(auth.Domain)
|
||||
if ok {
|
||||
conns, ok = val.([]*Connection)
|
||||
if !ok {
|
||||
conns = nil
|
||||
|
||||
s.domains.Remove(auth.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
conns = append(conns, connection)
|
||||
|
||||
s.domains.Set(auth.Domain, conns)
|
||||
connection.usable = true // XXX set this to true once health checks pass?
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
|
@ -375,8 +331,13 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
if err != nil {
|
||||
connection.cancel()
|
||||
}
|
||||
case <-s.ctx.Done():
|
||||
s.removeConn(ctx, connection)
|
||||
connection.Close()
|
||||
|
||||
return
|
||||
case <-ctx.Done():
|
||||
s.RemoveConn(ctx, connection)
|
||||
s.removeConn(ctx, connection)
|
||||
connection.Close()
|
||||
|
||||
return
|
||||
|
@ -384,8 +345,31 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
}
|
||||
}
|
||||
|
||||
// RemoveConn removes a connection.
|
||||
func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
|
||||
// addConn adds a connection to the pool of backend connections.
|
||||
func (s *Server) addConn(ctx context.Context, connection *Connection) {
|
||||
s.connlock.Lock()
|
||||
s.conns[connection.conn] = connection
|
||||
s.connlock.Unlock()
|
||||
|
||||
var conns []*Connection
|
||||
|
||||
val, ok := s.domains.Get(connection.domain)
|
||||
if ok {
|
||||
conns, ok = val.([]*Connection)
|
||||
if !ok {
|
||||
conns = nil
|
||||
|
||||
s.domains.Remove(connection.domain)
|
||||
}
|
||||
}
|
||||
|
||||
conns = append(conns, connection)
|
||||
|
||||
s.domains.Set(connection.domain, conns)
|
||||
}
|
||||
|
||||
// removeConn removes a connection from pool of backend connections.
|
||||
func (s *Server) removeConn(ctx context.Context, connection *Connection) {
|
||||
s.connlock.Lock()
|
||||
delete(s.conns, connection.conn)
|
||||
s.connlock.Unlock()
|
||||
|
@ -416,8 +400,6 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
|
|||
} else {
|
||||
s.domains.Remove(auth.Domain)
|
||||
}
|
||||
|
||||
ln.Log(ctx, connection, ln.Action("backend disconnect"))
|
||||
}
|
||||
|
||||
// RoundTrip sends a HTTP request to a backend and then returns its response.
|
||||
|
|
|
@ -1,16 +1,31 @@
|
|||
package tun2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Xe/uuid"
|
||||
)
|
||||
|
||||
// testing constants
|
||||
const (
|
||||
user = "shachi"
|
||||
token = "orcaz r kewl"
|
||||
noPermToken = "aw heck"
|
||||
otherUserToken = "even more heck"
|
||||
domain = "cetacean.club"
|
||||
)
|
||||
|
||||
func TestNewServerNullConfig(t *testing.T) {
|
||||
_, err := NewServer(nil)
|
||||
if err == nil {
|
||||
|
@ -51,3 +66,164 @@ func TestGen502Page(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendAuthV1(t *testing.T) {
|
||||
st := MockStorage()
|
||||
|
||||
s, err := NewServer(&ServerConfig{
|
||||
Storage: st,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
st.AddRoute(domain, user)
|
||||
st.AddToken(token, user, []string{"connect"})
|
||||
st.AddToken(noPermToken, user, nil)
|
||||
st.AddToken(otherUserToken, "cadey", []string{"connect"})
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
auth Auth
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic everything should work",
|
||||
auth: Auth{
|
||||
Token: token,
|
||||
Domain: domain,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid domain",
|
||||
auth: Auth{
|
||||
Token: token,
|
||||
Domain: "aw.heck",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
auth: Auth{
|
||||
Token: "asdfwtweg",
|
||||
Domain: domain,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid token scopes",
|
||||
auth: Auth{
|
||||
Token: noPermToken,
|
||||
Domain: domain,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user token doesn't match domain owner",
|
||||
auth: Auth{
|
||||
Token: otherUserToken,
|
||||
Domain: domain,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
data, err := json.Marshal(cs.auth)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, _, err = s.backendAuthv1(ctx, bytes.NewBuffer(data))
|
||||
|
||||
if cs.wantErr && err == nil {
|
||||
t.Fatalf("auth did not err as expected")
|
||||
}
|
||||
|
||||
if !cs.wantErr && err != nil {
|
||||
t.Fatalf("unexpected auth err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendRouting(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
st := MockStorage()
|
||||
|
||||
st.AddRoute(domain, user)
|
||||
st.AddToken(token, user, []string{"connect"})
|
||||
|
||||
s, err := NewServer(&ServerConfig{
|
||||
Storage: st,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go s.Listen(l, false)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
should200 bool
|
||||
handler http.HandlerFunc
|
||||
}{
|
||||
{
|
||||
name: "200 everything's okay",
|
||||
should200: true,
|
||||
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "HTTP 200, everything is okay :)", http.StatusOK)
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
ts := httptest.NewServer(cs.handler)
|
||||
defer ts.Close()
|
||||
|
||||
cc := &ClientConfig{
|
||||
ConnType: "tcp",
|
||||
ServerAddr: l.Addr().String(),
|
||||
Token: token,
|
||||
BackendURL: ts.URL,
|
||||
}
|
||||
|
||||
c, err := NewClient(cc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go c.Connect(ctx) //
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://cetacean.club/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := s.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("error in doing round trip: %v", err)
|
||||
}
|
||||
|
||||
if cs.should200 && resp.StatusCode != http.StatusOK {
|
||||
resp.Write(os.Stdout)
|
||||
t.Fatalf("got status %d instead of StatusOK", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
package tun2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func MockStorage() *mockStorage {
|
||||
return &mockStorage{
|
||||
tokens: make(map[string]mockToken),
|
||||
domains: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
type mockToken struct {
|
||||
user string
|
||||
scopes []string
|
||||
}
|
||||
|
||||
// mockStorage is a simple mock of the Storage interface suitable for testing.
|
||||
type mockStorage struct {
|
||||
sync.Mutex
|
||||
tokens map[string]mockToken
|
||||
domains map[string]string
|
||||
}
|
||||
|
||||
func (ms *mockStorage) AddToken(token, user string, scopes []string) {
|
||||
ms.Lock()
|
||||
defer ms.Unlock()
|
||||
|
||||
ms.tokens[token] = mockToken{user: user, scopes: scopes}
|
||||
}
|
||||
|
||||
func (ms *mockStorage) AddRoute(domain, user string) {
|
||||
ms.Lock()
|
||||
defer ms.Unlock()
|
||||
|
||||
ms.domains[domain] = user
|
||||
}
|
||||
|
||||
func (ms *mockStorage) HasToken(token string) (string, []string, error) {
|
||||
ms.Lock()
|
||||
defer ms.Unlock()
|
||||
|
||||
tok, ok := ms.tokens[token]
|
||||
if !ok {
|
||||
return "", nil, errors.New("no such token")
|
||||
}
|
||||
|
||||
return tok.user, tok.scopes, nil
|
||||
}
|
||||
|
||||
func (ms *mockStorage) HasRoute(domain string) (string, error) {
|
||||
ms.Lock()
|
||||
defer ms.Unlock()
|
||||
|
||||
user, ok := ms.domains[domain]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func TestMockStorage(t *testing.T) {
|
||||
ms := MockStorage()
|
||||
|
||||
t.Run("token", func(t *testing.T) {
|
||||
ms.AddToken(token, user, []string{"connect"})
|
||||
|
||||
us, sc, err := ms.HasToken(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if us != user {
|
||||
t.Fatalf("username was %q, expected %q", us, user)
|
||||
}
|
||||
|
||||
if sc[0] != "connect" {
|
||||
t.Fatalf("token expected to only have one scope, connect")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("domain", func(t *testing.T) {
|
||||
ms.AddRoute(domain, user)
|
||||
|
||||
us, err := ms.HasRoute(domain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if us != user {
|
||||
t.Fatalf("username was %q, expected %q", us, user)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
Loading…
Reference in New Issue