From 5afb3715ccbfa40d39e51c9b2bbbe743ce269cf5 Mon Sep 17 00:00:00 2001 From: Christine Dodrill Date: Tue, 3 Oct 2017 13:20:23 -0700 Subject: [PATCH] tun2: documentation and unit tests --- internal/tun2/backend.go | 71 ++++++++++++ internal/tun2/client.go | 15 ++- internal/tun2/client_test.go | 21 ++++ internal/tun2/server.go | 219 ++++++++++++++--------------------- internal/tun2/server_test.go | 53 +++++++++ 5 files changed, 241 insertions(+), 138 deletions(-) create mode 100644 internal/tun2/client_test.go create mode 100644 internal/tun2/server_test.go diff --git a/internal/tun2/backend.go b/internal/tun2/backend.go index 37af1aa..30fd2e1 100644 --- a/internal/tun2/backend.go +++ b/internal/tun2/backend.go @@ -1,5 +1,7 @@ package tun2 +import "time" + // Backend is the public state of an individual Connection. type Backend struct { ID string @@ -10,3 +12,72 @@ type Backend struct { Host string Usable bool } + +type backendMatcher func(*Connection) bool + +func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend { + s.connlock.Lock() + defer s.connlock.Unlock() + + var result []Backend + + for _, c := range s.conns { + if !bm(c) { + continue + } + + protocol := "tcp" + if c.isKCP { + protocol = "kcp" + } + + result = append(result, Backend{ + ID: c.id, + Proto: protocol, + User: c.user, + Domain: c.domain, + Phi: float32(c.detector.Phi(time.Now())), + Host: c.conn.RemoteAddr().String(), + Usable: c.usable, + }) + } + + return result +} + +// KillBackend forcibly disconnects a given backend but doesn't offer a way to +// "ban" it from reconnecting. +func (s *Server) KillBackend(id string) error { + s.connlock.Lock() + defer s.connlock.Unlock() + + for _, c := range s.conns { + if c.id == id { + c.cancel() + return nil + } + } + + return ErrNoSuchBackend +} + +// GetBackendsForDomain fetches all backends connected to this server associated +// to a single public domain name. +func (s *Server) GetBackendsForDomain(domain string) []Backend { + return s.getBackendsForMatcher(func(c *Connection) bool { + return c.domain == domain + }) +} + +// GetBackendsForUser fetches all backends connected to this server owned by a +// given user by username. +func (s *Server) GetBackendsForUser(uname string) []Backend { + return s.getBackendsForMatcher(func(c *Connection) bool { + return c.user == uname + }) +} + +// GetAllBackends fetches every backend connected to this server. +func (s *Server) GetAllBackends() []Backend { + return s.getBackendsForMatcher(func(*Connection) bool { return true }) +} diff --git a/internal/tun2/client.go b/internal/tun2/client.go index adb1c8b..8c38c5c 100644 --- a/internal/tun2/client.go +++ b/internal/tun2/client.go @@ -14,10 +14,14 @@ import ( "github.com/xtaci/smux" ) +// Client connects to a remote tun2 server and sets up authentication before routing +// individual HTTP requests to discrete streams that are reverse proxied to the eventual +// backend. type Client struct { cfg *ClientConfig } +// ClientConfig configures client with settings that the user provides. type ClientConfig struct { TLSConfig *tls.Config ConnType string @@ -27,6 +31,7 @@ type ClientConfig struct { BackendURL string } +// NewClient constructs an instance of Client with a given ClientConfig. func NewClient(cfg *ClientConfig) (*Client, error) { if cfg == nil { return nil, errors.New("tun2: client config needed") @@ -39,6 +44,11 @@ func NewClient(cfg *ClientConfig) (*Client, error) { return c, nil } +// Connect dials the remote server and negotiates a client session with its +// configured server address. This will then continuously proxy incoming HTTP +// requests to the backend HTTP server. +// +// This is a blocking function. func (c *Client) Connect() error { return c.connect(c.cfg.ServerAddr) } @@ -117,15 +127,12 @@ func (c *Client) connect(serverAddr string) error { return nil } +// smuxListener wraps a smux session as a net.Listener. type smuxListener struct { conn net.Conn session *smux.Session } -var ( - _ net.Listener = &smuxListener{} // interface check -) - func (sl *smuxListener) Accept() (net.Conn, error) { return sl.session.AcceptStream() } diff --git a/internal/tun2/client_test.go b/internal/tun2/client_test.go new file mode 100644 index 0000000..d3127a7 --- /dev/null +++ b/internal/tun2/client_test.go @@ -0,0 +1,21 @@ +package tun2 + +import ( + "net" + "testing" +) + +func TestNewClientNullConfig(t *testing.T) { + _, err := NewClient(nil) + if err == nil { + t.Fatalf("expected NewClient(nil) to fail, got non-failure") + } +} + +func TestSmuxListenerIsNetListener(t *testing.T) { + var sl interface{} = &smuxListener{} + _, ok := sl.(net.Listener) + if !ok { + t.Fatalf("smuxListener does not implement net.Listener") + } +} diff --git a/internal/tun2/server.go b/internal/tun2/server.go index 783d952..332ab4b 100644 --- a/internal/tun2/server.go +++ b/internal/tun2/server.go @@ -30,6 +30,40 @@ var ( ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it") ) +// gen502Page creates the page that is shown when a backend is not connected to a given route. +func gen502Page(req *http.Request) *http.Response { + template := `no backends connected

no backends connected

Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.

` + + resbody := []byte(os.Expand(template, func(in string) string { + switch in { + case "HOST": + return req.Host + case "REQ_ID": + return req.Header.Get("X-Request-Id") + } + + return "" + })) + reshdr := req.Header + reshdr.Set("Content-Type", "text/html; charset=utf-8") + + resp := &http.Response{ + Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway), + StatusCode: http.StatusBadGateway, + Body: ioutil.NopCloser(bytes.NewBuffer(resbody)), + + Proto: req.Proto, + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Header: reshdr, + ContentLength: int64(len(resbody)), + Close: true, + Request: req, + } + + return resp +} + // ServerConfig ... type ServerConfig struct { TCPAddr string @@ -81,66 +115,29 @@ func NewServer(cfg *ServerConfig) (*Server, error) { return server, nil } -type backendMatcher func(*Connection) bool +// Listen passes this Server a given net.Listener to accept backend connections. +func (s *Server) Listen(l net.Listener, isKCP bool) { + ctx := context.Background() -func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend { - s.connlock.Lock() - defer s.connlock.Unlock() - - var result []Backend - - for _, c := range s.conns { - if !bm(c) { + for { + conn, err := l.Accept() + if err != nil { + ln.Error(ctx, err, ln.F{ + "addr": l.Addr().String(), + "network": l.Addr().Network(), + }) continue } - protocol := "tcp" - if c.isKCP { - protocol = "kcp" - } - - result = append(result, Backend{ - ID: c.id, - Proto: protocol, - User: c.user, - Domain: c.domain, - Phi: float32(c.detector.Phi(time.Now())), - Host: c.conn.RemoteAddr().String(), - Usable: c.usable, + ln.Log(ctx, ln.F{ + "action": "new_client", + "network": conn.RemoteAddr().Network(), + "addr": conn.RemoteAddr(), + "list": conn.LocalAddr(), }) + + go s.HandleConn(conn, isKCP) } - - return result -} - -func (s *Server) KillBackend(id string) error { - s.connlock.Lock() - defer s.connlock.Unlock() - - for _, c := range s.conns { - if c.id == id { - c.cancel() - return nil - } - } - - return ErrNoSuchBackend -} - -func (s *Server) GetBackendsForDomain(domain string) []Backend { - return s.getBackendsForMatcher(func(c *Connection) bool { - return c.domain == domain - }) -} - -func (s *Server) GetBackendsForUser(uname string) []Backend { - return s.getBackendsForMatcher(func(c *Connection) bool { - return c.user == uname - }) -} - -func (s *Server) GetAllBackends() []Backend { - return s.getBackendsForMatcher(func(*Connection) bool { return true }) } // ListenAndServe starts the backend TCP/KCP listeners and relays backend @@ -248,67 +245,62 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + f := ln.F{ + "local": c.LocalAddr().String(), + "remote": c.RemoteAddr().String(), + } + session, err := smux.Server(c, s.cfg.SmuxConf) if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "session_failure", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) - - c.Close() + ln.Error(ctx, err, f, ln.Action("establish server side of smux")) return } defer session.Close() + f["stage"] = "smux_setup" + controlStream, err := session.OpenStream() if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "control_stream_failure", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) + ln.Error(ctx, err, f, ln.Action("opening control stream")) return } defer controlStream.Close() + f["stage"] = "control_stream_open" + csd := json.NewDecoder(controlStream) auth := &Auth{} err = csd.Decode(auth) if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "control_stream_auth_decoding_failure", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) + ln.Error(ctx, err, f, ln.Action("decode control stream authenication message")) return } + f["stage"] = "checking_domain" + routeUser, err := s.cfg.Storage.HasRoute(auth.Domain) if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "nosuch_domain", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) + ln.Error(ctx, err, f, ln.Action("no such domain when checking client auth")) return } + f["route_user"] = routeUser + f["stage"] = "checking_token" + tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token) if err != nil { - ln.Error(ctx, err, ln.F{ - "action": "nosuch_token", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) + ln.Error(ctx, err, f, ln.Action("no such token exists or other token error")) return } + f["token_user"] = tokenUser + f["stage"] = "checking_token_scopes" + ok := false for _, sc := range scopes { if sc == "connect" { @@ -318,19 +310,15 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { } if !ok { - ln.Error(ctx, ErrAuthMismatch, ln.F{ - "action": "token_not_authorized", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) + ln.Error(ctx, ErrAuthMismatch, f, ln.Action("token not authorized to connect")) + + return } + f["stage"] = "user_verification" + if routeUser != tokenUser { - ln.Error(ctx, ErrAuthMismatch, ln.F{ - "action": "auth_mismatch", - "local": c.LocalAddr().String(), - "remote": c.RemoteAddr().String(), - }) + ln.Error(ctx, ErrAuthMismatch, f, ln.Action("auth mismatch")) return } @@ -353,10 +341,9 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { } }() - ln.Log(ctx, ln.F{ - "action": "backend_connected", - }, connection.F()) + ln.Log(ctx, connection, ln.Action("backend successfully connected")) + // TODO: put these lines in a function? s.connlock.Lock() s.conns[c] = connection s.connlock.Unlock() @@ -376,7 +363,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { conns = append(conns, connection) s.domains.Set(auth.Domain, conns) - connection.usable = true + connection.usable = true // XXX set this to true once health checks pass? ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -411,9 +398,8 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) { if ok { conns, ok = val.([]*Connection) if !ok { - ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection.F(), ln.F{ - "action": "looking_up_for_disconnect_removal", - }) + ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection, ln.Action("looking up for disconnect removal")) + return } } @@ -431,42 +417,7 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) { s.domains.Remove(auth.Domain) } - ln.Log(ctx, connection.F(), ln.F{ - "action": "client_disconnecting", - }) -} - -func gen502Page(req *http.Request) *http.Response { - template := `no backends connected

