diff --git a/cmd/route-cli/.gitignore b/cmd/route-cli/.gitignore index 6f36346..41fbb63 100644 --- a/cmd/route-cli/.gitignore +++ b/cmd/route-cli/.gitignore @@ -1 +1 @@ -route +route-cli diff --git a/cmd/route-httpagent/.gitignore b/cmd/route-httpagent/.gitignore index 0380c39..f22d7b4 100644 --- a/cmd/route-httpagent/.gitignore +++ b/cmd/route-httpagent/.gitignore @@ -1 +1 @@ -route-httpagent \ No newline at end of file +route-httpagent diff --git a/cmd/route-httpagent/main.go b/cmd/route-httpagent/main.go index 12feafa..f6acac3 100644 --- a/cmd/route-httpagent/main.go +++ b/cmd/route-httpagent/main.go @@ -35,7 +35,7 @@ func main() { client, _ := tun2.NewClient(cfg) for { - err := client.Connect() + err := client.Connect(context.Background()) if err != nil { ln.Error(context.Background(), err, ln.Action("client connection failed")) } diff --git a/internal/elfs/elfs_test.go b/internal/elfs/elfs_test.go new file mode 100644 index 0000000..c433477 --- /dev/null +++ b/internal/elfs/elfs_test.go @@ -0,0 +1,10 @@ +package elfs + +import "testing" + +func TestMakeName(t *testing.T) { + n := MakeName() + if len(n) == 0 { + t.Fatalf("MakeName had a zero output") + } +} diff --git a/internal/middleware/trace_test.go b/internal/middleware/trace_test.go new file mode 100644 index 0000000..cc4461a --- /dev/null +++ b/internal/middleware/trace_test.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestTrace(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var executed bool + var handler http.Handler = Trace(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executed = true + w.WriteHeader(http.StatusOK) + })) + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("error when creating request: %v", err) + } + req = req.WithContext(ctx) + + rw := httptest.NewRecorder() + + handler.ServeHTTP(rw, req) + + if !executed { + t.Fatal("middleware Trace doesn't pass through to underlying handler") + } +} diff --git a/internal/routecrypto/rsa_test.go b/internal/routecrypto/rsa_test.go new file mode 100644 index 0000000..08e742e --- /dev/null +++ b/internal/routecrypto/rsa_test.go @@ -0,0 +1,41 @@ +package routecrypto + +import "testing" + +var ( + rsaPrivKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQC6C94euSI3GAbszcTVvuBI4ejM/fugqe/uUyXz2bUIGemkADBh +OOkNWXFi/gnYylHRrFKOH06wxhzZWpsBMacmwx6tD7a7nKktcw7HsVFL8is0PPnp +syhWfW+DF6vMDZxkgI3iKrr9/WY/3/qUg7ga17s1JXb3SmQ2sMDTh5I6DQIET4Bo +LwKBgCBG2EmsLiVPCXwN+Mk8IGck7BHKhVpcm955VDDiuKNMuFK4F9ak3tbsKOza +UDC+JhqhB1U7/J8zABM+qVqHBwse1sJMZUEXPuGbIuw4vmEHFA+scAuwkpmRx4gA +/Ghi9eWr1rDlrRFMEF5vs18GObY7Z07GxTx/nZPx7FZ+6FqZAkEA24zob4NMKGUj +efHggZ4DFiIGDEbfbRS6a/w7VicJwI41pwhbGj7KCPZEwXYhnXR3H9UXSrowsm14 +D0Wbsw4gRwJBANjvAbFVBAW8TWxLCgKx7uyHehygEBl5NY2in/8QHMjJpE7fQX5U +qutOL68A6+8P0lrtoz4VJZSnAxwkaifM8QsCQA37iRRm+Qd64OetQrHj+FhiZlrJ +LAT0CUWmADJ5KYX49B2lfNXDrXOsUG9sZ4tHKRGDt51KC/0KjMgq9BGx41MCQF0y +FxOL0s2EtXz/33V4QA9twe9xUBDY4CMts4Eyq3xlscbBBe4IjwrcKuntJ3POkGPS +Xotb9TDONmrANIqlmbECQCD8Uo0bgt8kR5bShqkbW1e5qVNz5w4+tM7Uh+oQMIGB +bC3xLJD4u2NPTwTdqKxxkeicFMKpuiGvX200M/CcoVc= +-----END RSA PRIVATE KEY-----`) +) + +func TestRSA(t *testing.T) { + pk, err := PemToRSAPrivateKey(rsaPrivKey) + if err != nil { + t.Fatalf("can't parse key: %v", err) + } + + pkd := RSAPrivateKeyToPem(pk) + + pk2, err := PemToRSAPrivateKey(pkd) + if err != nil { + t.Fatalf("can't parse key: %v", err) + } + + pkd2 := RSAPrivateKeyToPem(pk2) + + if string(pkd) != string(pkd2) { + t.Fatalf("functions are not 1:1") + } +} diff --git a/internal/routecrypto/secretbox_test.go b/internal/routecrypto/secretbox_test.go new file mode 100644 index 0000000..05e3d52 --- /dev/null +++ b/internal/routecrypto/secretbox_test.go @@ -0,0 +1,40 @@ +package routecrypto + +import "testing" + +func TestSecretBox(t *testing.T) { + var ( + key *[32]byte + sk string + ) + + t.Run("generate key", func(t *testing.T) { + var err error + key, err = GenerateKey() + if err != nil { + t.Fatalf("can't generate key: %v", err) + } + }) + + if key == nil { + t.Fatal("can't continue") + } + + t.Run("show key", func(t *testing.T) { + sk = ShowKey(key) + if len(sk) == 0 { + t.Fatal("expected output to be a nonzero length string") + } + }) + + t.Run("read key", func(t *testing.T) { + readKey, err := ParseKey(sk) + if err != nil { + t.Fatal(err) + } + + if *key != *readKey { + t.Fatal("key did not parse out correctly") + } + }) +} diff --git a/internal/server/server.go b/internal/server/server.go index 6cbfbd4..313e9c6 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -11,8 +11,12 @@ import ( "git.xeserv.us/xena/route/internal/database" "git.xeserv.us/xena/route/internal/tun2" proto "git.xeserv.us/xena/route/proto" + "github.com/Xe/ln" "github.com/mtneug/pkg/ulid" + "github.com/oxtoacart/bpool" + kcp "github.com/xtaci/kcp-go" "golang.org/x/crypto/acme/autocert" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -46,6 +50,56 @@ type Config struct { CertKey *[32]byte } +func (s *Server) listenTCP(ctx context.Context, addr string, tcfg *tls.Config) { + l, err := tls.Listen("tcp", addr, tcfg) + if err != nil { + panic(err) + } + + ln.Log(ctx, ln.Action("tcp+tls listening"), ln.F{"addr": l.Addr()}) + + for { + conn, err := l.Accept() + if err != nil { + ln.Error(ctx, err, ln.Action("accept backend client socket")) + } + + ln.Log(ctx, ln.F{ + "action": "new backend client", + "addr": conn.RemoteAddr(), + "network": conn.RemoteAddr().Network(), + }) + + go s.ts.HandleConn(conn, false) + } +} + +func (s *Server) listenKCP(ctx context.Context, addr string, tcfg *tls.Config) { + l, err := kcp.Listen(addr) + if err != nil { + panic(err) + } + + ln.Log(ctx, ln.Action("kcp+tls listening"), ln.F{"addr": l.Addr()}) + + for { + conn, err := l.Accept() + if err != nil { + ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()}) + } + + ln.Log(ctx, ln.F{ + "action": "new_client", + "network": conn.RemoteAddr().Network(), + "addr": conn.RemoteAddr(), + }) + + tc := tls.Server(conn, tcfg) + + go s.ts.HandleConn(tc, true) + } +} + // New creates a new Server func New(cfg Config) (*Server, error) { if cfg.CertKey == nil { @@ -65,11 +119,6 @@ func New(cfg Config) (*Server, error) { } tcfg := &tun2.ServerConfig{ - TCPAddr: cfg.BackendTCPAddr, - KCPAddr: cfg.BackendKCPAddr, - TLSConfig: &tls.Config{ - GetCertificate: m.GetCertificate, - }, Storage: &storageWrapper{ Storage: db, }, @@ -79,6 +128,7 @@ func New(cfg Config) (*Server, error) { if err != nil { return nil, err } + s := &Server{ cfg: &cfg, db: db, @@ -87,13 +137,15 @@ func New(cfg Config) (*Server, error) { Manager: m, } - s.ts = ts - go ts.ListenAndServe() + tc := &tls.Config{ + GetCertificate: m.GetCertificate, + } - gs := grpc.NewServer(grpc.Creds(credentials.NewTLS(&tls.Config{ - GetCertificate: m.GetCertificate, - InsecureSkipVerify: true, - }))) + go s.listenKCP(context.Background(), cfg.BackendKCPAddr, tc) + go s.listenTCP(context.Background(), cfg.BackendTCPAddr, tc) + + // gRPC setup + gs := grpc.NewServer(grpc.Creds(credentials.NewTLS(tc))) proto.RegisterBackendsServer(gs, &Backend{Server: s}) proto.RegisterRoutesServer(gs, &Route{Server: s}) @@ -140,7 +192,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { Director: s.Director, Transport: s.ts, FlushInterval: 1 * time.Second, - //BufferPool: bpool.NewBytePool(256, 4096), + BufferPool: bpool.NewBytePool(256, 4096), } rp.ServeHTTP(w, r) diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..506fdb3 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,29 @@ +package server + +import ( + "net/http" + "testing" +) + +func TestDirector(t *testing.T) { + s := &Server{} + + req, err := http.NewRequest("GET", "https://cetacean.club/", nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Add("X-Forwarded-For", "Rick-James") + req.Header.Add("X-Client-Ip", "56.32.51.84") + + s.Director(req) + + for _, header := range []string{"X-Forwarded-For", "X-Client-Ip"} { + t.Run(header, func(t *testing.T) { + val := req.Header.Get(header) + if val != "" { + t.Fatalf("expected header %q to have no value, got: %v", header, val) + } + }) + } +} diff --git a/internal/tun2/backend.go b/internal/tun2/backend.go index 37af1aa..30fd2e1 100644 --- a/internal/tun2/backend.go +++ b/internal/tun2/backend.go @@ -1,5 +1,7 @@ package tun2 +import "time" + // Backend is the public state of an individual Connection. type Backend struct { ID string @@ -10,3 +12,72 @@ type Backend struct { Host string Usable bool } + +type backendMatcher func(*Connection) bool + +func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend { + s.connlock.Lock() + defer s.connlock.Unlock() + + var result []Backend + + for _, c := range s.conns { + if !bm(c) { + continue + } + + protocol := "tcp" + if c.isKCP { + protocol = "kcp" + } + + result = append(result, Backend{ + ID: c.id, + Proto: protocol, + User: c.user, + Domain: c.domain, + Phi: float32(c.detector.Phi(time.Now())), + Host: c.conn.RemoteAddr().String(), + Usable: c.usable, + }) + } + + return result +} + +// KillBackend forcibly disconnects a given backend but doesn't offer a way to +// "ban" it from reconnecting. +func (s *Server) KillBackend(id string) error { + s.connlock.Lock() + defer s.connlock.Unlock() + + for _, c := range s.conns { + if c.id == id { + c.cancel() + return nil + } + } + + return ErrNoSuchBackend +} + +// GetBackendsForDomain fetches all backends connected to this server associated +// to a single public domain name. +func (s *Server) GetBackendsForDomain(domain string) []Backend { + return s.getBackendsForMatcher(func(c *Connection) bool { + return c.domain == domain + }) +} + +// GetBackendsForUser fetches all backends connected to this server owned by a +// given user by username. +func (s *Server) GetBackendsForUser(uname string) []Backend { + return s.getBackendsForMatcher(func(c *Connection) bool { + return c.user == uname + }) +} + +// GetAllBackends fetches every backend connected to this server. +func (s *Server) GetAllBackends() []Backend { + return s.getBackendsForMatcher(func(*Connection) bool { return true }) +} diff --git a/internal/tun2/client.go b/internal/tun2/client.go index adb1c8b..635a770 100644 --- a/internal/tun2/client.go +++ b/internal/tun2/client.go @@ -1,6 +1,7 @@ package tun2 import ( + "context" "crypto/tls" "encoding/json" "errors" @@ -14,10 +15,14 @@ import ( "github.com/xtaci/smux" ) +// Client connects to a remote tun2 server and sets up authentication before routing +// individual HTTP requests to discrete streams that are reverse proxied to the eventual +// backend. type Client struct { cfg *ClientConfig } +// ClientConfig configures client with settings that the user provides. type ClientConfig struct { TLSConfig *tls.Config ConnType string @@ -25,8 +30,12 @@ type ClientConfig struct { Token string Domain string BackendURL string + + // internal use only + forceTCPClear bool } +// NewClient constructs an instance of Client with a given ClientConfig. func NewClient(cfg *ClientConfig) (*Client, error) { if cfg == nil { return nil, errors.New("tun2: client config needed") @@ -39,7 +48,12 @@ func NewClient(cfg *ClientConfig) (*Client, error) { return c, nil } -func (c *Client) Connect() error { +// Connect dials the remote server and negotiates a client session with its +// configured server address. This will then continuously proxy incoming HTTP +// requests to the backend HTTP server. +// +// This is a blocking function. +func (c *Client) Connect(ctx context.Context) error { return c.connect(c.cfg.ServerAddr) } @@ -57,7 +71,12 @@ func (c *Client) connect(serverAddr string) error { switch c.cfg.ConnType { case "tcp": - conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig) + if c.cfg.forceTCPClear { + conn, err = net.Dial("tcp", serverAddr) + } else { + conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig) + } + if err != nil { return err } @@ -117,15 +136,12 @@ func (c *Client) connect(serverAddr string) error { return nil } +// smuxListener wraps a smux session as a net.Listener. type smuxListener struct { conn net.Conn session *smux.Session } -var ( - _ net.Listener = &smuxListener{} // interface check -) - func (sl *smuxListener) Accept() (net.Conn, error) { return sl.session.AcceptStream() } diff --git a/internal/tun2/client_test.go b/internal/tun2/client_test.go new file mode 100644 index 0000000..d3127a7 --- /dev/null +++ b/internal/tun2/client_test.go @@ -0,0 +1,21 @@ +package tun2 + +import ( + "net" + "testing" +) + +func TestNewClientNullConfig(t *testing.T) { + _, err := NewClient(nil) + if err == nil { + t.Fatalf("expected NewClient(nil) to fail, got non-failure") + } +} + +func TestSmuxListenerIsNetListener(t *testing.T) { + var sl interface{} = &smuxListener{} + _, ok := sl.(net.Listener) + if !ok { + t.Fatalf("smuxListener does not implement net.Listener") + } +} diff --git a/internal/tun2/connection.go b/internal/tun2/connection.go index 9db2195..40f7c4d 100644 --- a/internal/tun2/connection.go +++ b/internal/tun2/connection.go @@ -2,11 +2,10 @@ package tun2 import ( "bufio" - "bytes" "context" - "io/ioutil" "net" "net/http" + "strconv" "time" "github.com/Xe/ln" @@ -127,7 +126,12 @@ func (c *Connection) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { return nil, errors.Wrap(err, ErrCantOpenSessionStream.Error()) } - defer stream.Close() + + go func() { + time.Sleep(30 * time.Minute) + + stream.Close() + }() err = req.Write(stream) if err != nil { @@ -142,13 +146,13 @@ func (c *Connection) RoundTrip(req *http.Request) (*http.Response, error) { } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + cl := resp.Header.Get("Content-Length") + asInt, err := strconv.Atoi(cl) if err != nil { - return nil, errors.Wrap(err, "can't read response body") + return nil, err } - resp.Body = ioutil.NopCloser(bytes.NewBuffer(body)) - resp.ContentLength = int64(len(body)) + resp.ContentLength = int64(asInt) return resp, nil } diff --git a/internal/tun2/server.go b/internal/tun2/server.go index 783d952..54d5dbf 100644 --- a/internal/tun2/server.go +++ b/internal/tun2/server.go @@ -3,10 +3,10 @@ package tun2 import ( "bytes" "context" - "crypto/tls" "encoding/json" "errors" "fmt" + "io" "io/ioutil" "math/rand" "net" @@ -19,7 +19,6 @@ import ( failure "github.com/dgryski/go-failure" "github.com/mtneug/pkg/ulid" cmap "github.com/streamrail/concurrent-map" - kcp "github.com/xtaci/kcp-go" "github.com/xtaci/smux" ) @@ -30,412 +29,7 @@ var ( ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it") ) -// ServerConfig ... -type ServerConfig struct { - TCPAddr string - KCPAddr string - TLSConfig *tls.Config - - SmuxConf *smux.Config - Storage Storage -} - -// Storage is the minimal subset of features that tun2's Server needs out of a -// persistence layer. -type Storage interface { - HasToken(token string) (user string, scopes []string, err error) - HasRoute(domain string) (user string, err error) -} - -// Server routes frontend HTTP traffic to backend TCP traffic. -type Server struct { - cfg *ServerConfig - - connlock sync.Mutex - conns map[net.Conn]*Connection - - domains cmap.ConcurrentMap -} - -// NewServer creates a new Server instance with a given config, acquiring all -// relevant resources. -func NewServer(cfg *ServerConfig) (*Server, error) { - if cfg == nil { - return nil, errors.New("tun2: config must be specified") - } - - if cfg.SmuxConf == nil { - cfg.SmuxConf = smux.DefaultConfig() - } - - cfg.SmuxConf.KeepAliveInterval = time.Second - cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second - - server := &Server{ - cfg: cfg, - - conns: map[net.Conn]*Connection{}, - domains: cmap.New(), - } - - return server, nil -} - -type backendMatcher func(*Connection) bool - -func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend { - s.connlock.Lock() - defer s.connlock.Unlock() - - var result []Backend - - for _, c := range s.conns { - if !bm(c) { - continue - } - - protocol := "tcp" - if c.isKCP { - protocol = "kcp" - } - - result = append(result, Backend{ - ID: c.id, - Proto: protocol, - User: c.user, - Domain: c.domain, - Phi: float32(c.detector.Phi(time.Now())), - Host: c.conn.RemoteAddr().String(), - Usable: c.usable, - }) - } - - return result -} - -func (s *Server) KillBackend(id string) error { - s.connlock.Lock() - defer s.connlock.Unlock() - - for _, c := range s.conns { - if c.id == id { - c.cancel() - return nil - } - } - - return ErrNoSuchBackend -} - -func (s *Server) GetBackendsForDomain(domain string) []Backend { - return s.getBackendsForMatcher(func(c *Connection) bool { - return c.domain == domain - }) -} - -func (s *Server) GetBackendsForUser(uname string) []Backend { - return s.getBackendsForMatcher(func(c *Connection) bool { - return c.user == uname - }) -} - -func (s *Server) GetAllBackends() []Backend { - return s.getBackendsForMatcher(func(*Connection) bool { return true }) -} - -// ListenAndServe starts the backend TCP/KCP listeners and relays backend -// traffic to and from them. -func (s *Server) ListenAndServe() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ln.Log(ctx, ln.F{ - "action": "listen_and_serve_called", - }) - - if s.cfg.TCPAddr != "" { - go func() { - l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig) - if err != nil { - panic(err) - } - - ln.Log(ctx, ln.F{ - "action": "tcp+tls_listening", - "addr": l.Addr(), - }) - - for { - conn, err := l.Accept() - if err != nil { - ln.Error(ctx, err, ln.F{"kind": "tcp", "addr": l.Addr().String()}) - continue - } - - ln.Log(ctx, ln.F{ - "action": "new_client", - "kcp": false, - "addr": conn.RemoteAddr(), - }) - - go s.HandleConn(conn, false) - } - }() - } - - if s.cfg.KCPAddr != "" { - go func() { - l, err := kcp.Listen(s.cfg.KCPAddr) - if err != nil { - panic(err) - } - - ln.Log(ctx, ln.F{ - "action": "kcp+tls_listening", - "addr": l.Addr(), - }) - - for { - conn, err := l.Accept() - if err != nil { - ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()}) - } - - ln.Log(ctx, ln.F{ - "action": "new_client", - "kcp": true, - "addr": conn.RemoteAddr(), - }) - - tc := tls.Server(conn, s.cfg.TLSConfig) - - go s.HandleConn(tc, true) - } - }() - } - - // XXX experimental, might get rid of this inside this process - go func() { - for { - time.Sleep(time.Second) - - now := time.Now() - - s.connlock.Lock() - for _, c := range s.conns { - failureChance := c.detector.Phi(now) - - if failureChance > 0.8 { - ln.Log(ctx, c.F(), ln.F{ - "action": "phi_failure_detection", - "value": failureChance, - }) - } - } - s.connlock.Unlock() - } - }() - - return nil -} - -// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from -// the currently connected backend. -func (s *Server) HandleConn(c net.Conn, isKCP bool) { - // XXX TODO clean this up it's really ugly. - defer c.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - session, err := smux.Server(c, s.cfg.SmuxConf) - if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "session_failure", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) - - c.Close() - - return - } - defer session.Close() - - controlStream, err := session.OpenStream() - if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "control_stream_failure", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) - - return - } - defer controlStream.Close() - - csd := json.NewDecoder(controlStream) - auth := &Auth{} - err = csd.Decode(auth) - if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "control_stream_auth_decoding_failure", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) - - return - } - - routeUser, err := s.cfg.Storage.HasRoute(auth.Domain) - if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "nosuch_domain", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) - - return - } - - tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token) - if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "nosuch_token", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) - - return - } - - ok := false - for _, sc := range scopes { - if sc == "connect" { - ok = true - break - } - } - - if !ok { - ln.Error(ctx, ErrAuthMismatch, ln.F{ - "action": "token_not_authorized", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) - } - - if routeUser != tokenUser { - ln.Error(ctx, ErrAuthMismatch, ln.F{ - "action": "auth_mismatch", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) - - return - } - - connection := &Connection{ - id: ulid.New().String(), - conn: c, - isKCP: isKCP, - session: session, - user: tokenUser, - domain: auth.Domain, - cf: cancel, - detector: failure.New(15, 1), - Auth: auth, - } - - defer func() { - if r := recover(); r != nil { - ln.Log(ctx, connection, ln.F{"action": "connection handler panic", "err": r}) - } - }() - - ln.Log(ctx, ln.F{ - "action": "backend_connected", - }, connection.F()) - - s.connlock.Lock() - s.conns[c] = connection - s.connlock.Unlock() - - var conns []*Connection - - val, ok := s.domains.Get(auth.Domain) - if ok { - conns, ok = val.([]*Connection) - if !ok { - conns = nil - - s.domains.Remove(auth.Domain) - } - } - - conns = append(conns, connection) - - s.domains.Set(auth.Domain, conns) - connection.usable = true - - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - err := connection.Ping() - if err != nil { - connection.cancel() - } - case <-ctx.Done(): - s.RemoveConn(ctx, connection) - connection.Close() - - return - } - } -} - -// RemoveConn removes a connection. -func (s *Server) RemoveConn(ctx context.Context, connection *Connection) { - s.connlock.Lock() - delete(s.conns, connection.conn) - s.connlock.Unlock() - - auth := connection.Auth - - var conns []*Connection - - val, ok := s.domains.Get(auth.Domain) - if ok { - conns, ok = val.([]*Connection) - if !ok { - ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection.F(), ln.F{ - "action": "looking_up_for_disconnect_removal", - }) - return - } - } - - for i, cntn := range conns { - if cntn.id == connection.id { - conns[i] = conns[len(conns)-1] - conns = conns[:len(conns)-1] - } - } - - if len(conns) != 0 { - s.domains.Set(auth.Domain, conns) - } else { - s.domains.Remove(auth.Domain) - } - - ln.Log(ctx, connection.F(), ln.F{ - "action": "client_disconnecting", - }) -} - +// gen502Page creates the page that is shown when a backend is not connected to a given route. func gen502Page(req *http.Request) *http.Response { template := `no backends connected

no backends connected

Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.

` @@ -469,6 +63,347 @@ func gen502Page(req *http.Request) *http.Response { return resp } +// ServerConfig ... +type ServerConfig struct { + SmuxConf *smux.Config + Storage Storage +} + +// Storage is the minimal subset of features that tun2's Server needs out of a +// persistence layer. +type Storage interface { + HasToken(token string) (user string, scopes []string, err error) + HasRoute(domain string) (user string, err error) +} + +// Server routes frontend HTTP traffic to backend TCP traffic. +type Server struct { + cfg *ServerConfig + ctx context.Context + cancel context.CancelFunc + + connlock sync.Mutex + conns map[net.Conn]*Connection + + domains cmap.ConcurrentMap +} + +// NewServer creates a new Server instance with a given config, acquiring all +// relevant resources. +func NewServer(cfg *ServerConfig) (*Server, error) { + if cfg == nil { + return nil, errors.New("tun2: config must be specified") + } + + if cfg.SmuxConf == nil { + cfg.SmuxConf = smux.DefaultConfig() + + cfg.SmuxConf.KeepAliveInterval = time.Second + cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second + } + + ctx, cancel := context.WithCancel(context.Background()) + + server := &Server{ + cfg: cfg, + + conns: map[net.Conn]*Connection{}, + domains: cmap.New(), + ctx: ctx, + cancel: cancel, + } + + go server.phiDetectionLoop(ctx) + + return server, nil +} + +// Close stops the background tasks for this Server. +func (s *Server) Close() { + s.cancel() +} + +// Wait blocks until the server context is cancelled. +func (s *Server) Wait() { + for { + select { + case <-s.ctx.Done(): + return + } + } +} + +// Listen passes this Server a given net.Listener to accept backend connections. +func (s *Server) Listen(l net.Listener, isKCP bool) { + ctx := s.ctx + + f := ln.F{ + "listener_addr": l.Addr(), + "listener_network": l.Addr().Network(), + } + + for { + select { + case <-ctx.Done(): + return + default: + } + + conn, err := l.Accept() + if err != nil { + ln.Error(ctx, err, f, ln.Action("accept connection")) + continue + } + + ln.Log(ctx, f, ln.Action("new backend client connected"), ln.F{ + "conn_addr": conn.RemoteAddr(), + "conn_network": conn.RemoteAddr().Network(), + }) + + go s.HandleConn(conn, isKCP) + } +} + +// phiDetectionLoop is an infinite loop that will run the [phi accrual failure detector] +// for each of the backends connected to the Server. This is fairly experimental and +// may be removed. +// +// [phi accrual failure detector]: https://dspace.jaist.ac.jp/dspace/handle/10119/4784 +func (s *Server) phiDetectionLoop(ctx context.Context) { + t := time.NewTicker(5 * time.Second) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-t.C: + now := time.Now() + + s.connlock.Lock() + for _, c := range s.conns { + failureChance := c.detector.Phi(now) + const thresh = 0.9 // the threshold for phi failure detection causing logs + + if failureChance > thresh { + ln.Log(ctx, c, ln.Action("phi failure detection"), ln.F{ + "value": failureChance, + "threshold": thresh, + }) + } + } + s.connlock.Unlock() + } + } +} + +// backendAuthv1 runs a simple backend authentication check. It expects the +// client to write a json-encoded instance of Auth. This is then checked +// for token validity and domain matching. +// +// This returns the user that was authenticated and the domain they identified +// with. +func (s *Server) backendAuthv1(ctx context.Context, st io.Reader) (string, *Auth, error) { + f := ln.F{ + "action": "backend authentication", + "backend_auth_version": 1, + } + + f["stage"] = "json decoding" + + d := json.NewDecoder(st) + var auth Auth + err := d.Decode(&auth) + if err != nil { + ln.Error(ctx, err, f) + return "", nil, err + } + + f["auth_domain"] = auth.Domain + f["stage"] = "checking domain" + + routeUser, err := s.cfg.Storage.HasRoute(auth.Domain) + if err != nil { + ln.Error(ctx, err, f) + return "", nil, err + } + + f["route_user"] = routeUser + f["stage"] = "checking token" + + tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token) + if err != nil { + ln.Error(ctx, err, f) + return "", nil, err + } + + f["token_user"] = tokenUser + f["stage"] = "checking token scopes" + + ok := false + for _, sc := range scopes { + if sc == "connect" { + ok = true + break + } + } + + if !ok { + ln.Error(ctx, ErrAuthMismatch, f) + return "", nil, ErrAuthMismatch + } + + f["stage"] = "user verification" + + if routeUser != tokenUser { + ln.Error(ctx, ErrAuthMismatch, f) + return "", nil, ErrAuthMismatch + } + + return routeUser, &auth, nil +} + +// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from +// the currently connected backend. +func (s *Server) HandleConn(c net.Conn, isKCP bool) { + defer c.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + f := ln.F{ + "local": c.LocalAddr().String(), + "remote": c.RemoteAddr().String(), + } + + session, err := smux.Server(c, s.cfg.SmuxConf) + if err != nil { + ln.Error(ctx, err, f, ln.Action("establish server side of smux")) + + return + } + defer session.Close() + + controlStream, err := session.OpenStream() + if err != nil { + ln.Error(ctx, err, f, ln.Action("opening control stream")) + + return + } + defer controlStream.Close() + + user, auth, err := s.backendAuthv1(ctx, controlStream) + if err != nil { + return + } + + connection := &Connection{ + id: ulid.New().String(), + conn: c, + isKCP: isKCP, + session: session, + user: user, + domain: auth.Domain, + cf: cancel, + detector: failure.New(15, 1), + Auth: auth, + } + + defer func() { + if r := recover(); r != nil { + ln.Log(ctx, connection, ln.F{"action": "connection handler panic", "err": r}) + } + }() + + ln.Log(ctx, connection, ln.Action("backend successfully connected")) + + s.addConn(ctx, connection) + + connection.usable = true // XXX set this to true once health checks pass? + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := connection.Ping() + if err != nil { + connection.cancel() + } + // case <-s.ctx.Done(): + // ln.Log(ctx, connection, ln.Action("server context finished")) + // s.removeConn(ctx, connection) + // connection.Close() + + // return + case <-ctx.Done(): + ln.Log(ctx, connection, ln.Action("client context finished")) + s.removeConn(ctx, connection) + connection.Close() + + return + } + } +} + +// addConn adds a connection to the pool of backend connections. +func (s *Server) addConn(ctx context.Context, connection *Connection) { + s.connlock.Lock() + s.conns[connection.conn] = connection + s.connlock.Unlock() + + var conns []*Connection + + val, ok := s.domains.Get(connection.domain) + if ok { + conns, ok = val.([]*Connection) + if !ok { + conns = nil + + s.domains.Remove(connection.domain) + } + } + + conns = append(conns, connection) + + s.domains.Set(connection.domain, conns) +} + +// removeConn removes a connection from pool of backend connections. +func (s *Server) removeConn(ctx context.Context, connection *Connection) { + s.connlock.Lock() + delete(s.conns, connection.conn) + s.connlock.Unlock() + + auth := connection.Auth + + var conns []*Connection + + val, ok := s.domains.Get(auth.Domain) + if ok { + conns, ok = val.([]*Connection) + if !ok { + ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection, ln.Action("looking up for disconnect removal")) + + return + } + } + + for i, cntn := range conns { + if cntn.id == connection.id { + conns[i] = conns[len(conns)-1] + conns = conns[:len(conns)-1] + } + } + + if len(conns) != 0 { + s.domains.Set(auth.Domain, conns) + } else { + s.domains.Remove(auth.Domain) + } +} + // RoundTrip sends a HTTP request to a backend and then returns its response. func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) { var conns []*Connection diff --git a/internal/tun2/server_test.go b/internal/tun2/server_test.go new file mode 100644 index 0000000..00a3464 --- /dev/null +++ b/internal/tun2/server_test.go @@ -0,0 +1,324 @@ +package tun2 + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/Xe/uuid" +) + +// testing constants +const ( + user = "shachi" + token = "orcaz r kewl" + noPermToken = "aw heck" + otherUserToken = "even more heck" + domain = "cetacean.club" +) + +func TestNewServerNullConfig(t *testing.T) { + _, err := NewServer(nil) + if err == nil { + t.Fatalf("expected NewServer(nil) to fail, got non-failure") + } +} + +func TestGen502Page(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req, err := http.NewRequest("GET", "http://cetacean.club", nil) + if err != nil { + t.Fatal(err) + } + + substring := uuid.New() + + req = req.WithContext(ctx) + req.Header.Add("X-Request-Id", substring) + req.Host = "cetacean.club" + + resp := gen502Page(req) + if resp == nil { + t.Fatalf("expected response to be non-nil") + } + + if resp.Body != nil { + defer resp.Body.Close() + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(string(data), substring) { + fmt.Println(string(data)) + t.Fatalf("502 page did not contain needed substring %q", substring) + } + } +} + +func TestBackendAuthV1(t *testing.T) { + st := MockStorage() + + s, err := NewServer(&ServerConfig{ + Storage: st, + }) + if err != nil { + t.Fatal(err) + } + defer s.Close() + + st.AddRoute(domain, user) + st.AddToken(token, user, []string{"connect"}) + st.AddToken(noPermToken, user, nil) + st.AddToken(otherUserToken, "cadey", []string{"connect"}) + + cases := []struct { + name string + auth Auth + wantErr bool + }{ + { + name: "basic everything should work", + auth: Auth{ + Token: token, + Domain: domain, + }, + wantErr: false, + }, + { + name: "invalid domain", + auth: Auth{ + Token: token, + Domain: "aw.heck", + }, + wantErr: true, + }, + { + name: "invalid token", + auth: Auth{ + Token: "asdfwtweg", + Domain: domain, + }, + wantErr: true, + }, + { + name: "invalid token scopes", + auth: Auth{ + Token: noPermToken, + Domain: domain, + }, + wantErr: true, + }, + { + name: "user token doesn't match domain owner", + auth: Auth{ + Token: otherUserToken, + Domain: domain, + }, + wantErr: true, + }, + } + + for _, cs := range cases { + t.Run(cs.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + data, err := json.Marshal(cs.auth) + if err != nil { + t.Fatal(err) + } + + _, _, err = s.backendAuthv1(ctx, bytes.NewBuffer(data)) + + if cs.wantErr && err == nil { + t.Fatalf("auth did not err as expected") + } + + if !cs.wantErr && err != nil { + t.Fatalf("unexpected auth err: %v", err) + } + }) + } +} + +func TestBackendRouting(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + st := MockStorage() + + st.AddRoute(domain, user) + st.AddToken(token, user, []string{"connect"}) + + s, err := NewServer(&ServerConfig{ + Storage: st, + }) + if err != nil { + t.Fatal(err) + } + defer s.Close() + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + go s.Listen(l, false) + + cases := []struct { + name string + should200 bool + handler http.HandlerFunc + }{ + { + name: "200 everything's okay", + should200: true, + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "HTTP 200, everything is okay :)", http.StatusOK) + }), + }, + } + + for _, cs := range cases { + t.Run(cs.name, func(t *testing.T) { + ts := httptest.NewServer(cs.handler) + defer ts.Close() + + cc := &ClientConfig{ + ConnType: "tcp", + ServerAddr: l.Addr().String(), + Token: token, + BackendURL: ts.URL, + Domain: domain, + + forceTCPClear: true, + } + + c, err := NewClient(cc) + if err != nil { + t.Fatal(err) + } + + go c.Connect(ctx) // TODO: fix the client library so this ends up actually getting cleaned up + + time.Sleep(time.Second) + + req, err := http.NewRequest("GET", "http://cetacean.club/", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := s.RoundTrip(req) + if err != nil { + t.Fatalf("error in doing round trip: %v", err) + } + + if cs.should200 && resp.StatusCode != http.StatusOK { + resp.Write(os.Stdout) + t.Fatalf("got status %d instead of StatusOK", resp.StatusCode) + } + }) + } +} + +func setupTestServer() (*Server, *mockStorage, net.Listener, error) { + st := MockStorage() + + st.AddRoute(domain, user) + st.AddToken(token, user, []string{"connect"}) + + s, err := NewServer(&ServerConfig{ + Storage: st, + }) + if err != nil { + return nil, nil, nil, err + } + defer s.Close() + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, nil, err + } + + go s.Listen(l, false) + + return s, st, l, nil +} + +func BenchmarkHTTP200(b *testing.B) { + b.Skip("this benchmark doesn't work yet") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s, _, l, err := setupTestServer() + if err != nil { + b.Fatal(err) + } + defer s.Close() + defer l.Close() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer ts.Close() + + cc := &ClientConfig{ + ConnType: "tcp", + ServerAddr: l.Addr().String(), + Token: token, + BackendURL: ts.URL, + Domain: domain, + + forceTCPClear: true, + } + + c, err := NewClient(cc) + if err != nil { + b.Fatal(err) + } + + go c.Connect(ctx) // TODO: fix the client library so this ends up actually getting cleaned up + + for { + r := s.GetBackendsForDomain(domain) + if len(r) == 0 { + time.Sleep(125 * time.Millisecond) + continue + } + + break + } + + req, err := http.NewRequest("GET", "http://cetacean.club/", nil) + if err != nil { + b.Fatal(err) + } + + _, err = s.RoundTrip(req) + if err != nil { + b.Fatalf("got error on initial request exchange: %v", err) + } + + for n := 0; n < b.N; n++ { + resp, err := s.RoundTrip(req) + if err != nil { + b.Fatalf("got error on %d: %v", n, err) + } + + if resp.StatusCode != http.StatusOK { + b.Fail() + b.Logf("got %d instead of 200", resp.StatusCode) + } + } +} diff --git a/internal/tun2/storage_test.go b/internal/tun2/storage_test.go new file mode 100644 index 0000000..ed50fb8 --- /dev/null +++ b/internal/tun2/storage_test.go @@ -0,0 +1,99 @@ +package tun2 + +import ( + "errors" + "sync" + "testing" +) + +func MockStorage() *mockStorage { + return &mockStorage{ + tokens: make(map[string]mockToken), + domains: make(map[string]string), + } +} + +type mockToken struct { + user string + scopes []string +} + +// mockStorage is a simple mock of the Storage interface suitable for testing. +type mockStorage struct { + sync.Mutex + tokens map[string]mockToken + domains map[string]string +} + +func (ms *mockStorage) AddToken(token, user string, scopes []string) { + ms.Lock() + defer ms.Unlock() + + ms.tokens[token] = mockToken{user: user, scopes: scopes} +} + +func (ms *mockStorage) AddRoute(domain, user string) { + ms.Lock() + defer ms.Unlock() + + ms.domains[domain] = user +} + +func (ms *mockStorage) HasToken(token string) (string, []string, error) { + ms.Lock() + defer ms.Unlock() + + tok, ok := ms.tokens[token] + if !ok { + return "", nil, errors.New("no such token") + } + + return tok.user, tok.scopes, nil +} + +func (ms *mockStorage) HasRoute(domain string) (string, error) { + ms.Lock() + defer ms.Unlock() + + user, ok := ms.domains[domain] + if !ok { + return "", errors.New("no such route") + } + + return user, nil +} + +func TestMockStorage(t *testing.T) { + ms := MockStorage() + + t.Run("token", func(t *testing.T) { + ms.AddToken(token, user, []string{"connect"}) + + us, sc, err := ms.HasToken(token) + if err != nil { + t.Fatal(err) + } + + if us != user { + t.Fatalf("username was %q, expected %q", us, user) + } + + if sc[0] != "connect" { + t.Fatalf("token expected to only have one scope, connect") + } + }) + + t.Run("domain", func(t *testing.T) { + ms.AddRoute(domain, user) + + us, err := ms.HasRoute(domain) + if err != nil { + t.Fatal(err) + } + + if us != user { + t.Fatalf("username was %q, expected %q", us, user) + } + }) + +} diff --git a/mage.go b/mage.go index 1f7a0ae..cdddede 100644 --- a/mage.go +++ b/mage.go @@ -187,8 +187,17 @@ func Package() { } } +// Version is the version as git reports. func Version() { ver, err := gitTag() qod.ANE(err) qod.Printlnf("route-%s", ver) } + +// Test runs all of the functional and unit tests for the project. +func Test() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + shouldWork(ctx, nil, wd, "go", "test", "-v", "./...") +} diff --git a/plugins/autohttpagent/main.go b/plugins/autohttpagent/main.go index d2d8de2..c8cf89d 100644 --- a/plugins/autohttpagent/main.go +++ b/plugins/autohttpagent/main.go @@ -32,7 +32,7 @@ func mustEnv(key string, def string) string { return val } -func doHttpAgent() { +func doHTTPAgent() { go func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -47,17 +47,17 @@ func doHttpAgent() { } client, _ := tun2.NewClient(cfg) - err := client.Connect() + err := client.Connect(ctx) if err != nil { ln.Error(ctx, err, ln.Action("client connection error, restarting")) time.Sleep(5 * time.Second) - doHttpAgent() + doHTTPAgent() } }() } func init() { - doHttpAgent() + doHTTPAgent() } diff --git a/plugins/autohttpagent/main_test.go b/plugins/autohttpagent/main_test.go new file mode 100644 index 0000000..38dd16d --- /dev/null +++ b/plugins/autohttpagent/main_test.go @@ -0,0 +1,3 @@ +package main + +func main() {} diff --git a/proto/client/.#client.go b/proto/client/.#client.go deleted file mode 120000 index 81f4155..0000000 --- a/proto/client/.#client.go +++ /dev/null @@ -1 +0,0 @@ -xena@greedo.xeserv.us.17867:1486865539 \ No newline at end of file