From 59a3f45150d57bc22288079dc0d40d220e34b342 Mon Sep 17 00:00:00 2001 From: Christine Dodrill Date: Tue, 3 Oct 2017 23:43:31 -0700 Subject: [PATCH] tun2: some experimenting on the core --- internal/tun2/client.go | 13 +- internal/tun2/server.go | 338 ++++++++++++++++------------------ internal/tun2/server_test.go | 176 ++++++++++++++++++ internal/tun2/storage_test.go | 99 ++++++++++ 4 files changed, 446 insertions(+), 180 deletions(-) create mode 100644 internal/tun2/storage_test.go diff --git a/internal/tun2/client.go b/internal/tun2/client.go index 8c38c5c..635a770 100644 --- a/internal/tun2/client.go +++ b/internal/tun2/client.go @@ -1,6 +1,7 @@ package tun2 import ( + "context" "crypto/tls" "encoding/json" "errors" @@ -29,6 +30,9 @@ type ClientConfig struct { Token string Domain string BackendURL string + + // internal use only + forceTCPClear bool } // NewClient constructs an instance of Client with a given ClientConfig. @@ -49,7 +53,7 @@ func NewClient(cfg *ClientConfig) (*Client, error) { // requests to the backend HTTP server. // // This is a blocking function. -func (c *Client) Connect() error { +func (c *Client) Connect(ctx context.Context) error { return c.connect(c.cfg.ServerAddr) } @@ -67,7 +71,12 @@ func (c *Client) connect(serverAddr string) error { switch c.cfg.ConnType { case "tcp": - conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig) + if c.cfg.forceTCPClear { + conn, err = net.Dial("tcp", serverAddr) + } else { + conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig) + } + if err != nil { return err } diff --git a/internal/tun2/server.go b/internal/tun2/server.go index 332ab4b..3b92180 100644 --- a/internal/tun2/server.go +++ b/internal/tun2/server.go @@ -3,10 +3,10 @@ package tun2 import ( "bytes" "context" - "crypto/tls" "encoding/json" "errors" "fmt" + "io" "io/ioutil" "math/rand" "net" @@ -17,9 +17,9 @@ import ( "github.com/Xe/ln" failure "github.com/dgryski/go-failure" + "github.com/kr/pretty" "github.com/mtneug/pkg/ulid" cmap "github.com/streamrail/concurrent-map" - kcp "github.com/xtaci/kcp-go" "github.com/xtaci/smux" ) @@ -66,10 +66,6 @@ func gen502Page(req *http.Request) *http.Response { // ServerConfig ... type ServerConfig struct { - TCPAddr string - KCPAddr string - TLSConfig *tls.Config - SmuxConf *smux.Config Storage Storage } @@ -83,7 +79,9 @@ type Storage interface { // Server routes frontend HTTP traffic to backend TCP traffic. type Server struct { - cfg *ServerConfig + cfg *ServerConfig + ctx context.Context + cancel context.CancelFunc connlock sync.Mutex conns map[net.Conn]*Connection @@ -100,146 +98,174 @@ func NewServer(cfg *ServerConfig) (*Server, error) { if cfg.SmuxConf == nil { cfg.SmuxConf = smux.DefaultConfig() + + cfg.SmuxConf.KeepAliveInterval = time.Second + cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second } - cfg.SmuxConf.KeepAliveInterval = time.Second - cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second + ctx, cancel := context.WithCancel(context.Background()) server := &Server{ cfg: cfg, conns: map[net.Conn]*Connection{}, domains: cmap.New(), + ctx: ctx, + cancel: cancel, } + go server.phiDetectionLoop(ctx) + return server, nil } +// Close stops the background tasks for this Server. +func (s *Server) Close() { + s.cancel() +} + +// Wait blocks until the server context is cancelled. +func (s *Server) Wait() { + for { + select { + case <-s.ctx.Done(): + return + } + } +} + // Listen passes this Server a given net.Listener to accept backend connections. func (s *Server) Listen(l net.Listener, isKCP bool) { ctx := context.Background() + f := ln.F{ + "listener_addr": l.Addr(), + "listener_network": l.Addr().Network(), + } + for { conn, err := l.Accept() if err != nil { - ln.Error(ctx, err, ln.F{ - "addr": l.Addr().String(), - "network": l.Addr().Network(), - }) + ln.Error(ctx, err, f, ln.Action("accept connection")) continue } - ln.Log(ctx, ln.F{ - "action": "new_client", - "network": conn.RemoteAddr().Network(), - "addr": conn.RemoteAddr(), - "list": conn.LocalAddr(), + ln.Log(ctx, f, ln.Action("new backend client connected"), ln.F{ + "conn_addr": conn.RemoteAddr(), + "conn_network": conn.RemoteAddr().Network(), }) go s.HandleConn(conn, isKCP) } } -// ListenAndServe starts the backend TCP/KCP listeners and relays backend -// traffic to and from them. -func (s *Server) ListenAndServe() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ln.Log(ctx, ln.F{ - "action": "listen_and_serve_called", - }) - - if s.cfg.TCPAddr != "" { - go func() { - l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig) - if err != nil { - panic(err) - } - - ln.Log(ctx, ln.F{ - "action": "tcp+tls_listening", - "addr": l.Addr(), - }) - - for { - conn, err := l.Accept() - if err != nil { - ln.Error(ctx, err, ln.F{"kind": "tcp", "addr": l.Addr().String()}) - continue - } - - ln.Log(ctx, ln.F{ - "action": "new_client", - "kcp": false, - "addr": conn.RemoteAddr(), - }) - - go s.HandleConn(conn, false) - } - }() - } - - if s.cfg.KCPAddr != "" { - go func() { - l, err := kcp.Listen(s.cfg.KCPAddr) - if err != nil { - panic(err) - } - - ln.Log(ctx, ln.F{ - "action": "kcp+tls_listening", - "addr": l.Addr(), - }) - - for { - conn, err := l.Accept() - if err != nil { - ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()}) - } - - ln.Log(ctx, ln.F{ - "action": "new_client", - "kcp": true, - "addr": conn.RemoteAddr(), - }) - - tc := tls.Server(conn, s.cfg.TLSConfig) - - go s.HandleConn(tc, true) - } - }() - } - - // XXX experimental, might get rid of this inside this process - go func() { - for { - time.Sleep(time.Second) +// phiDetectionLoop is an infinite loop that will run the [phi accrual failure detector] +// for each of the backends connected to the Server. This is fairly experimental and +// may be removed. +// +// [phi accrual failure detector]: https://dspace.jaist.ac.jp/dspace/handle/10119/4784 +func (s *Server) phiDetectionLoop(ctx context.Context) { + t := time.NewTicker(5 * time.Second) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: now := time.Now() s.connlock.Lock() for _, c := range s.conns { failureChance := c.detector.Phi(now) + const thresh = 0.9 // the threshold for phi failure detection causing logs - if failureChance > 0.8 { - ln.Log(ctx, c.F(), ln.F{ - "action": "phi_failure_detection", - "value": failureChance, + if failureChance > thresh { + ln.Log(ctx, c, ln.Action("phi failure detection"), ln.F{ + "value": failureChance, + "threshold": thresh, }) } } s.connlock.Unlock() } - }() + } +} - return nil +// backendAuthv1 runs a simple backend authentication check. It expects the +// client to write a json-encoded instance of Auth. This is then checked +// for token validity and domain matching. +// +// This returns the user that was authenticated and the domain they identified +// with. +func (s *Server) backendAuthv1(ctx context.Context, st io.Reader) (string, *Auth, error) { + f := ln.F{ + "action": "backend authentication", + "backend_auth_version": 1, + } + + f["stage"] = "json decoding" + ln.Log(ctx, f) + + d := json.NewDecoder(st) + var auth Auth + err := d.Decode(&auth) + if err != nil { + ln.Error(ctx, err, f) + return "", nil, err + } + + f["stage"] = "checking domain" + ln.Log(ctx, f) + + pretty.Println(s.cfg.Storage) + routeUser, err := s.cfg.Storage.HasRoute(auth.Domain) + if err != nil { + ln.Error(ctx, err, f) + return "", nil, err + } + + f["route_user"] = routeUser + f["stage"] = "checking token" + ln.Log(ctx, f) + + tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token) + if err != nil { + ln.Error(ctx, err, f) + return "", nil, err + } + + f["token_user"] = tokenUser + f["stage"] = "checking token scopes" + ln.Log(ctx, f) + + ok := false + for _, sc := range scopes { + if sc == "connect" { + ok = true + break + } + } + + if !ok { + ln.Error(ctx, ErrAuthMismatch, f) + return "", nil, ErrAuthMismatch + } + + f["stage"] = "user verification" + ln.Log(ctx, f) + + if routeUser != tokenUser { + ln.Error(ctx, ErrAuthMismatch, f) + return "", nil, ErrAuthMismatch + } + + return routeUser, &auth, nil } // HandleConn starts up the needed mechanisms to relay HTTP traffic to/from // the currently connected backend. func (s *Server) HandleConn(c net.Conn, isKCP bool) { - // XXX TODO clean this up it's really ugly. defer c.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -258,8 +284,6 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { } defer session.Close() - f["stage"] = "smux_setup" - controlStream, err := session.OpenStream() if err != nil { ln.Error(ctx, err, f, ln.Action("opening control stream")) @@ -268,58 +292,8 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { } defer controlStream.Close() - f["stage"] = "control_stream_open" - - csd := json.NewDecoder(controlStream) - auth := &Auth{} - err = csd.Decode(auth) + user, auth, err := s.backendAuthv1(ctx, controlStream) if err != nil { - ln.Error(ctx, err, f, ln.Action("decode control stream authenication message")) - - return - } - - f["stage"] = "checking_domain" - - routeUser, err := s.cfg.Storage.HasRoute(auth.Domain) - if err != nil { - ln.Error(ctx, err, f, ln.Action("no such domain when checking client auth")) - - return - } - - f["route_user"] = routeUser - f["stage"] = "checking_token" - - tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token) - if err != nil { - ln.Error(ctx, err, f, ln.Action("no such token exists or other token error")) - - return - } - - f["token_user"] = tokenUser - f["stage"] = "checking_token_scopes" - - ok := false - for _, sc := range scopes { - if sc == "connect" { - ok = true - break - } - } - - if !ok { - ln.Error(ctx, ErrAuthMismatch, f, ln.Action("token not authorized to connect")) - - return - } - - f["stage"] = "user_verification" - - if routeUser != tokenUser { - ln.Error(ctx, ErrAuthMismatch, f, ln.Action("auth mismatch")) - return } @@ -328,7 +302,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { conn: c, isKCP: isKCP, session: session, - user: tokenUser, + user: user, domain: auth.Domain, cf: cancel, detector: failure.New(15, 1), @@ -343,26 +317,8 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { ln.Log(ctx, connection, ln.Action("backend successfully connected")) - // TODO: put these lines in a function? - s.connlock.Lock() - s.conns[c] = connection - s.connlock.Unlock() + s.addConn(ctx, connection) - var conns []*Connection - - val, ok := s.domains.Get(auth.Domain) - if ok { - conns, ok = val.([]*Connection) - if !ok { - conns = nil - - s.domains.Remove(auth.Domain) - } - } - - conns = append(conns, connection) - - s.domains.Set(auth.Domain, conns) connection.usable = true // XXX set this to true once health checks pass? ticker := time.NewTicker(5 * time.Second) @@ -375,8 +331,13 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { if err != nil { connection.cancel() } + case <-s.ctx.Done(): + s.removeConn(ctx, connection) + connection.Close() + + return case <-ctx.Done(): - s.RemoveConn(ctx, connection) + s.removeConn(ctx, connection) connection.Close() return @@ -384,8 +345,31 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { } } -// RemoveConn removes a connection. -func (s *Server) RemoveConn(ctx context.Context, connection *Connection) { +// addConn adds a connection to the pool of backend connections. +func (s *Server) addConn(ctx context.Context, connection *Connection) { + s.connlock.Lock() + s.conns[connection.conn] = connection + s.connlock.Unlock() + + var conns []*Connection + + val, ok := s.domains.Get(connection.domain) + if ok { + conns, ok = val.([]*Connection) + if !ok { + conns = nil + + s.domains.Remove(connection.domain) + } + } + + conns = append(conns, connection) + + s.domains.Set(connection.domain, conns) +} + +// removeConn removes a connection from pool of backend connections. +func (s *Server) removeConn(ctx context.Context, connection *Connection) { s.connlock.Lock() delete(s.conns, connection.conn) s.connlock.Unlock() @@ -416,8 +400,6 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) { } else { s.domains.Remove(auth.Domain) } - - ln.Log(ctx, connection, ln.Action("backend disconnect")) } // RoundTrip sends a HTTP request to a backend and then returns its response. diff --git a/internal/tun2/server_test.go b/internal/tun2/server_test.go index 3f37c4e..0ead668 100644 --- a/internal/tun2/server_test.go +++ b/internal/tun2/server_test.go @@ -1,16 +1,31 @@ package tun2 import ( + "bytes" "context" + "encoding/json" "fmt" "io/ioutil" + "net" "net/http" + "net/http/httptest" + "os" "strings" "testing" + "time" "github.com/Xe/uuid" ) +// testing constants +const ( + user = "shachi" + token = "orcaz r kewl" + noPermToken = "aw heck" + otherUserToken = "even more heck" + domain = "cetacean.club" +) + func TestNewServerNullConfig(t *testing.T) { _, err := NewServer(nil) if err == nil { @@ -51,3 +66,164 @@ func TestGen502Page(t *testing.T) { } } } + +func TestBackendAuthV1(t *testing.T) { + st := MockStorage() + + s, err := NewServer(&ServerConfig{ + Storage: st, + }) + if err != nil { + t.Fatal(err) + } + + st.AddRoute(domain, user) + st.AddToken(token, user, []string{"connect"}) + st.AddToken(noPermToken, user, nil) + st.AddToken(otherUserToken, "cadey", []string{"connect"}) + + cases := []struct { + name string + auth Auth + wantErr bool + }{ + { + name: "basic everything should work", + auth: Auth{ + Token: token, + Domain: domain, + }, + wantErr: false, + }, + { + name: "invalid domain", + auth: Auth{ + Token: token, + Domain: "aw.heck", + }, + wantErr: true, + }, + { + name: "invalid token", + auth: Auth{ + Token: "asdfwtweg", + Domain: domain, + }, + wantErr: true, + }, + { + name: "invalid token scopes", + auth: Auth{ + Token: noPermToken, + Domain: domain, + }, + wantErr: true, + }, + { + name: "user token doesn't match domain owner", + auth: Auth{ + Token: otherUserToken, + Domain: domain, + }, + wantErr: true, + }, + } + + for _, cs := range cases { + t.Run(cs.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + data, err := json.Marshal(cs.auth) + if err != nil { + t.Fatal(err) + } + + _, _, err = s.backendAuthv1(ctx, bytes.NewBuffer(data)) + + if cs.wantErr && err == nil { + t.Fatalf("auth did not err as expected") + } + + if !cs.wantErr && err != nil { + t.Fatalf("unexpected auth err: %v", err) + } + }) + } +} + +func TestBackendRouting(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + st := MockStorage() + + st.AddRoute(domain, user) + st.AddToken(token, user, []string{"connect"}) + + s, err := NewServer(&ServerConfig{ + Storage: st, + }) + if err != nil { + t.Fatal(err) + } + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + go s.Listen(l, false) + + cases := []struct { + name string + should200 bool + handler http.HandlerFunc + }{ + { + name: "200 everything's okay", + should200: true, + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "HTTP 200, everything is okay :)", http.StatusOK) + }), + }, + } + + for _, cs := range cases { + t.Run(cs.name, func(t *testing.T) { + ts := httptest.NewServer(cs.handler) + defer ts.Close() + + cc := &ClientConfig{ + ConnType: "tcp", + ServerAddr: l.Addr().String(), + Token: token, + BackendURL: ts.URL, + } + + c, err := NewClient(cc) + if err != nil { + t.Fatal(err) + } + + go c.Connect(ctx) // + + time.Sleep(5 * time.Second) + + req, err := http.NewRequest("GET", "http://cetacean.club/", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := s.RoundTrip(req) + if err != nil { + t.Fatalf("error in doing round trip: %v", err) + } + + if cs.should200 && resp.StatusCode != http.StatusOK { + resp.Write(os.Stdout) + t.Fatalf("got status %d instead of StatusOK", resp.StatusCode) + } + }) + } +} diff --git a/internal/tun2/storage_test.go b/internal/tun2/storage_test.go new file mode 100644 index 0000000..f7e068f --- /dev/null +++ b/internal/tun2/storage_test.go @@ -0,0 +1,99 @@ +package tun2 + +import ( + "errors" + "sync" + "testing" +) + +func MockStorage() *mockStorage { + return &mockStorage{ + tokens: make(map[string]mockToken), + domains: make(map[string]string), + } +} + +type mockToken struct { + user string + scopes []string +} + +// mockStorage is a simple mock of the Storage interface suitable for testing. +type mockStorage struct { + sync.Mutex + tokens map[string]mockToken + domains map[string]string +} + +func (ms *mockStorage) AddToken(token, user string, scopes []string) { + ms.Lock() + defer ms.Unlock() + + ms.tokens[token] = mockToken{user: user, scopes: scopes} +} + +func (ms *mockStorage) AddRoute(domain, user string) { + ms.Lock() + defer ms.Unlock() + + ms.domains[domain] = user +} + +func (ms *mockStorage) HasToken(token string) (string, []string, error) { + ms.Lock() + defer ms.Unlock() + + tok, ok := ms.tokens[token] + if !ok { + return "", nil, errors.New("no such token") + } + + return tok.user, tok.scopes, nil +} + +func (ms *mockStorage) HasRoute(domain string) (string, error) { + ms.Lock() + defer ms.Unlock() + + user, ok := ms.domains[domain] + if !ok { + return "", nil + } + + return user, nil +} + +func TestMockStorage(t *testing.T) { + ms := MockStorage() + + t.Run("token", func(t *testing.T) { + ms.AddToken(token, user, []string{"connect"}) + + us, sc, err := ms.HasToken(token) + if err != nil { + t.Fatal(err) + } + + if us != user { + t.Fatalf("username was %q, expected %q", us, user) + } + + if sc[0] != "connect" { + t.Fatalf("token expected to only have one scope, connect") + } + }) + + t.Run("domain", func(t *testing.T) { + ms.AddRoute(domain, user) + + us, err := ms.HasRoute(domain) + if err != nil { + t.Fatal(err) + } + + if us != user { + t.Fatalf("username was %q, expected %q", us, user) + } + }) + +}