package tun2
import (
"bytes"
"context"
"encoding/json"
"errors"
"expvar"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"net/http"
"os"
"sync"
"time"
"github.com/Xe/ln"
failure "github.com/dgryski/go-failure"
"github.com/mtneug/pkg/ulid"
cmap "github.com/streamrail/concurrent-map"
"github.com/xtaci/smux"
)
// Error values
var (
ErrNoSuchBackend = errors.New("tun2: there is no such backend")
ErrAuthMismatch = errors.New("tun2: authenication doesn't match database records")
ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it")
)
// gen502Page creates the page that is shown when a backend is not connected to a given route.
func gen502Page(req *http.Request) *http.Response {
template := `
no backends connectedno backends connected
Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.
`
resbody := []byte(os.Expand(template, func(in string) string {
switch in {
case "HOST":
return req.Host
case "REQ_ID":
return req.Header.Get("X-Request-Id")
}
return ""
}))
reshdr := req.Header
reshdr.Set("Content-Type", "text/html; charset=utf-8")
resp := &http.Response{
Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Body: ioutil.NopCloser(bytes.NewBuffer(resbody)),
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: reshdr,
ContentLength: int64(len(resbody)),
Close: true,
Request: req,
}
return resp
}
// ServerConfig ...
type ServerConfig struct {
SmuxConf *smux.Config
Storage Storage
}
// Storage is the minimal subset of features that tun2's Server needs out of a
// persistence layer.
type Storage interface {
HasToken(token string) (user string, scopes []string, err error)
HasRoute(domain string) (user string, err error)
}
// Server routes frontend HTTP traffic to backend TCP traffic.
type Server struct {
cfg *ServerConfig
ctx context.Context
cancel context.CancelFunc
connlock sync.Mutex
conns map[net.Conn]*Connection
domains cmap.ConcurrentMap
}
// NewServer creates a new Server instance with a given config, acquiring all
// relevant resources.
func NewServer(cfg *ServerConfig) (*Server, error) {
if cfg == nil {
return nil, errors.New("tun2: config must be specified")
}
if cfg.SmuxConf == nil {
cfg.SmuxConf = smux.DefaultConfig()
cfg.SmuxConf.KeepAliveInterval = time.Second
cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second
}
ctx, cancel := context.WithCancel(context.Background())
server := &Server{
cfg: cfg,
conns: map[net.Conn]*Connection{},
domains: cmap.New(),
ctx: ctx,
cancel: cancel,
}
go server.phiDetectionLoop(ctx)
return server, nil
}
// Close stops the background tasks for this Server.
func (s *Server) Close() {
s.cancel()
}
// Wait blocks until the server context is cancelled.
func (s *Server) Wait() {
for {
select {
case <-s.ctx.Done():
return
}
}
}
// Listen passes this Server a given net.Listener to accept backend connections.
func (s *Server) Listen(l net.Listener, isKCP bool) {
ctx := s.ctx
f := ln.F{
"listener_addr": l.Addr(),
"listener_network": l.Addr().Network(),
}
for {
select {
case <-ctx.Done():
return
default:
}
conn, err := l.Accept()
if err != nil {
ln.Error(ctx, err, f, ln.Action("accept connection"))
continue
}
ln.Log(ctx, f, ln.Action("new backend client connected"), ln.F{
"conn_addr": conn.RemoteAddr(),
"conn_network": conn.RemoteAddr().Network(),
})
go s.HandleConn(conn, isKCP)
}
}
// phiDetectionLoop is an infinite loop that will run the [phi accrual failure detector]
// for each of the backends connected to the Server. This is fairly experimental and
// may be removed.
//
// [phi accrual failure detector]: https://dspace.jaist.ac.jp/dspace/handle/10119/4784
func (s *Server) phiDetectionLoop(ctx context.Context) {
t := time.NewTicker(5 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
now := time.Now()
s.connlock.Lock()
for _, c := range s.conns {
failureChance := c.detector.Phi(now)
const thresh = 0.9 // the threshold for phi failure detection causing logs
if failureChance > thresh {
ln.Log(ctx, c, ln.Action("phi failure detection"), ln.F{
"value": failureChance,
"threshold": thresh,
})
}
}
s.connlock.Unlock()
}
}
}
// backendAuthv1 runs a simple backend authentication check. It expects the
// client to write a json-encoded instance of Auth. This is then checked
// for token validity and domain matching.
//
// This returns the user that was authenticated and the domain they identified
// with.
func (s *Server) backendAuthv1(ctx context.Context, st io.Reader) (string, *Auth, error) {
f := ln.F{
"action": "backend authentication",
"backend_auth_version": 1,
}
f["stage"] = "json decoding"
d := json.NewDecoder(st)
var auth Auth
err := d.Decode(&auth)
if err != nil {
ln.Error(ctx, err, f)
return "", nil, err
}
f["auth_domain"] = auth.Domain
f["stage"] = "checking domain"
routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
if err != nil {
ln.Error(ctx, err, f)
return "", nil, err
}
f["route_user"] = routeUser
f["stage"] = "checking token"
tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
if err != nil {
ln.Error(ctx, err, f)
return "", nil, err
}
f["token_user"] = tokenUser
f["stage"] = "checking token scopes"
ok := false
for _, sc := range scopes {
if sc == "connect" {
ok = true
break
}
}
if !ok {
ln.Error(ctx, ErrAuthMismatch, f)
return "", nil, ErrAuthMismatch
}
f["stage"] = "user verification"
if routeUser != tokenUser {
ln.Error(ctx, ErrAuthMismatch, f)
return "", nil, ErrAuthMismatch
}
return routeUser, &auth, nil
}
// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from
// the currently connected backend.
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
defer c.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
f := ln.F{
"local": c.LocalAddr().String(),
"remote": c.RemoteAddr().String(),
}
session, err := smux.Server(c, s.cfg.SmuxConf)
if err != nil {
ln.Error(ctx, err, f, ln.Action("establish server side of smux"))
return
}
defer session.Close()
controlStream, err := session.OpenStream()
if err != nil {
ln.Error(ctx, err, f, ln.Action("opening control stream"))
return
}
defer controlStream.Close()
user, auth, err := s.backendAuthv1(ctx, controlStream)
if err != nil {
return
}
connection := &Connection{
id: ulid.New().String(),
conn: c,
isKCP: isKCP,
session: session,
user: user,
domain: auth.Domain,
cf: cancel,
detector: failure.New(15, 1),
Auth: auth,
}
connection.counter = expvar.NewInt("http.backend." + connection.id + ".hits")
defer func() {
if r := recover(); r != nil {
ln.Log(ctx, connection, ln.F{"action": "connection handler panic", "err": r})
}
}()
ln.Log(ctx, connection, ln.Action("backend successfully connected"))
s.addConn(ctx, connection)
connection.usable = true // XXX set this to true once health checks pass?
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := connection.Ping()
if err != nil {
connection.cancel()
}
// case <-s.ctx.Done():
// ln.Log(ctx, connection, ln.Action("server context finished"))
// s.removeConn(ctx, connection)
// connection.Close()
// return
case <-ctx.Done():
ln.Log(ctx, connection, ln.Action("client context finished"))
s.removeConn(ctx, connection)
connection.Close()
return
}
}
}
// addConn adds a connection to the pool of backend connections.
func (s *Server) addConn(ctx context.Context, connection *Connection) {
s.connlock.Lock()
s.conns[connection.conn] = connection
s.connlock.Unlock()
var conns []*Connection
val, ok := s.domains.Get(connection.domain)
if ok {
conns, ok = val.([]*Connection)
if !ok {
conns = nil
s.domains.Remove(connection.domain)
}
}
conns = append(conns, connection)
s.domains.Set(connection.domain, conns)
}
// removeConn removes a connection from pool of backend connections.
func (s *Server) removeConn(ctx context.Context, connection *Connection) {
s.connlock.Lock()
delete(s.conns, connection.conn)
s.connlock.Unlock()
auth := connection.Auth
var conns []*Connection
val, ok := s.domains.Get(auth.Domain)
if ok {
conns, ok = val.([]*Connection)
if !ok {
ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection, ln.Action("looking up for disconnect removal"))
return
}
}
for i, cntn := range conns {
if cntn.id == connection.id {
conns[i] = conns[len(conns)-1]
conns = conns[:len(conns)-1]
}
}
if len(conns) != 0 {
s.domains.Set(auth.Domain, conns)
} else {
s.domains.Remove(auth.Domain)
}
}
// RoundTrip sends a HTTP request to a backend and then returns its response.
func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {
var conns []*Connection
ctx := req.Context()
f := ln.F{
"req_remote": req.RemoteAddr,
"req_host": req.Host,
"req_uri": req.RequestURI,
"req_method": req.Method,
"req_content_length": req.ContentLength,
}
val, ok := s.domains.Get(req.Host)
if ok {
conns, ok = val.([]*Connection)
if !ok {
ln.Error(ctx, ErrNoSuchBackend, f, ln.Action("no backend available"))
return gen502Page(req), nil
}
}
var goodConns []*Connection
for _, conn := range conns {
if conn.usable {
goodConns = append(goodConns, conn)
}
}
if len(goodConns) == 0 {
ln.Error(ctx, ErrNoSuchBackend, f, ln.Action("no good backends available"))
return gen502Page(req), nil
}
c := goodConns[rand.Intn(len(goodConns))]
resp, err := c.RoundTrip(req)
if err != nil {
ln.Error(ctx, err, c, f, ln.Action("connection roundtrip"))
defer c.cancel()
return nil, err
}
ln.Log(ctx, c, ln.Action("http traffic"), f, ln.F{
"resp_status_code": resp.StatusCode,
"resp_content_length": resp.ContentLength,
})
return resp, nil
}
// Auth is the authentication info the client passes to the server.
type Auth struct {
Token string `json:"token"`
Domain string `json:"domain"`
}