diff --git a/internal/tun2/backend.go b/internal/tun2/backend.go
index 37af1aa..30fd2e1 100644
--- a/internal/tun2/backend.go
+++ b/internal/tun2/backend.go
@@ -1,5 +1,7 @@
package tun2
+import "time"
+
// Backend is the public state of an individual Connection.
type Backend struct {
ID string
@@ -10,3 +12,72 @@ type Backend struct {
Host string
Usable bool
}
+
+type backendMatcher func(*Connection) bool
+
+func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
+ s.connlock.Lock()
+ defer s.connlock.Unlock()
+
+ var result []Backend
+
+ for _, c := range s.conns {
+ if !bm(c) {
+ continue
+ }
+
+ protocol := "tcp"
+ if c.isKCP {
+ protocol = "kcp"
+ }
+
+ result = append(result, Backend{
+ ID: c.id,
+ Proto: protocol,
+ User: c.user,
+ Domain: c.domain,
+ Phi: float32(c.detector.Phi(time.Now())),
+ Host: c.conn.RemoteAddr().String(),
+ Usable: c.usable,
+ })
+ }
+
+ return result
+}
+
+// KillBackend forcibly disconnects a given backend but doesn't offer a way to
+// "ban" it from reconnecting.
+func (s *Server) KillBackend(id string) error {
+ s.connlock.Lock()
+ defer s.connlock.Unlock()
+
+ for _, c := range s.conns {
+ if c.id == id {
+ c.cancel()
+ return nil
+ }
+ }
+
+ return ErrNoSuchBackend
+}
+
+// GetBackendsForDomain fetches all backends connected to this server associated
+// to a single public domain name.
+func (s *Server) GetBackendsForDomain(domain string) []Backend {
+ return s.getBackendsForMatcher(func(c *Connection) bool {
+ return c.domain == domain
+ })
+}
+
+// GetBackendsForUser fetches all backends connected to this server owned by a
+// given user by username.
+func (s *Server) GetBackendsForUser(uname string) []Backend {
+ return s.getBackendsForMatcher(func(c *Connection) bool {
+ return c.user == uname
+ })
+}
+
+// GetAllBackends fetches every backend connected to this server.
+func (s *Server) GetAllBackends() []Backend {
+ return s.getBackendsForMatcher(func(*Connection) bool { return true })
+}
diff --git a/internal/tun2/client.go b/internal/tun2/client.go
index adb1c8b..8c38c5c 100644
--- a/internal/tun2/client.go
+++ b/internal/tun2/client.go
@@ -14,10 +14,14 @@ import (
"github.com/xtaci/smux"
)
+// Client connects to a remote tun2 server and sets up authentication before routing
+// individual HTTP requests to discrete streams that are reverse proxied to the eventual
+// backend.
type Client struct {
cfg *ClientConfig
}
+// ClientConfig configures client with settings that the user provides.
type ClientConfig struct {
TLSConfig *tls.Config
ConnType string
@@ -27,6 +31,7 @@ type ClientConfig struct {
BackendURL string
}
+// NewClient constructs an instance of Client with a given ClientConfig.
func NewClient(cfg *ClientConfig) (*Client, error) {
if cfg == nil {
return nil, errors.New("tun2: client config needed")
@@ -39,6 +44,11 @@ func NewClient(cfg *ClientConfig) (*Client, error) {
return c, nil
}
+// Connect dials the remote server and negotiates a client session with its
+// configured server address. This will then continuously proxy incoming HTTP
+// requests to the backend HTTP server.
+//
+// This is a blocking function.
func (c *Client) Connect() error {
return c.connect(c.cfg.ServerAddr)
}
@@ -117,15 +127,12 @@ func (c *Client) connect(serverAddr string) error {
return nil
}
+// smuxListener wraps a smux session as a net.Listener.
type smuxListener struct {
conn net.Conn
session *smux.Session
}
-var (
- _ net.Listener = &smuxListener{} // interface check
-)
-
func (sl *smuxListener) Accept() (net.Conn, error) {
return sl.session.AcceptStream()
}
diff --git a/internal/tun2/client_test.go b/internal/tun2/client_test.go
new file mode 100644
index 0000000..d3127a7
--- /dev/null
+++ b/internal/tun2/client_test.go
@@ -0,0 +1,21 @@
+package tun2
+
+import (
+ "net"
+ "testing"
+)
+
+func TestNewClientNullConfig(t *testing.T) {
+ _, err := NewClient(nil)
+ if err == nil {
+ t.Fatalf("expected NewClient(nil) to fail, got non-failure")
+ }
+}
+
+func TestSmuxListenerIsNetListener(t *testing.T) {
+ var sl interface{} = &smuxListener{}
+ _, ok := sl.(net.Listener)
+ if !ok {
+ t.Fatalf("smuxListener does not implement net.Listener")
+ }
+}
diff --git a/internal/tun2/server.go b/internal/tun2/server.go
index 783d952..332ab4b 100644
--- a/internal/tun2/server.go
+++ b/internal/tun2/server.go
@@ -30,6 +30,40 @@ var (
ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it")
)
+// gen502Page creates the page that is shown when a backend is not connected to a given route.
+func gen502Page(req *http.Request) *http.Response {
+ template := `
no backends connectedno backends connected
Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.
`
+
+ resbody := []byte(os.Expand(template, func(in string) string {
+ switch in {
+ case "HOST":
+ return req.Host
+ case "REQ_ID":
+ return req.Header.Get("X-Request-Id")
+ }
+
+ return ""
+ }))
+ reshdr := req.Header
+ reshdr.Set("Content-Type", "text/html; charset=utf-8")
+
+ resp := &http.Response{
+ Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway),
+ StatusCode: http.StatusBadGateway,
+ Body: ioutil.NopCloser(bytes.NewBuffer(resbody)),
+
+ Proto: req.Proto,
+ ProtoMajor: req.ProtoMajor,
+ ProtoMinor: req.ProtoMinor,
+ Header: reshdr,
+ ContentLength: int64(len(resbody)),
+ Close: true,
+ Request: req,
+ }
+
+ return resp
+}
+
// ServerConfig ...
type ServerConfig struct {
TCPAddr string
@@ -81,66 +115,29 @@ func NewServer(cfg *ServerConfig) (*Server, error) {
return server, nil
}
-type backendMatcher func(*Connection) bool
+// Listen passes this Server a given net.Listener to accept backend connections.
+func (s *Server) Listen(l net.Listener, isKCP bool) {
+ ctx := context.Background()
-func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
- s.connlock.Lock()
- defer s.connlock.Unlock()
-
- var result []Backend
-
- for _, c := range s.conns {
- if !bm(c) {
+ for {
+ conn, err := l.Accept()
+ if err != nil {
+ ln.Error(ctx, err, ln.F{
+ "addr": l.Addr().String(),
+ "network": l.Addr().Network(),
+ })
continue
}
- protocol := "tcp"
- if c.isKCP {
- protocol = "kcp"
- }
-
- result = append(result, Backend{
- ID: c.id,
- Proto: protocol,
- User: c.user,
- Domain: c.domain,
- Phi: float32(c.detector.Phi(time.Now())),
- Host: c.conn.RemoteAddr().String(),
- Usable: c.usable,
+ ln.Log(ctx, ln.F{
+ "action": "new_client",
+ "network": conn.RemoteAddr().Network(),
+ "addr": conn.RemoteAddr(),
+ "list": conn.LocalAddr(),
})
+
+ go s.HandleConn(conn, isKCP)
}
-
- return result
-}
-
-func (s *Server) KillBackend(id string) error {
- s.connlock.Lock()
- defer s.connlock.Unlock()
-
- for _, c := range s.conns {
- if c.id == id {
- c.cancel()
- return nil
- }
- }
-
- return ErrNoSuchBackend
-}
-
-func (s *Server) GetBackendsForDomain(domain string) []Backend {
- return s.getBackendsForMatcher(func(c *Connection) bool {
- return c.domain == domain
- })
-}
-
-func (s *Server) GetBackendsForUser(uname string) []Backend {
- return s.getBackendsForMatcher(func(c *Connection) bool {
- return c.user == uname
- })
-}
-
-func (s *Server) GetAllBackends() []Backend {
- return s.getBackendsForMatcher(func(*Connection) bool { return true })
}
// ListenAndServe starts the backend TCP/KCP listeners and relays backend
@@ -248,67 +245,62 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
+ f := ln.F{
+ "local": c.LocalAddr().String(),
+ "remote": c.RemoteAddr().String(),
+ }
+
session, err := smux.Server(c, s.cfg.SmuxConf)
if err != nil {
- ln.Error(ctx, err, ln.F{
- "action": "session_failure",
- "local": c.LocalAddr().String(),
- "remote": c.RemoteAddr().String(),
- })
-
- c.Close()
+ ln.Error(ctx, err, f, ln.Action("establish server side of smux"))
return
}
defer session.Close()
+ f["stage"] = "smux_setup"
+
controlStream, err := session.OpenStream()
if err != nil {
- ln.Error(ctx, err, ln.F{
- "action": "control_stream_failure",
- "local": c.LocalAddr().String(),
- "remote": c.RemoteAddr().String(),
- })
+ ln.Error(ctx, err, f, ln.Action("opening control stream"))
return
}
defer controlStream.Close()
+ f["stage"] = "control_stream_open"
+
csd := json.NewDecoder(controlStream)
auth := &Auth{}
err = csd.Decode(auth)
if err != nil {
- ln.Error(ctx, err, ln.F{
- "action": "control_stream_auth_decoding_failure",
- "local": c.LocalAddr().String(),
- "remote": c.RemoteAddr().String(),
- })
+ ln.Error(ctx, err, f, ln.Action("decode control stream authenication message"))
return
}
+ f["stage"] = "checking_domain"
+
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
if err != nil {
- ln.Error(ctx, err, ln.F{
- "action": "nosuch_domain",
- "local": c.LocalAddr().String(),
- "remote": c.RemoteAddr().String(),
- })
+ ln.Error(ctx, err, f, ln.Action("no such domain when checking client auth"))
return
}
+ f["route_user"] = routeUser
+ f["stage"] = "checking_token"
+
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
if err != nil {
- ln.Error(ctx, err, ln.F{
- "action": "nosuch_token",
- "local": c.LocalAddr().String(),
- "remote": c.RemoteAddr().String(),
- })
+ ln.Error(ctx, err, f, ln.Action("no such token exists or other token error"))
return
}
+ f["token_user"] = tokenUser
+ f["stage"] = "checking_token_scopes"
+
ok := false
for _, sc := range scopes {
if sc == "connect" {
@@ -318,19 +310,15 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
}
if !ok {
- ln.Error(ctx, ErrAuthMismatch, ln.F{
- "action": "token_not_authorized",
- "local": c.LocalAddr().String(),
- "remote": c.RemoteAddr().String(),
- })
+ ln.Error(ctx, ErrAuthMismatch, f, ln.Action("token not authorized to connect"))
+
+ return
}
+ f["stage"] = "user_verification"
+
if routeUser != tokenUser {
- ln.Error(ctx, ErrAuthMismatch, ln.F{
- "action": "auth_mismatch",
- "local": c.LocalAddr().String(),
- "remote": c.RemoteAddr().String(),
- })
+ ln.Error(ctx, ErrAuthMismatch, f, ln.Action("auth mismatch"))
return
}
@@ -353,10 +341,9 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
}
}()
- ln.Log(ctx, ln.F{
- "action": "backend_connected",
- }, connection.F())
+ ln.Log(ctx, connection, ln.Action("backend successfully connected"))
+ // TODO: put these lines in a function?
s.connlock.Lock()
s.conns[c] = connection
s.connlock.Unlock()
@@ -376,7 +363,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
conns = append(conns, connection)
s.domains.Set(auth.Domain, conns)
- connection.usable = true
+ connection.usable = true // XXX set this to true once health checks pass?
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
@@ -411,9 +398,8 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
if ok {
conns, ok = val.([]*Connection)
if !ok {
- ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection.F(), ln.F{
- "action": "looking_up_for_disconnect_removal",
- })
+ ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection, ln.Action("looking up for disconnect removal"))
+
return
}
}
@@ -431,42 +417,7 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
s.domains.Remove(auth.Domain)
}
- ln.Log(ctx, connection.F(), ln.F{
- "action": "client_disconnecting",
- })
-}
-
-func gen502Page(req *http.Request) *http.Response {
- template := `no backends connectedno backends connected
Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.
`
-
- resbody := []byte(os.Expand(template, func(in string) string {
- switch in {
- case "HOST":
- return req.Host
- case "REQ_ID":
- return req.Header.Get("X-Request-Id")
- }
-
- return ""
- }))
- reshdr := req.Header
- reshdr.Set("Content-Type", "text/html; charset=utf-8")
-
- resp := &http.Response{
- Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway),
- StatusCode: http.StatusBadGateway,
- Body: ioutil.NopCloser(bytes.NewBuffer(resbody)),
-
- Proto: req.Proto,
- ProtoMajor: req.ProtoMajor,
- ProtoMinor: req.ProtoMinor,
- Header: reshdr,
- ContentLength: int64(len(resbody)),
- Close: true,
- Request: req,
- }
-
- return resp
+ ln.Log(ctx, connection, ln.Action("backend disconnect"))
}
// RoundTrip sends a HTTP request to a backend and then returns its response.
diff --git a/internal/tun2/server_test.go b/internal/tun2/server_test.go
new file mode 100644
index 0000000..3f37c4e
--- /dev/null
+++ b/internal/tun2/server_test.go
@@ -0,0 +1,53 @@
+package tun2
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/Xe/uuid"
+)
+
+func TestNewServerNullConfig(t *testing.T) {
+ _, err := NewServer(nil)
+ if err == nil {
+ t.Fatalf("expected NewServer(nil) to fail, got non-failure")
+ }
+}
+
+func TestGen502Page(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ req, err := http.NewRequest("GET", "http://cetacean.club", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ substring := uuid.New()
+
+ req = req.WithContext(ctx)
+ req.Header.Add("X-Request-Id", substring)
+ req.Host = "cetacean.club"
+
+ resp := gen502Page(req)
+ if resp == nil {
+ t.Fatalf("expected response to be non-nil")
+ }
+
+ if resp.Body != nil {
+ defer resp.Body.Close()
+ data, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !strings.Contains(string(data), substring) {
+ fmt.Println(string(data))
+ t.Fatalf("502 page did not contain needed substring %q", substring)
+ }
+ }
+}