package tun2 import ( "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io/ioutil" "math/rand" "net" "net/http" "os" "sync" "time" "github.com/Xe/ln" failure "github.com/dgryski/go-failure" "github.com/mtneug/pkg/ulid" cmap "github.com/streamrail/concurrent-map" kcp "github.com/xtaci/kcp-go" "github.com/xtaci/smux" ) // Error values var ( ErrNoSuchBackend = errors.New("tun2: there is no such backend") ErrAuthMismatch = errors.New("tun2: authenication doesn't match database records") ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it") ) // gen502Page creates the page that is shown when a backend is not connected to a given route. func gen502Page(req *http.Request) *http.Response { template := `no backends connected

no backends connected

Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.

` resbody := []byte(os.Expand(template, func(in string) string { switch in { case "HOST": return req.Host case "REQ_ID": return req.Header.Get("X-Request-Id") } return "" })) reshdr := req.Header reshdr.Set("Content-Type", "text/html; charset=utf-8") resp := &http.Response{ Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway), StatusCode: http.StatusBadGateway, Body: ioutil.NopCloser(bytes.NewBuffer(resbody)), Proto: req.Proto, ProtoMajor: req.ProtoMajor, ProtoMinor: req.ProtoMinor, Header: reshdr, ContentLength: int64(len(resbody)), Close: true, Request: req, } return resp } // ServerConfig ... type ServerConfig struct { TCPAddr string KCPAddr string TLSConfig *tls.Config SmuxConf *smux.Config Storage Storage } // Storage is the minimal subset of features that tun2's Server needs out of a // persistence layer. type Storage interface { HasToken(token string) (user string, scopes []string, err error) HasRoute(domain string) (user string, err error) } // Server routes frontend HTTP traffic to backend TCP traffic. type Server struct { cfg *ServerConfig connlock sync.Mutex conns map[net.Conn]*Connection domains cmap.ConcurrentMap } // NewServer creates a new Server instance with a given config, acquiring all // relevant resources. func NewServer(cfg *ServerConfig) (*Server, error) { if cfg == nil { return nil, errors.New("tun2: config must be specified") } if cfg.SmuxConf == nil { cfg.SmuxConf = smux.DefaultConfig() } cfg.SmuxConf.KeepAliveInterval = time.Second cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second server := &Server{ cfg: cfg, conns: map[net.Conn]*Connection{}, domains: cmap.New(), } return server, nil } // Listen passes this Server a given net.Listener to accept backend connections. func (s *Server) Listen(l net.Listener, isKCP bool) { ctx := context.Background() for { conn, err := l.Accept() if err != nil { ln.Error(ctx, err, ln.F{ "addr": l.Addr().String(), "network": l.Addr().Network(), }) continue } ln.Log(ctx, ln.F{ "action": "new_client", "network": conn.RemoteAddr().Network(), "addr": conn.RemoteAddr(), "list": conn.LocalAddr(), }) go s.HandleConn(conn, isKCP) } } // ListenAndServe starts the backend TCP/KCP listeners and relays backend // traffic to and from them. func (s *Server) ListenAndServe() error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() ln.Log(ctx, ln.F{ "action": "listen_and_serve_called", }) if s.cfg.TCPAddr != "" { go func() { l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig) if err != nil { panic(err) } ln.Log(ctx, ln.F{ "action": "tcp+tls_listening", "addr": l.Addr(), }) for { conn, err := l.Accept() if err != nil { ln.Error(ctx, err, ln.F{"kind": "tcp", "addr": l.Addr().String()}) continue } ln.Log(ctx, ln.F{ "action": "new_client", "kcp": false, "addr": conn.RemoteAddr(), }) go s.HandleConn(conn, false) } }() } if s.cfg.KCPAddr != "" { go func() { l, err := kcp.Listen(s.cfg.KCPAddr) if err != nil { panic(err) } ln.Log(ctx, ln.F{ "action": "kcp+tls_listening", "addr": l.Addr(), }) for { conn, err := l.Accept() if err != nil { ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()}) } ln.Log(ctx, ln.F{ "action": "new_client", "kcp": true, "addr": conn.RemoteAddr(), }) tc := tls.Server(conn, s.cfg.TLSConfig) go s.HandleConn(tc, true) } }() } // XXX experimental, might get rid of this inside this process go func() { for { time.Sleep(time.Second) now := time.Now() s.connlock.Lock() for _, c := range s.conns { failureChance := c.detector.Phi(now) if failureChance > 0.8 { ln.Log(ctx, c.F(), ln.F{ "action": "phi_failure_detection", "value": failureChance, }) } } s.connlock.Unlock() } }() return nil } // HandleConn starts up the needed mechanisms to relay HTTP traffic to/from // the currently connected backend. func (s *Server) HandleConn(c net.Conn, isKCP bool) { // XXX TODO clean this up it's really ugly. defer c.Close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() f := ln.F{ "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), } session, err := smux.Server(c, s.cfg.SmuxConf) if err != nil { ln.Error(ctx, err, f, ln.Action("establish server side of smux")) return } defer session.Close() f["stage"] = "smux_setup" controlStream, err := session.OpenStream() if err != nil { ln.Error(ctx, err, f, ln.Action("opening control stream")) return } defer controlStream.Close() f["stage"] = "control_stream_open" csd := json.NewDecoder(controlStream) auth := &Auth{} err = csd.Decode(auth) if err != nil { ln.Error(ctx, err, f, ln.Action("decode control stream authenication message")) return } f["stage"] = "checking_domain" routeUser, err := s.cfg.Storage.HasRoute(auth.Domain) if err != nil { ln.Error(ctx, err, f, ln.Action("no such domain when checking client auth")) return } f["route_user"] = routeUser f["stage"] = "checking_token" tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token) if err != nil { ln.Error(ctx, err, f, ln.Action("no such token exists or other token error")) return } f["token_user"] = tokenUser f["stage"] = "checking_token_scopes" ok := false for _, sc := range scopes { if sc == "connect" { ok = true break } } if !ok { ln.Error(ctx, ErrAuthMismatch, f, ln.Action("token not authorized to connect")) return } f["stage"] = "user_verification" if routeUser != tokenUser { ln.Error(ctx, ErrAuthMismatch, f, ln.Action("auth mismatch")) return } connection := &Connection{ id: ulid.New().String(), conn: c, isKCP: isKCP, session: session, user: tokenUser, domain: auth.Domain, cf: cancel, detector: failure.New(15, 1), Auth: auth, } defer func() { if r := recover(); r != nil { ln.Log(ctx, connection, ln.F{"action": "connection handler panic", "err": r}) } }() ln.Log(ctx, connection, ln.Action("backend successfully connected")) // TODO: put these lines in a function? s.connlock.Lock() s.conns[c] = connection s.connlock.Unlock() var conns []*Connection val, ok := s.domains.Get(auth.Domain) if ok { conns, ok = val.([]*Connection) if !ok { conns = nil s.domains.Remove(auth.Domain) } } conns = append(conns, connection) s.domains.Set(auth.Domain, conns) connection.usable = true // XXX set this to true once health checks pass? ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: err := connection.Ping() if err != nil { connection.cancel() } case <-ctx.Done(): s.RemoveConn(ctx, connection) connection.Close() return } } } // RemoveConn removes a connection. func (s *Server) RemoveConn(ctx context.Context, connection *Connection) { s.connlock.Lock() delete(s.conns, connection.conn) s.connlock.Unlock() auth := connection.Auth var conns []*Connection val, ok := s.domains.Get(auth.Domain) if ok { conns, ok = val.([]*Connection) if !ok { ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection, ln.Action("looking up for disconnect removal")) return } } for i, cntn := range conns { if cntn.id == connection.id { conns[i] = conns[len(conns)-1] conns = conns[:len(conns)-1] } } if len(conns) != 0 { s.domains.Set(auth.Domain, conns) } else { s.domains.Remove(auth.Domain) } ln.Log(ctx, connection, ln.Action("backend disconnect")) } // RoundTrip sends a HTTP request to a backend and then returns its response. func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) { var conns []*Connection ctx := req.Context() val, ok := s.domains.Get(req.Host) if ok { conns, ok = val.([]*Connection) if !ok { ln.Error(ctx, ErrNoSuchBackend, ln.F{ "action": "no_backend_connected", "remote": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) return gen502Page(req), nil } } var goodConns []*Connection for _, conn := range conns { if conn.usable { goodConns = append(goodConns, conn) } } if len(goodConns) == 0 { ln.Error(ctx, ErrNoSuchBackend, ln.F{ "action": "no_backend_connected", "remote": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) return gen502Page(req), nil } c := goodConns[rand.Intn(len(goodConns))] resp, err := c.RoundTrip(req) if err != nil { ln.Error(ctx, err, c, ln.F{ "action": "connection_roundtrip", }) defer c.cancel() return nil, err } ln.Log(ctx, c, ln.F{ "action": "http traffic", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.URL.Path, "status": resp.Status, "status_code": resp.StatusCode, "content_length": resp.ContentLength, }) return resp, nil } // Auth is the authentication info the client passes to the server. type Auth struct { Token string `json:"token"` Domain string `json:"domain"` } const defaultUser = "Cadey"