no backends connected

Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.

` - - resbody := []byte(os.Expand(template, func(in string) string { - switch in { - case "HOST": - return req.Host - case "REQ_ID": - return req.Header.Get("X-Request-Id") - } - - return "" - })) - reshdr := req.Header - reshdr.Set("Content-Type", "text/html; charset=utf-8") - - resp := &http.Response{ - Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway), - StatusCode: http.StatusBadGateway, - Body: ioutil.NopCloser(bytes.NewBuffer(resbody)), - - Proto: req.Proto, - ProtoMajor: req.ProtoMajor, - ProtoMinor: req.ProtoMinor, - Header: reshdr, - ContentLength: int64(len(resbody)), - Close: true, - Request: req, - } - - return resp + ln.Log(ctx, connection, ln.Action("backend disconnect")) } // RoundTrip sends a HTTP request to a backend and then returns its response. diff --git a/internal/tun2/server_test.go b/internal/tun2/server_test.go new file mode 100644 index 0000000..3f37c4e --- /dev/null +++ b/internal/tun2/server_test.go @@ -0,0 +1,53 @@ +package tun2 + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Xe/uuid" +) + +func TestNewServerNullConfig(t *testing.T) { + _, err := NewServer(nil) + if err == nil { + t.Fatalf("expected NewServer(nil) to fail, got non-failure") + } +} + +func TestGen502Page(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req, err := http.NewRequest("GET", "http://cetacean.club", nil) + if err != nil { + t.Fatal(err) + } + + substring := uuid.New() + + req = req.WithContext(ctx) + req.Header.Add("X-Request-Id", substring) + req.Host = "cetacean.club" + + resp := gen502Page(req) + if resp == nil { + t.Fatalf("expected response to be non-nil") + } + + if resp.Body != nil { + defer resp.Body.Close() + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(string(data), substring) { + fmt.Println(string(data)) + t.Fatalf("502 page did not contain needed substring %q", substring) + } + } +}