tun2: documentation and unit tests
This commit is contained in:
parent
0a7f7a4652
commit
5afb3715cc
|
@ -1,5 +1,7 @@
|
|||
package tun2
|
||||
|
||||
import "time"
|
||||
|
||||
// Backend is the public state of an individual Connection.
|
||||
type Backend struct {
|
||||
ID string
|
||||
|
@ -10,3 +12,72 @@ type Backend struct {
|
|||
Host string
|
||||
Usable bool
|
||||
}
|
||||
|
||||
type backendMatcher func(*Connection) bool
|
||||
|
||||
func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
|
||||
s.connlock.Lock()
|
||||
defer s.connlock.Unlock()
|
||||
|
||||
var result []Backend
|
||||
|
||||
for _, c := range s.conns {
|
||||
if !bm(c) {
|
||||
continue
|
||||
}
|
||||
|
||||
protocol := "tcp"
|
||||
if c.isKCP {
|
||||
protocol = "kcp"
|
||||
}
|
||||
|
||||
result = append(result, Backend{
|
||||
ID: c.id,
|
||||
Proto: protocol,
|
||||
User: c.user,
|
||||
Domain: c.domain,
|
||||
Phi: float32(c.detector.Phi(time.Now())),
|
||||
Host: c.conn.RemoteAddr().String(),
|
||||
Usable: c.usable,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// KillBackend forcibly disconnects a given backend but doesn't offer a way to
|
||||
// "ban" it from reconnecting.
|
||||
func (s *Server) KillBackend(id string) error {
|
||||
s.connlock.Lock()
|
||||
defer s.connlock.Unlock()
|
||||
|
||||
for _, c := range s.conns {
|
||||
if c.id == id {
|
||||
c.cancel()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrNoSuchBackend
|
||||
}
|
||||
|
||||
// GetBackendsForDomain fetches all backends connected to this server associated
|
||||
// to a single public domain name.
|
||||
func (s *Server) GetBackendsForDomain(domain string) []Backend {
|
||||
return s.getBackendsForMatcher(func(c *Connection) bool {
|
||||
return c.domain == domain
|
||||
})
|
||||
}
|
||||
|
||||
// GetBackendsForUser fetches all backends connected to this server owned by a
|
||||
// given user by username.
|
||||
func (s *Server) GetBackendsForUser(uname string) []Backend {
|
||||
return s.getBackendsForMatcher(func(c *Connection) bool {
|
||||
return c.user == uname
|
||||
})
|
||||
}
|
||||
|
||||
// GetAllBackends fetches every backend connected to this server.
|
||||
func (s *Server) GetAllBackends() []Backend {
|
||||
return s.getBackendsForMatcher(func(*Connection) bool { return true })
|
||||
}
|
||||
|
|
|
@ -14,10 +14,14 @@ import (
|
|||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
// Client connects to a remote tun2 server and sets up authentication before routing
|
||||
// individual HTTP requests to discrete streams that are reverse proxied to the eventual
|
||||
// backend.
|
||||
type Client struct {
|
||||
cfg *ClientConfig
|
||||
}
|
||||
|
||||
// ClientConfig configures client with settings that the user provides.
|
||||
type ClientConfig struct {
|
||||
TLSConfig *tls.Config
|
||||
ConnType string
|
||||
|
@ -27,6 +31,7 @@ type ClientConfig struct {
|
|||
BackendURL string
|
||||
}
|
||||
|
||||
// NewClient constructs an instance of Client with a given ClientConfig.
|
||||
func NewClient(cfg *ClientConfig) (*Client, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("tun2: client config needed")
|
||||
|
@ -39,6 +44,11 @@ func NewClient(cfg *ClientConfig) (*Client, error) {
|
|||
return c, nil
|
||||
}
|
||||
|
||||
// Connect dials the remote server and negotiates a client session with its
|
||||
// configured server address. This will then continuously proxy incoming HTTP
|
||||
// requests to the backend HTTP server.
|
||||
//
|
||||
// This is a blocking function.
|
||||
func (c *Client) Connect() error {
|
||||
return c.connect(c.cfg.ServerAddr)
|
||||
}
|
||||
|
@ -117,15 +127,12 @@ func (c *Client) connect(serverAddr string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// smuxListener wraps a smux session as a net.Listener.
|
||||
type smuxListener struct {
|
||||
conn net.Conn
|
||||
session *smux.Session
|
||||
}
|
||||
|
||||
var (
|
||||
_ net.Listener = &smuxListener{} // interface check
|
||||
)
|
||||
|
||||
func (sl *smuxListener) Accept() (net.Conn, error) {
|
||||
return sl.session.AcceptStream()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package tun2
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewClientNullConfig(t *testing.T) {
|
||||
_, err := NewClient(nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected NewClient(nil) to fail, got non-failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmuxListenerIsNetListener(t *testing.T) {
|
||||
var sl interface{} = &smuxListener{}
|
||||
_, ok := sl.(net.Listener)
|
||||
if !ok {
|
||||
t.Fatalf("smuxListener does not implement net.Listener")
|
||||
}
|
||||
}
|
|
@ -30,6 +30,40 @@ var (
|
|||
ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it")
|
||||
)
|
||||
|
||||
// gen502Page creates the page that is shown when a backend is not connected to a given route.
|
||||
func gen502Page(req *http.Request) *http.Response {
|
||||
template := `<html><head><title>no backends connected</title></head><body><h1>no backends connected</h1><p>Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.</p></body></html>`
|
||||
|
||||
resbody := []byte(os.Expand(template, func(in string) string {
|
||||
switch in {
|
||||
case "HOST":
|
||||
return req.Host
|
||||
case "REQ_ID":
|
||||
return req.Header.Get("X-Request-Id")
|
||||
}
|
||||
|
||||
return "<unknown>"
|
||||
}))
|
||||
reshdr := req.Header
|
||||
reshdr.Set("Content-Type", "text/html; charset=utf-8")
|
||||
|
||||
resp := &http.Response{
|
||||
Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway),
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Body: ioutil.NopCloser(bytes.NewBuffer(resbody)),
|
||||
|
||||
Proto: req.Proto,
|
||||
ProtoMajor: req.ProtoMajor,
|
||||
ProtoMinor: req.ProtoMinor,
|
||||
Header: reshdr,
|
||||
ContentLength: int64(len(resbody)),
|
||||
Close: true,
|
||||
Request: req,
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// ServerConfig ...
|
||||
type ServerConfig struct {
|
||||
TCPAddr string
|
||||
|
@ -81,66 +115,29 @@ func NewServer(cfg *ServerConfig) (*Server, error) {
|
|||
return server, nil
|
||||
}
|
||||
|
||||
type backendMatcher func(*Connection) bool
|
||||
// Listen passes this Server a given net.Listener to accept backend connections.
|
||||
func (s *Server) Listen(l net.Listener, isKCP bool) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
|
||||
s.connlock.Lock()
|
||||
defer s.connlock.Unlock()
|
||||
|
||||
var result []Backend
|
||||
|
||||
for _, c := range s.conns {
|
||||
if !bm(c) {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"addr": l.Addr().String(),
|
||||
"network": l.Addr().Network(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
protocol := "tcp"
|
||||
if c.isKCP {
|
||||
protocol = "kcp"
|
||||
}
|
||||
|
||||
result = append(result, Backend{
|
||||
ID: c.id,
|
||||
Proto: protocol,
|
||||
User: c.user,
|
||||
Domain: c.domain,
|
||||
Phi: float32(c.detector.Phi(time.Now())),
|
||||
Host: c.conn.RemoteAddr().String(),
|
||||
Usable: c.usable,
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "new_client",
|
||||
"network": conn.RemoteAddr().Network(),
|
||||
"addr": conn.RemoteAddr(),
|
||||
"list": conn.LocalAddr(),
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
go s.HandleConn(conn, isKCP)
|
||||
}
|
||||
|
||||
func (s *Server) KillBackend(id string) error {
|
||||
s.connlock.Lock()
|
||||
defer s.connlock.Unlock()
|
||||
|
||||
for _, c := range s.conns {
|
||||
if c.id == id {
|
||||
c.cancel()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrNoSuchBackend
|
||||
}
|
||||
|
||||
func (s *Server) GetBackendsForDomain(domain string) []Backend {
|
||||
return s.getBackendsForMatcher(func(c *Connection) bool {
|
||||
return c.domain == domain
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) GetBackendsForUser(uname string) []Backend {
|
||||
return s.getBackendsForMatcher(func(c *Connection) bool {
|
||||
return c.user == uname
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) GetAllBackends() []Backend {
|
||||
return s.getBackendsForMatcher(func(*Connection) bool { return true })
|
||||
}
|
||||
|
||||
// ListenAndServe starts the backend TCP/KCP listeners and relays backend
|
||||
|
@ -248,67 +245,62 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
session, err := smux.Server(c, s.cfg.SmuxConf)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "session_failure",
|
||||
f := ln.F{
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
}
|
||||
|
||||
c.Close()
|
||||
session, err := smux.Server(c, s.cfg.SmuxConf)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.Action("establish server side of smux"))
|
||||
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
f["stage"] = "smux_setup"
|
||||
|
||||
controlStream, err := session.OpenStream()
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "control_stream_failure",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
ln.Error(ctx, err, f, ln.Action("opening control stream"))
|
||||
|
||||
return
|
||||
}
|
||||
defer controlStream.Close()
|
||||
|
||||
f["stage"] = "control_stream_open"
|
||||
|
||||
csd := json.NewDecoder(controlStream)
|
||||
auth := &Auth{}
|
||||
err = csd.Decode(auth)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "control_stream_auth_decoding_failure",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
ln.Error(ctx, err, f, ln.Action("decode control stream authenication message"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f["stage"] = "checking_domain"
|
||||
|
||||
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "nosuch_domain",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
ln.Error(ctx, err, f, ln.Action("no such domain when checking client auth"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f["route_user"] = routeUser
|
||||
f["stage"] = "checking_token"
|
||||
|
||||
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{
|
||||
"action": "nosuch_token",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
ln.Error(ctx, err, f, ln.Action("no such token exists or other token error"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f["token_user"] = tokenUser
|
||||
f["stage"] = "checking_token_scopes"
|
||||
|
||||
ok := false
|
||||
for _, sc := range scopes {
|
||||
if sc == "connect" {
|
||||
|
@ -318,19 +310,15 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
}
|
||||
|
||||
if !ok {
|
||||
ln.Error(ctx, ErrAuthMismatch, ln.F{
|
||||
"action": "token_not_authorized",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
ln.Error(ctx, ErrAuthMismatch, f, ln.Action("token not authorized to connect"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f["stage"] = "user_verification"
|
||||
|
||||
if routeUser != tokenUser {
|
||||
ln.Error(ctx, ErrAuthMismatch, ln.F{
|
||||
"action": "auth_mismatch",
|
||||
"local": c.LocalAddr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
ln.Error(ctx, ErrAuthMismatch, f, ln.Action("auth mismatch"))
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -353,10 +341,9 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
}
|
||||
}()
|
||||
|
||||
ln.Log(ctx, ln.F{
|
||||
"action": "backend_connected",
|
||||
}, connection.F())
|
||||
ln.Log(ctx, connection, ln.Action("backend successfully connected"))
|
||||
|
||||
// TODO: put these lines in a function?
|
||||
s.connlock.Lock()
|
||||
s.conns[c] = connection
|
||||
s.connlock.Unlock()
|
||||
|
@ -376,7 +363,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
conns = append(conns, connection)
|
||||
|
||||
s.domains.Set(auth.Domain, conns)
|
||||
connection.usable = true
|
||||
connection.usable = true // XXX set this to true once health checks pass?
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
@ -411,9 +398,8 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
|
|||
if ok {
|
||||
conns, ok = val.([]*Connection)
|
||||
if !ok {
|
||||
ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection.F(), ln.F{
|
||||
"action": "looking_up_for_disconnect_removal",
|
||||
})
|
||||
ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection, ln.Action("looking up for disconnect removal"))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -431,42 +417,7 @@ func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
|
|||
s.domains.Remove(auth.Domain)
|
||||
}
|
||||
|
||||
ln.Log(ctx, connection.F(), ln.F{
|
||||
"action": "client_disconnecting",
|
||||
})
|
||||
}
|
||||
|
||||
func gen502Page(req *http.Request) *http.Response {
|
||||
template := `<html><head><title>no backends connected</title></head><body><h1>no backends connected</h1><p>Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.</p></body></html>`
|
||||
|
||||
resbody := []byte(os.Expand(template, func(in string) string {
|
||||
switch in {
|
||||
case "HOST":
|
||||
return req.Host
|
||||
case "REQ_ID":
|
||||
return req.Header.Get("X-Request-Id")
|
||||
}
|
||||
|
||||
return "<unknown>"
|
||||
}))
|
||||
reshdr := req.Header
|
||||
reshdr.Set("Content-Type", "text/html; charset=utf-8")
|
||||
|
||||
resp := &http.Response{
|
||||
Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway),
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Body: ioutil.NopCloser(bytes.NewBuffer(resbody)),
|
||||
|
||||
Proto: req.Proto,
|
||||
ProtoMajor: req.ProtoMajor,
|
||||
ProtoMinor: req.ProtoMinor,
|
||||
Header: reshdr,
|
||||
ContentLength: int64(len(resbody)),
|
||||
Close: true,
|
||||
Request: req,
|
||||
}
|
||||
|
||||
return resp
|
||||
ln.Log(ctx, connection, ln.Action("backend disconnect"))
|
||||
}
|
||||
|
||||
// RoundTrip sends a HTTP request to a backend and then returns its response.
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
package tun2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Xe/uuid"
|
||||
)
|
||||
|
||||
func TestNewServerNullConfig(t *testing.T) {
|
||||
_, err := NewServer(nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected NewServer(nil) to fail, got non-failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGen502Page(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequest("GET", "http://cetacean.club", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
substring := uuid.New()
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Add("X-Request-Id", substring)
|
||||
req.Host = "cetacean.club"
|
||||
|
||||
resp := gen502Page(req)
|
||||
if resp == nil {
|
||||
t.Fatalf("expected response to be non-nil")
|
||||
}
|
||||
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(data), substring) {
|
||||
fmt.Println(string(data))
|
||||
t.Fatalf("502 page did not contain needed substring %q", substring)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue