tun2: some experimenting on the core
This commit is contained in:
parent
a47fd75c5f
commit
59a3f45150
|
@ -1,6 +1,7 @@
|
||||||
package tun2
|
package tun2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -29,6 +30,9 @@ type ClientConfig struct {
|
||||||
Token string
|
Token string
|
||||||
Domain string
|
Domain string
|
||||||
BackendURL string
|
BackendURL string
|
||||||
|
|
||||||
|
// internal use only
|
||||||
|
forceTCPClear bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient constructs an instance of Client with a given ClientConfig.
|
// NewClient constructs an instance of Client with a given ClientConfig.
|
||||||
|
@ -49,7 +53,7 @@ func NewClient(cfg *ClientConfig) (*Client, error) {
|
||||||
// requests to the backend HTTP server.
|
// requests to the backend HTTP server.
|
||||||
//
|
//
|
||||||
// This is a blocking function.
|
// This is a blocking function.
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect(ctx context.Context) error {
|
||||||
return c.connect(c.cfg.ServerAddr)
|
return c.connect(c.cfg.ServerAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +71,12 @@ func (c *Client) connect(serverAddr string) error {
|
||||||
|
|
||||||
switch c.cfg.ConnType {
|
switch c.cfg.ConnType {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig)
|
if c.cfg.forceTCPClear {
|
||||||
|
conn, err = net.Dial("tcp", serverAddr)
|
||||||
|
} else {
|
||||||
|
conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,10 +3,10 @@ package tun2
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
@ -17,9 +17,9 @@ import (
|
||||||
|
|
||||||
"github.com/Xe/ln"
|
"github.com/Xe/ln"
|
||||||
failure "github.com/dgryski/go-failure"
|
failure "github.com/dgryski/go-failure"
|
||||||
|
"github.com/kr/pretty"
|
||||||
"github.com/mtneug/pkg/ulid"
|
"github.com/mtneug/pkg/ulid"
|
||||||
cmap "github.com/streamrail/concurrent-map"
|
cmap "github.com/streamrail/concurrent-map"
|
||||||
kcp "github.com/xtaci/kcp-go"
|
|
||||||
"github.com/xtaci/smux"
|
"github.com/xtaci/smux"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -66,10 +66,6 @@ func gen502Page(req *http.Request) *http.Response {
|
||||||
|
|
||||||
// ServerConfig ...
|
// ServerConfig ...
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
TCPAddr string
|
|
||||||
KCPAddr string
|
|
||||||
TLSConfig *tls.Config
|
|
||||||
|
|
||||||
SmuxConf *smux.Config
|
SmuxConf *smux.Config
|
||||||
Storage Storage
|
Storage Storage
|
||||||
}
|
}
|
||||||
|
@ -83,7 +79,9 @@ type Storage interface {
|
||||||
|
|
||||||
// Server routes frontend HTTP traffic to backend TCP traffic.
|
// Server routes frontend HTTP traffic to backend TCP traffic.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
cfg *ServerConfig
|
cfg *ServerConfig
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
connlock sync.Mutex
|
connlock sync.Mutex
|
||||||
conns map[net.Conn]*Connection
|
conns map[net.Conn]*Connection
|
||||||
|
@ -100,146 +98,174 @@ func NewServer(cfg *ServerConfig) (*Server, error) {
|
||||||
|
|
||||||
if cfg.SmuxConf == nil {
|
if cfg.SmuxConf == nil {
|
||||||
cfg.SmuxConf = smux.DefaultConfig()
|
cfg.SmuxConf = smux.DefaultConfig()
|
||||||
|
|
||||||
|
cfg.SmuxConf.KeepAliveInterval = time.Second
|
||||||
|
cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.SmuxConf.KeepAliveInterval = time.Second
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second
|
|
||||||
|
|
||||||
server := &Server{
|
server := &Server{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
|
||||||
conns: map[net.Conn]*Connection{},
|
conns: map[net.Conn]*Connection{},
|
||||||
domains: cmap.New(),
|
domains: cmap.New(),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go server.phiDetectionLoop(ctx)
|
||||||
|
|
||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close stops the background tasks for this Server.
|
||||||
|
func (s *Server) Close() {
|
||||||
|
s.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks until the server context is cancelled.
|
||||||
|
func (s *Server) Wait() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Listen passes this Server a given net.Listener to accept backend connections.
|
// Listen passes this Server a given net.Listener to accept backend connections.
|
||||||
func (s *Server) Listen(l net.Listener, isKCP bool) {
|
func (s *Server) Listen(l net.Listener, isKCP bool) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
f := ln.F{
|
||||||
|
"listener_addr": l.Addr(),
|
||||||
|
"listener_network": l.Addr().Network(),
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := l.Accept()
|
conn, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ln.Error(ctx, err, ln.F{
|
ln.Error(ctx, err, f, ln.Action("accept connection"))
|
||||||
"addr": l.Addr().String(),
|
|
||||||
"network": l.Addr().Network(),
|
|
||||||
})
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ln.Log(ctx, ln.F{
|
ln.Log(ctx, f, ln.Action("new backend client connected"), ln.F{
|
||||||
"action": "new_client",
|
"conn_addr": conn.RemoteAddr(),
|
||||||
"network": conn.RemoteAddr().Network(),
|
"conn_network": conn.RemoteAddr().Network(),
|
||||||
"addr": conn.RemoteAddr(),
|
|
||||||
"list": conn.LocalAddr(),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
go s.HandleConn(conn, isKCP)
|
go s.HandleConn(conn, isKCP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServe starts the backend TCP/KCP listeners and relays backend
|
// phiDetectionLoop is an infinite loop that will run the [phi accrual failure detector]
|
||||||
// traffic to and from them.
|
// for each of the backends connected to the Server. This is fairly experimental and
|
||||||
func (s *Server) ListenAndServe() error {
|
// may be removed.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
//
|
||||||
defer cancel()
|
// [phi accrual failure detector]: https://dspace.jaist.ac.jp/dspace/handle/10119/4784
|
||||||
|
func (s *Server) phiDetectionLoop(ctx context.Context) {
|
||||||
ln.Log(ctx, ln.F{
|
t := time.NewTicker(5 * time.Second)
|
||||||
"action": "listen_and_serve_called",
|
defer t.Stop()
|
||||||
})
|
|
||||||
|
|
||||||
if s.cfg.TCPAddr != "" {
|
|
||||||
go func() {
|
|
||||||
l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ln.Log(ctx, ln.F{
|
|
||||||
"action": "tcp+tls_listening",
|
|
||||||
"addr": l.Addr(),
|
|
||||||
})
|
|
||||||
|
|
||||||
for {
|
|
||||||
conn, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
ln.Error(ctx, err, ln.F{"kind": "tcp", "addr": l.Addr().String()})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ln.Log(ctx, ln.F{
|
|
||||||
"action": "new_client",
|
|
||||||
"kcp": false,
|
|
||||||
"addr": conn.RemoteAddr(),
|
|
||||||
})
|
|
||||||
|
|
||||||
go s.HandleConn(conn, false)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.cfg.KCPAddr != "" {
|
|
||||||
go func() {
|
|
||||||
l, err := kcp.Listen(s.cfg.KCPAddr)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ln.Log(ctx, ln.F{
|
|
||||||
"action": "kcp+tls_listening",
|
|
||||||
"addr": l.Addr(),
|
|
||||||
})
|
|
||||||
|
|
||||||
for {
|
|
||||||
conn, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()})
|
|
||||||
}
|
|
||||||
|
|
||||||
ln.Log(ctx, ln.F{
|
|
||||||
"action": "new_client",
|
|
||||||
"kcp": true,
|
|
||||||
"addr": conn.RemoteAddr(),
|
|
||||||
})
|
|
||||||
|
|
||||||
tc := tls.Server(conn, s.cfg.TLSConfig)
|
|
||||||
|
|
||||||
go s.HandleConn(tc, true)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// XXX experimental, might get rid of this inside this process
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-t.C:
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
s.connlock.Lock()
|
s.connlock.Lock()
|
||||||
for _, c := range s.conns {
|
for _, c := range s.conns {
|
||||||
failureChance := c.detector.Phi(now)
|
failureChance := c.detector.Phi(now)
|
||||||
|
const thresh = 0.9 // the threshold for phi failure detection causing logs
|
||||||
|
|
||||||
if failureChance > 0.8 {
|
if failureChance > thresh {
|
||||||
ln.Log(ctx, c.F(), ln.F{
|
ln.Log(ctx, c, ln.Action("phi failure detection"), ln.F{
|
||||||
"action": "phi_failure_detection",
|
"value": failureChance,
|
||||||
"value": failureChance,
|
"threshold": thresh,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.connlock.Unlock()
|
s.connlock.Unlock()
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
// backendAuthv1 runs a simple backend authentication check. It expects the
|
||||||
|
// client to write a json-encoded instance of Auth. This is then checked
|
||||||
|
// for token validity and domain matching.
|
||||||
|
//
|
||||||
|
// This returns the user that was authenticated and the domain they identified
|
||||||
|
// with.
|
||||||
|
func (s *Server) backendAuthv1(ctx context.Context, st io.Reader) (string, *Auth, error) {
|
||||||
|
f := ln.F{
|
||||||
|
"action": "backend authentication",
|
||||||
|
"backend_auth_version": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
f["stage"] = "json decoding"
|
||||||
|
ln.Log(ctx, f)
|
||||||
|
|
||||||
|
d := json.NewDecoder(st)
|
||||||
|
var auth Auth
|
||||||
|
err := d.Decode(&auth)
|
||||||
|
if err != nil {
|
||||||
|
ln.Error(ctx, err, f)
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f["stage"] = "checking domain"
|
||||||
|
ln.Log(ctx, f)
|
||||||
|
|
||||||
|
pretty.Println(s.cfg.Storage)
|
||||||
|
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
|
||||||
|
if err != nil {
|
||||||
|
ln.Error(ctx, err, f)
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f["route_user"] = routeUser
|
||||||
|
f["stage"] = "checking token"
|
||||||
|
ln.Log(ctx, f)
|
||||||
|
|
||||||
|
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
|
||||||
|
if err != nil {
|
||||||
|
ln.Error(ctx, err, f)
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f["token_user"] = tokenUser
|
||||||
|
f["stage"] = "checking token scopes"
|
||||||
|
ln.Log(ctx, f)
|
||||||
|
|
||||||
|
ok := false
|
||||||
|
for _, sc := range scopes {
|
||||||
|
if sc == "connect" {
|
||||||
|
ok = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
ln.Error(ctx, ErrAuthMismatch, f)
|
||||||
|
return "", nil, ErrAuthMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
f["stage"] = "user verification"
|
||||||
|
ln.Log(ctx, f)
|
||||||
|
|
||||||
|
if routeUser != tokenUser {
|
||||||
|
ln.Error(ctx, ErrAuthMismatch, f)
|
||||||
|
return "", nil, ErrAuthMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
return routeUser, &auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from
|
// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from
|
||||||
// the currently connected backend.
|
// the currently connected backend.
|
||||||
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
// XXX TODO clean this up it's really ugly.
|
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -258,8 +284,6 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|
||||||
f["stage"] = "smux_setup"
|
|
||||||
|
|
||||||
controlStream, err := session.OpenStream()
|
controlStream, err := session.OpenStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ln.Error(ctx, err, f, ln.Action("opening control stream"))
|
ln.Error(ctx, err, f, ln.Action("opening control stream"))
|
||||||
|
@ -268,58 +292,8 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
}
|
}
|
||||||
defer controlStream.Close()
|
defer controlStream.Close()
|
||||||
|
|
||||||
f["stage"] = "control_stream_open"
|
user, auth, err := s.backendAuthv1(ctx, controlStream)
|
||||||
|
|
||||||
csd := json.NewDecoder(controlStream)
|
|
||||||
auth := &Auth{}
|
|
||||||
err = csd.Decode(auth)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ln.Error(ctx, err, f, ln.Action("decode control stream authenication message"))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f["stage"] = "checking_domain"
|
|
||||||
|
|
||||||
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
|
|
||||||
if err != nil {
|
|
||||||
ln.Error(ctx, err, f, ln.Action("no such domain when checking client auth"))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f["route_user"] = routeUser
|
|
||||||
f["stage"] = "checking_token"
|
|
||||||
|
|
||||||
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
|
|
||||||
if err != nil {
|
|
||||||
ln.Error(ctx, err, f, ln.Action("no such token exists or other token error"))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f["token_user"] = tokenUser
|
|
||||||
f["stage"] = "checking_token_scopes"
|
|
||||||
|
|
||||||
ok := false
|
|
||||||
for _, sc := range scopes {
|
|
||||||
if sc == "connect" {
|
|
||||||
ok = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
ln.Error(ctx, ErrAuthMismatch, f, ln.Action("token not authorized to connect"))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f["stage"] = "user_verification"
|
|
||||||
|
|
||||||
if routeUser != tokenUser {
|
|
||||||
ln.Error(ctx, ErrAuthMismatch, f, ln.Action("auth mismatch"))
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -328,7 +302,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
conn: c,
|
conn: c,
|
||||||
isKCP: isKCP,
|
isKCP: isKCP,
|
||||||
session: session,
|
session: session,
|
||||||
user: tokenUser,
|
user: user,
|
||||||
domain: auth.Domain,
|
domain: auth.Domain,
|
||||||
cf: cancel,
|
cf: cancel,
|
||||||
detector: failure.New(15, 1),
|
detector: failure.New(15, 1),
|
||||||
|
@ -343,26 +317,8 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
|
|
||||||
ln.Log(ctx, connection, ln.Action("backend successfully connected"))
|
ln.Log(ctx, connection, ln.Action("backend successfully connected"))
|
||||||
|
|
||||||
// TODO: put these lines in a function?
|
s.addConn(ctx, connection)
|
||||||
s.connlock.Lock()
|
|
||||||
s.conns[c] = connection
|
|
||||||
s.connlock.Unlock()
|
|
||||||
|
|
||||||
var conns []*Connection
|
|
||||||
|
|
||||||
val, ok := s.domains.Get(auth.Domain)
|
|
||||||
if ok {
|
|
||||||
conns, ok = val.([]*Connection)
|
|
||||||
if !ok {
|
|
||||||
conns = nil
|
|
||||||
|
|
||||||
s.domains.Remove(auth.Domain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
conns = append(conns, connection)
|
|
||||||
|
|
||||||
s.domains.Set(auth.Domain, conns)
|
|
||||||
connection.usable = true // XXX set this to true once health checks pass?
|
connection.usable = true // XXX set this to true once health checks pass?
|
||||||
|
|
||||||
ticker := time.NewTicker(5 * time.Second)
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
@ -375,8 +331,13 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connection.cancel()
|
connection.cancel()
|
||||||
}
|
}
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
s.removeConn(ctx, connection)
|
||||||
|
connection.Close()
|
||||||
|
|
||||||
|
return
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
s.RemoveConn(ctx, connection)
|
s.removeConn(ctx, connection)
|
||||||
connection.Close()
|
connection.Close()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -384,8 +345,31 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveConn removes a connection.
|
// addConn adds a connection to the pool of backend connections.
|
||||||
func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
|
func (s *Server) addConn(ctx context.Context, connection *Connection) {
|
||||||
|
s.connlock.Lock()
|
||||||
|
s.conns[connection.conn] = connection
|
||||||
|
s.connlock.Unlock()
|
||||||
|
|
||||||
|
var conns []*Connection
|
||||||
|
|
||||||
|
val, ok := s.domains.Get(connection.domain)
|
||||||
|
if ok {
|
||||||
|
conns, ok = val.([]*Connection)
|
||||||
|
if !ok {
|
||||||
|
conns = nil
|
||||||
|
|
||||||
|
s.domains.Remove(connection.domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conns = append(conns, connection)
|
||||||
|
|
||||||
|
s.domains.Set(connection.domain, conns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeConn removes a connection from pool of backend connections.
|
||||||
|
func (s *Server) removeConn(ctx context.Context, connection *Connection) {
|
||||||
s.connlock.Lock()
|
s.connlock.Lock()
|
||||||
delete(s.conns, connection.conn)
|
delete(s.conns, connection.conn)
|
||||||
s.connlock.Unlock()
|
s.connlock.Unlock()
|
||||||
|
@ -416,8 +400,6 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
|
||||||
} else {
|
} else {
|
||||||
s.domains.Remove(auth.Domain)
|
s.domains.Remove(auth.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
ln.Log(ctx, connection, ln.Action("backend disconnect"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip sends a HTTP request to a backend and then returns its response.
|
// RoundTrip sends a HTTP request to a backend and then returns its response.
|
||||||
|
|
|
@ -1,16 +1,31 @@
|
||||||
package tun2
|
package tun2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Xe/uuid"
|
"github.com/Xe/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testing constants
|
||||||
|
const (
|
||||||
|
user = "shachi"
|
||||||
|
token = "orcaz r kewl"
|
||||||
|
noPermToken = "aw heck"
|
||||||
|
otherUserToken = "even more heck"
|
||||||
|
domain = "cetacean.club"
|
||||||
|
)
|
||||||
|
|
||||||
func TestNewServerNullConfig(t *testing.T) {
|
func TestNewServerNullConfig(t *testing.T) {
|
||||||
_, err := NewServer(nil)
|
_, err := NewServer(nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -51,3 +66,164 @@ func TestGen502Page(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBackendAuthV1(t *testing.T) {
|
||||||
|
st := MockStorage()
|
||||||
|
|
||||||
|
s, err := NewServer(&ServerConfig{
|
||||||
|
Storage: st,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
st.AddRoute(domain, user)
|
||||||
|
st.AddToken(token, user, []string{"connect"})
|
||||||
|
st.AddToken(noPermToken, user, nil)
|
||||||
|
st.AddToken(otherUserToken, "cadey", []string{"connect"})
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
auth Auth
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic everything should work",
|
||||||
|
auth: Auth{
|
||||||
|
Token: token,
|
||||||
|
Domain: domain,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid domain",
|
||||||
|
auth: Auth{
|
||||||
|
Token: token,
|
||||||
|
Domain: "aw.heck",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid token",
|
||||||
|
auth: Auth{
|
||||||
|
Token: "asdfwtweg",
|
||||||
|
Domain: domain,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid token scopes",
|
||||||
|
auth: Auth{
|
||||||
|
Token: noPermToken,
|
||||||
|
Domain: domain,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user token doesn't match domain owner",
|
||||||
|
auth: Auth{
|
||||||
|
Token: otherUserToken,
|
||||||
|
Domain: domain,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cs := range cases {
|
||||||
|
t.Run(cs.name, func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
data, err := json.Marshal(cs.auth)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = s.backendAuthv1(ctx, bytes.NewBuffer(data))
|
||||||
|
|
||||||
|
if cs.wantErr && err == nil {
|
||||||
|
t.Fatalf("auth did not err as expected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cs.wantErr && err != nil {
|
||||||
|
t.Fatalf("unexpected auth err: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendRouting(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
st := MockStorage()
|
||||||
|
|
||||||
|
st.AddRoute(domain, user)
|
||||||
|
st.AddToken(token, user, []string{"connect"})
|
||||||
|
|
||||||
|
s, err := NewServer(&ServerConfig{
|
||||||
|
Storage: st,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.Listen(l, false)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
should200 bool
|
||||||
|
handler http.HandlerFunc
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "200 everything's okay",
|
||||||
|
should200: true,
|
||||||
|
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, "HTTP 200, everything is okay :)", http.StatusOK)
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cs := range cases {
|
||||||
|
t.Run(cs.name, func(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(cs.handler)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
cc := &ClientConfig{
|
||||||
|
ConnType: "tcp",
|
||||||
|
ServerAddr: l.Addr().String(),
|
||||||
|
Token: token,
|
||||||
|
BackendURL: ts.URL,
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := NewClient(cc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go c.Connect(ctx) //
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", "http://cetacean.club/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error in doing round trip: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cs.should200 && resp.StatusCode != http.StatusOK {
|
||||||
|
resp.Write(os.Stdout)
|
||||||
|
t.Fatalf("got status %d instead of StatusOK", resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
package tun2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func MockStorage() *mockStorage {
|
||||||
|
return &mockStorage{
|
||||||
|
tokens: make(map[string]mockToken),
|
||||||
|
domains: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockToken struct {
|
||||||
|
user string
|
||||||
|
scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockStorage is a simple mock of the Storage interface suitable for testing.
|
||||||
|
type mockStorage struct {
|
||||||
|
sync.Mutex
|
||||||
|
tokens map[string]mockToken
|
||||||
|
domains map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ms *mockStorage) AddToken(token, user string, scopes []string) {
|
||||||
|
ms.Lock()
|
||||||
|
defer ms.Unlock()
|
||||||
|
|
||||||
|
ms.tokens[token] = mockToken{user: user, scopes: scopes}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ms *mockStorage) AddRoute(domain, user string) {
|
||||||
|
ms.Lock()
|
||||||
|
defer ms.Unlock()
|
||||||
|
|
||||||
|
ms.domains[domain] = user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ms *mockStorage) HasToken(token string) (string, []string, error) {
|
||||||
|
ms.Lock()
|
||||||
|
defer ms.Unlock()
|
||||||
|
|
||||||
|
tok, ok := ms.tokens[token]
|
||||||
|
if !ok {
|
||||||
|
return "", nil, errors.New("no such token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return tok.user, tok.scopes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ms *mockStorage) HasRoute(domain string) (string, error) {
|
||||||
|
ms.Lock()
|
||||||
|
defer ms.Unlock()
|
||||||
|
|
||||||
|
user, ok := ms.domains[domain]
|
||||||
|
if !ok {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockStorage(t *testing.T) {
|
||||||
|
ms := MockStorage()
|
||||||
|
|
||||||
|
t.Run("token", func(t *testing.T) {
|
||||||
|
ms.AddToken(token, user, []string{"connect"})
|
||||||
|
|
||||||
|
us, sc, err := ms.HasToken(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if us != user {
|
||||||
|
t.Fatalf("username was %q, expected %q", us, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sc[0] != "connect" {
|
||||||
|
t.Fatalf("token expected to only have one scope, connect")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("domain", func(t *testing.T) {
|
||||||
|
ms.AddRoute(domain, user)
|
||||||
|
|
||||||
|
us, err := ms.HasRoute(domain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if us != user {
|
||||||
|
t.Fatalf("username was %q, expected %q", us, user)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue