tun2: documentation and unit tests

This commit is contained in:
Cadey Ratio 2017-10-03 13:20:23 -07:00
parent 0a7f7a4652
commit 5afb3715cc
5 changed files with 241 additions and 138 deletions

View File

@ -1,5 +1,7 @@
package tun2
import "time"
// Backend is the public state of an individual Connection.
type Backend struct {
ID string
@ -10,3 +12,72 @@ type Backend struct {
Host string
Usable bool
}
type backendMatcher func(*Connection) bool
func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
s.connlock.Lock()
defer s.connlock.Unlock()
var result []Backend
for _, c := range s.conns {
if !bm(c) {
continue
}
protocol := "tcp"
if c.isKCP {
protocol = "kcp"
}
result = append(result, Backend{
ID: c.id,
Proto: protocol,
User: c.user,
Domain: c.domain,
Phi: float32(c.detector.Phi(time.Now())),
Host: c.conn.RemoteAddr().String(),
Usable: c.usable,
})
}
return result
}
// KillBackend forcibly disconnects a given backend but doesn't offer a way to
// "ban" it from reconnecting.
func (s *Server) KillBackend(id string) error {
s.connlock.Lock()
defer s.connlock.Unlock()
for _, c := range s.conns {
if c.id == id {
c.cancel()
return nil
}
}
return ErrNoSuchBackend
}
// GetBackendsForDomain fetches all backends connected to this server associated
// to a single public domain name.
func (s *Server) GetBackendsForDomain(domain string) []Backend {
return s.getBackendsForMatcher(func(c *Connection) bool {
return c.domain == domain
})
}
// GetBackendsForUser fetches all backends connected to this server owned by a
// given user by username.
func (s *Server) GetBackendsForUser(uname string) []Backend {
return s.getBackendsForMatcher(func(c *Connection) bool {
return c.user == uname
})
}
// GetAllBackends fetches every backend connected to this server.
func (s *Server) GetAllBackends() []Backend {
return s.getBackendsForMatcher(func(*Connection) bool { return true })
}

View File

@ -14,10 +14,14 @@ import (
"github.com/xtaci/smux"
)
// Client connects to a remote tun2 server and sets up authentication before routing
// individual HTTP requests to discrete streams that are reverse proxied to the eventual
// backend.
type Client struct {
cfg *ClientConfig
}
// ClientConfig configures client with settings that the user provides.
type ClientConfig struct {
TLSConfig *tls.Config
ConnType string
@ -27,6 +31,7 @@ type ClientConfig struct {
BackendURL string
}
// NewClient constructs an instance of Client with a given ClientConfig.
func NewClient(cfg *ClientConfig) (*Client, error) {
if cfg == nil {
return nil, errors.New("tun2: client config needed")
@ -39,6 +44,11 @@ func NewClient(cfg *ClientConfig) (*Client, error) {
return c, nil
}
// Connect dials the remote server and negotiates a client session with its
// configured server address. This will then continuously proxy incoming HTTP
// requests to the backend HTTP server.
//
// This is a blocking function.
func (c *Client) Connect() error {
return c.connect(c.cfg.ServerAddr)
}
@ -117,15 +127,12 @@ func (c *Client) connect(serverAddr string) error {
return nil
}
// smuxListener wraps a smux session as a net.Listener.
type smuxListener struct {
conn net.Conn
session *smux.Session
}
var (
_ net.Listener = &smuxListener{} // interface check
)
func (sl *smuxListener) Accept() (net.Conn, error) {
return sl.session.AcceptStream()
}

View File

@ -0,0 +1,21 @@
package tun2
import (
"net"
"testing"
)
func TestNewClientNullConfig(t *testing.T) {
_, err := NewClient(nil)
if err == nil {
t.Fatalf("expected NewClient(nil) to fail, got non-failure")
}
}
func TestSmuxListenerIsNetListener(t *testing.T) {
var sl interface{} = &smuxListener{}
_, ok := sl.(net.Listener)
if !ok {
t.Fatalf("smuxListener does not implement net.Listener")
}
}

View File

@ -30,6 +30,40 @@ var (
ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it")
)
// gen502Page creates the page that is shown when a backend is not connected to a given route.
func gen502Page(req *http.Request) *http.Response {
template := `<html><head><title>no backends connected</title></head><body><h1>no backends connected</h1><p>Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.</p></body></html>`
resbody := []byte(os.Expand(template, func(in string) string {
switch in {
case "HOST":
return req.Host
case "REQ_ID":
return req.Header.Get("X-Request-Id")
}
return "<unknown>"
}))
reshdr := req.Header
reshdr.Set("Content-Type", "text/html; charset=utf-8")
resp := &http.Response{
Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Body: ioutil.NopCloser(bytes.NewBuffer(resbody)),
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: reshdr,
ContentLength: int64(len(resbody)),
Close: true,
Request: req,
}
return resp
}
// ServerConfig ...
type ServerConfig struct {
TCPAddr string
@ -81,66 +115,29 @@ func NewServer(cfg *ServerConfig) (*Server, error) {
return server, nil
}
type backendMatcher func(*Connection) bool
// Listen passes this Server a given net.Listener to accept backend connections.
func (s *Server) Listen(l net.Listener, isKCP bool) {
ctx := context.Background()
func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
s.connlock.Lock()
defer s.connlock.Unlock()
var result []Backend
for _, c := range s.conns {
if !bm(c) {
for {
conn, err := l.Accept()
if err != nil {
ln.Error(ctx, err, ln.F{
"addr": l.Addr().String(),
"network": l.Addr().Network(),
})
continue
}
protocol := "tcp"
if c.isKCP {
protocol = "kcp"
}
result = append(result, Backend{
ID: c.id,
Proto: protocol,
User: c.user,
Domain: c.domain,
Phi: float32(c.detector.Phi(time.Now())),
Host: c.conn.RemoteAddr().String(),
Usable: c.usable,
ln.Log(ctx, ln.F{
"action": "new_client",
"network": conn.RemoteAddr().Network(),
"addr": conn.RemoteAddr(),
"list": conn.LocalAddr(),
})
go s.HandleConn(conn, isKCP)
}
return result
}
func (s *Server) KillBackend(id string) error {
s.connlock.Lock()
defer s.connlock.Unlock()
for _, c := range s.conns {
if c.id == id {
c.cancel()
return nil
}
}
return ErrNoSuchBackend
}
func (s *Server) GetBackendsForDomain(domain string) []Backend {
return s.getBackendsForMatcher(func(c *Connection) bool {
return c.domain == domain
})
}
func (s *Server) GetBackendsForUser(uname string) []Backend {
return s.getBackendsForMatcher(func(c *Connection) bool {
return c.user == uname
})
}
func (s *Server) GetAllBackends() []Backend {
return s.getBackendsForMatcher(func(*Connection) bool { return true })
}
// ListenAndServe starts the backend TCP/KCP listeners and relays backend
@ -248,67 +245,62 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
f := ln.F{
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
}
session, err := smux.Server(c, s.cfg.SmuxConf)
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "session_failure",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
c.Close()
ln.Error(ctx, err, f, ln.Action("establish server side of smux"))
return
}
defer session.Close()
f["stage"] = "smux_setup"
controlStream, err := session.OpenStream()
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "control_stream_failure",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
ln.Error(ctx, err, f, ln.Action("opening control stream"))
return
}
defer controlStream.Close()
f["stage"] = "control_stream_open"
csd := json.NewDecoder(controlStream)
auth := &Auth{}
err = csd.Decode(auth)
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "control_stream_auth_decoding_failure",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
ln.Error(ctx, err, f, ln.Action("decode control stream authenication message"))
return
}
f["stage"] = "checking_domain"
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "nosuch_domain",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
ln.Error(ctx, err, f, ln.Action("no such domain when checking client auth"))
return
}
f["route_user"] = routeUser
f["stage"] = "checking_token"
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
if err != nil {
ln.Error(ctx, err, ln.F{
"action": "nosuch_token",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
ln.Error(ctx, err, f, ln.Action("no such token exists or other token error"))
return
}
f["token_user"] = tokenUser
f["stage"] = "checking_token_scopes"
ok := false
for _, sc := range scopes {
if sc == "connect" {
@ -318,19 +310,15 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
}
if !ok {
ln.Error(ctx, ErrAuthMismatch, ln.F{
"action": "token_not_authorized",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
ln.Error(ctx, ErrAuthMismatch, f, ln.Action("token not authorized to connect"))
return
}
f["stage"] = "user_verification"
if routeUser != tokenUser {
ln.Error(ctx, ErrAuthMismatch, ln.F{
"action": "auth_mismatch",
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
})
ln.Error(ctx, ErrAuthMismatch, f, ln.Action("auth mismatch"))
return
}
@ -353,10 +341,9 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
}
}()
ln.Log(ctx, ln.F{
"action": "backend_connected",
}, connection.F())
ln.Log(ctx, connection, ln.Action("backend successfully connected"))
// TODO: put these lines in a function?
s.connlock.Lock()
s.conns[c] = connection
s.connlock.Unlock()
@ -376,7 +363,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
conns = append(conns, connection)
s.domains.Set(auth.Domain, conns)
connection.usable = true
connection.usable = true // XXX set this to true once health checks pass?
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
@ -411,9 +398,8 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
if ok {
conns, ok = val.([]*Connection)
if !ok {
ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection.F(), ln.F{
"action": "looking_up_for_disconnect_removal",
})
ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection, ln.Action("looking up for disconnect removal"))
return
}
}
@ -431,42 +417,7 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
s.domains.Remove(auth.Domain)
}
ln.Log(ctx, connection.F(), ln.F{
"action": "client_disconnecting",
})
}
func gen502Page(req *http.Request) *http.Response {
template := `<html><head><title>no backends connected</title></head><body><h1>no backends connected</h1><p>Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.</p></body></html>`
resbody := []byte(os.Expand(template, func(in string) string {
switch in {
case "HOST":
return req.Host
case "REQ_ID":
return req.Header.Get("X-Request-Id")
}
return "<unknown>"
}))
reshdr := req.Header
reshdr.Set("Content-Type", "text/html; charset=utf-8")
resp := &http.Response{
Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Body: ioutil.NopCloser(bytes.NewBuffer(resbody)),
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: reshdr,
ContentLength: int64(len(resbody)),
Close: true,
Request: req,
}
return resp
ln.Log(ctx, connection, ln.Action("backend disconnect"))
}
// RoundTrip sends a HTTP request to a backend and then returns its response.

View File

@ -0,0 +1,53 @@
package tun2
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/Xe/uuid"
)
func TestNewServerNullConfig(t *testing.T) {
_, err := NewServer(nil)
if err == nil {
t.Fatalf("expected NewServer(nil) to fail, got non-failure")
}
}
func TestGen502Page(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req, err := http.NewRequest("GET", "http://cetacean.club", nil)
if err != nil {
t.Fatal(err)
}
substring := uuid.New()
req = req.WithContext(ctx)
req.Header.Add("X-Request-Id", substring)
req.Host = "cetacean.club"
resp := gen502Page(req)
if resp == nil {
t.Fatalf("expected response to be non-nil")
}
if resp.Body != nil {
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(data), substring) {
fmt.Println(string(data))
t.Fatalf("502 page did not contain needed substring %q", substring)
}
}
}