package tun2 import ( "context" "crypto/tls" "encoding/json" "errors" "math/rand" "net" "net/http" "sync" "time" "github.com/Xe/ln" failure "github.com/dgryski/go-failure" "github.com/mtneug/pkg/ulid" cmap "github.com/streamrail/concurrent-map" kcp "github.com/xtaci/kcp-go" "github.com/xtaci/smux" ) // Error values var ( ErrNoSuchBackend = errors.New("tun2: there is no such backend") ErrAuthMismatch = errors.New("tun2: authenication doesn't match database records") ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it") ) // ServerConfig ... type ServerConfig struct { TCPAddr string KCPAddr string TLSConfig *tls.Config SmuxConf *smux.Config Storage Storage } // Storage is the minimal subset of features that tun2's Server needs out of a // persistence layer. type Storage interface { HasToken(token string) (user string, scopes []string, err error) HasRoute(domain string) (user string, err error) } // Server routes frontend HTTP traffic to backend TCP traffic. type Server struct { cfg *ServerConfig connlock sync.Mutex conns map[net.Conn]*Connection domains cmap.ConcurrentMap } // NewServer creates a new Server instance with a given config, acquiring all // relevant resources. func NewServer(cfg *ServerConfig) (*Server, error) { if cfg == nil { return nil, errors.New("tun2: config must be specified") } if cfg.SmuxConf == nil { cfg.SmuxConf = smux.DefaultConfig() } cfg.SmuxConf.KeepAliveInterval = time.Second cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second server := &Server{ cfg: cfg, conns: map[net.Conn]*Connection{}, domains: cmap.New(), } return server, nil } type backendMatcher func(*Connection) bool func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend { s.connlock.Lock() defer s.connlock.Unlock() var result []Backend for _, c := range s.conns { if !bm(c) { continue } protocol := "tcp" if c.isKCP { protocol = "kcp" } result = append(result, Backend{ ID: c.id, Proto: protocol, User: c.user, Domain: c.domain, Phi: float32(c.detector.Phi(time.Now())), Host: c.conn.RemoteAddr().String(), Usable: c.usable, }) } return result } func (s *Server) KillBackend(id string) error { s.connlock.Lock() defer s.connlock.Unlock() for _, c := range s.conns { if c.id == id { c.cancel() return nil } } return ErrNoSuchBackend } func (s *Server) GetBackendsForDomain(domain string) []Backend { return s.getBackendsForMatcher(func(c *Connection) bool { return c.domain == domain }) } func (s *Server) GetBackendsForUser(uname string) []Backend { return s.getBackendsForMatcher(func(c *Connection) bool { return c.user == uname }) } func (s *Server) GetAllBackends() []Backend { return s.getBackendsForMatcher(func(*Connection) bool { return true }) } // ListenAndServe starts the backend TCP/KCP listeners and relays backend // traffic to and from them. func (s *Server) ListenAndServe() error { ln.Log(ln.F{ "action": "listen_and_serve_called", }) if s.cfg.TCPAddr != "" { go func() { l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig) if err != nil { panic(err) } ln.Log(ln.F{ "action": "tcp+tls_listening", "addr": l.Addr(), }) for { conn, err := l.Accept() if err != nil { ln.Error(err, ln.F{"kind": "tcp", "addr": l.Addr().String()}) continue } ln.Log(ln.F{ "action": "new_client", "kcp": false, "addr": conn.RemoteAddr(), }) go s.HandleConn(conn, false) } }() } if s.cfg.KCPAddr != "" { go func() { l, err := kcp.Listen(s.cfg.KCPAddr) if err != nil { panic(err) } ln.Log(ln.F{ "action": "kcp+tls_listening", "addr": l.Addr(), }) for { conn, err := l.Accept() if err != nil { ln.Error(err, ln.F{"kind": "kcp", "addr": l.Addr().String()}) } ln.Log(ln.F{ "action": "new_client", "kcp": true, "addr": conn.RemoteAddr(), }) tc := tls.Server(conn, s.cfg.TLSConfig) go s.HandleConn(tc, true) } }() } // XXX experimental, might get rid of this inside this process go func() { for { time.Sleep(time.Second) now := time.Now() s.connlock.Lock() for _, c := range s.conns { failureChance := c.detector.Phi(now) if failureChance > 0.8 { ln.Log(c.F(), ln.F{ "action": "phi_failure_detection", "value": failureChance, }) } } s.connlock.Unlock() } }() return nil } // HandleConn starts up the needed mechanisms to relay HTTP traffic to/from // the currently connected backend. func (s *Server) HandleConn(c net.Conn, isKCP bool) { // XXX TODO clean this up it's really ugly. defer c.Close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() session, err := smux.Server(c, s.cfg.SmuxConf) if err != nil { ln.Error(err, ln.F{ "action": "session_failure", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) c.Close() return } defer session.Close() controlStream, err := session.OpenStream() if err != nil { ln.Error(err, ln.F{ "action": "control_stream_failure", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) return } defer controlStream.Close() csd := json.NewDecoder(controlStream) auth := &Auth{} err = csd.Decode(auth) if err != nil { ln.Error(err, ln.F{ "action": "control_stream_auth_decoding_failure", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) return } routeUser, err := s.cfg.Storage.HasRoute(auth.Domain) if err != nil { ln.Error(err, ln.F{ "action": "nosuch_domain", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) return } tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token) if err != nil { ln.Error(err, ln.F{ "action": "nosuch_token", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) return } ok := false for _, sc := range scopes { if sc == "connect" { ok = true break } } if !ok { ln.Error(ErrAuthMismatch, ln.F{ "action": "token_not_authorized", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) } if routeUser != tokenUser { ln.Error(ErrAuthMismatch, ln.F{ "action": "auth_mismatch", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) return } connection := &Connection{ id: ulid.New().String(), conn: c, isKCP: isKCP, session: session, user: tokenUser, domain: auth.Domain, cf: cancel, detector: failure.New(15, 1), Auth: auth, } ln.Log(ln.F{ "action": "backend_connected", }, connection.F()) s.connlock.Lock() s.conns[c] = connection s.connlock.Unlock() var conns []*Connection val, ok := s.domains.Get(auth.Domain) if ok { conns, ok = val.([]*Connection) if !ok { conns = nil s.domains.Remove(auth.Domain) } } conns = append(conns, connection) s.domains.Set(auth.Domain, conns) connection.usable = true ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: err := connection.Ping() if err != nil { connection.cancel() } case <-ctx.Done(): s.RemoveConn(connection) connection.Close() return } } } // RemoveConn removes a connection. func (s *Server) RemoveConn(connection *Connection) { s.connlock.Lock() delete(s.conns, connection.conn) s.connlock.Unlock() auth := connection.Auth var conns []*Connection val, ok := s.domains.Get(auth.Domain) if ok { conns, ok = val.([]*Connection) if !ok { ln.Error(ErrCantRemoveWhatDoesntExist, connection.F(), ln.F{ "action": "looking_up_for_disconnect_removal", }) return } } for i, cntn := range conns { if cntn.id == connection.id { conns[i] = conns[len(conns)-1] conns = conns[:len(conns)-1] } } if len(conns) != 0 { s.domains.Set(auth.Domain, conns) } else { s.domains.Remove(auth.Domain) } ln.Log(connection.F(), ln.F{ "action": "client_disconnecting", }) } // RoundTrip sends a HTTP request to a backend and then returns its response. func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) { var conns []*Connection val, ok := s.domains.Get(req.Host) if ok { conns, ok = val.([]*Connection) if !ok { ln.Error(ErrNoSuchBackend, ln.F{ "action": "no_backend_connected", "remote": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) return nil, ErrNoSuchBackend } } var goodConns []*Connection for _, conn := range conns { if conn.usable { goodConns = append(goodConns, conn) } } if len(goodConns) == 0 { ln.Error(ErrNoSuchBackend, ln.F{ "action": "no_backend_connected", "remote": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) return nil, ErrNoSuchBackend } c := goodConns[rand.Intn(len(goodConns))] resp, err := c.RoundTrip(req) if err != nil { ln.Error(err, c.F(), ln.F{ "action": "connection_roundtrip", }) defer c.cancel() return nil, err } ln.Log(c.F(), ln.F{ "action": "http_traffic", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) return resp, nil } // Auth is the authentication info the client passes to the server. type Auth struct { Token string `json:"token"` Domain string `json:"domain"` } const defaultUser = "Cadey"