package tun2 import ( "bufio" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io/ioutil" "math/rand" "net" "net/http" "strings" "sync" "time" "git.xeserv.us/xena/route/database" "github.com/Xe/ln" failure "github.com/dgryski/go-failure" "github.com/mtneug/pkg/ulid" cmap "github.com/streamrail/concurrent-map" kcp "github.com/xtaci/kcp-go" "github.com/xtaci/smux" ) type ServerConfig struct { TCPAddr string KCPAddr string TLSConfig *tls.Config SmuxConf *smux.Config Storage Storage } type Storage interface { GetRouteForHost(name string) (*database.Route, error) //ValidateToken(token string) (username string, ok bool, err error) // XXX RIP implement when users are implemented } type Server struct { cfg *ServerConfig connlock sync.Mutex conns map[net.Conn]*Connection domains cmap.ConcurrentMap } type Connection struct { id string conn net.Conn isKCP bool session *smux.Session controlStream *smux.Stream user string domain string cancel context.CancelFunc detector *failure.Detector } func (c *Connection) F() ln.F { return map[string]interface{}{ "id": c.id, "remote": c.conn.RemoteAddr(), "local": c.conn.LocalAddr(), "isKCP": c.isKCP, "user": c.user, "domain": c.domain, } } func NewServer(cfg *ServerConfig) (*Server, error) { if cfg == nil { return nil, errors.New("tun2: config must be specified") } if cfg.SmuxConf == nil { cfg.SmuxConf = smux.DefaultConfig() } server := &Server{ cfg: cfg, conns: map[net.Conn]*Connection{}, domains: cmap.New(), } return server, nil } func (s *Server) ListenAndServe() error { ln.Log(ln.F{ "action": "listen_and_serve_called", }) if s.cfg.TCPAddr != "" { go func() { l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig) if err != nil { panic(err) } ln.Log(ln.F{ "action": "tcp+tls_listening", "addr": l.Addr(), }) for { conn, err := l.Accept() if err != nil { ln.Error(err, ln.F{"kind": "tcp", "addr": l.Addr().String()}) continue } ln.Log(ln.F{ "action": "new_client", "kcp": false, "addr": conn.RemoteAddr(), }) go s.HandleConn(conn, false) } }() } if s.cfg.KCPAddr != "" { go func() { l, err := kcp.Listen(s.cfg.KCPAddr) if err != nil { panic(err) } ln.Log(ln.F{ "action": "kcp+tls_listening", "addr": l.Addr(), }) for { conn, err := l.Accept() if err != nil { ln.Error(err, ln.F{"kind": "kcp", "addr": l.Addr().String()}) } ln.Log(ln.F{ "action": "new_client", "kcp": true, "addr": conn.RemoteAddr(), }) tc := tls.Server(conn, s.cfg.TLSConfig) go s.HandleConn(tc, true) } }() } go func() { for { time.Sleep(5 * time.Second) now := time.Now() s.connlock.Lock() for _, c := range s.conns { failureChance := c.detector.Phi(now) if failureChance != 0 { ln.Log(c.F(), ln.F{ "action": "phi_failure_detection", "value": failureChance, }) } } s.connlock.Unlock() } }() return nil } // Ping ends a "ping" to the client. If the client doesn't respond or the connection // dies, then the connection needs to be cleaned up. func (c *Connection) Ping() error { req, err := http.NewRequest("GET", "http://backend/health", nil) if err != nil { panic(err) } stream, err := c.OpenStream() if err != nil { ln.Error(err, c.F()) defer c.cancel() return err } defer stream.Close() stream.SetWriteDeadline(time.Now().Add(time.Second)) err = req.Write(stream) if err != nil { ln.Error(err, c.F()) defer c.cancel() return err } stream.SetReadDeadline(time.Now().Add(5 * time.Second)) _, err = stream.Read(make([]byte, 30)) if err != nil { ln.Error(err, c.F()) defer c.cancel() return err } c.detector.Ping(time.Now()) return nil } // OpenStream creates a new stream (connection) to the backend server. func (c *Connection) OpenStream() (net.Conn, error) { err := c.conn.SetDeadline(time.Now().Add(time.Second)) if err != nil { ln.Error(err, c.F()) return nil, err } stream, err := c.session.OpenStream() if err != nil { ln.Error(err, c.F()) return nil, err } return stream, c.conn.SetDeadline(time.Time{}) } // Close destroys resouces specific to the connection. func (c *Connection) Close() error { err := c.controlStream.Close() if err != nil { return err } err = c.session.Close() if err != nil { return err } err = c.conn.Close() if err != nil { return err } return nil } func (s *Server) HandleConn(c net.Conn, isKCP bool) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() session, err := smux.Server(c, s.cfg.SmuxConf) if err != nil { ln.Error(err, ln.F{ "action": "session_failure", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) c.Close() return } controlStream, err := session.OpenStream() if err != nil { ln.Error(err, ln.F{ "action": "control_stream_failure", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) session.Close() c.Close() return } csd := json.NewDecoder(controlStream) auth := &Auth{} err = csd.Decode(auth) if err != nil { ln.Error(err, ln.F{ "action": "control_stream_auth_decoding_failure", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) controlStream.Close() session.Close() c.Close() return } route, err := s.cfg.Storage.GetRouteForHost(auth.Domain) if err != nil { ln.Error(err, ln.F{ "action": "nosuch_domain", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) controlStream.Close() session.Close() c.Close() return } if route.Token != auth.Token { ln.Error(err, ln.F{ "action": "bad_token", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) fmt.Fprintln(controlStream, "bad token") controlStream.Close() session.Close() c.Close() return } connection := &Connection{ id: ulid.New().String(), conn: c, isKCP: isKCP, session: session, user: defaultUser, // XXX RIP replace this with the actual token user once users are implemented domain: auth.Domain, cancel: cancel, detector: failure.New(15, 1), } ln.Log(ln.F{ "action": "backend_connected", }, connection.F()) s.connlock.Lock() s.conns[c] = connection s.connlock.Unlock() var conns []*Connection val, ok := s.domains.Get(auth.Domain) if ok { conns, ok = val.([]*Connection) if !ok { conns = nil s.domains.Remove(auth.Domain) } } conns = append(conns, connection) s.domains.Set(auth.Domain, conns) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: err := connection.Ping() if err != nil { cancel() } case <-ctx.Done(): s.RemoveConn(auth, connection) connection.Close() return } } } func (s *Server) RemoveConn(auth *Auth, connection *Connection) { s.connlock.Lock() delete(s.conns, connection.conn) s.connlock.Unlock() var conns []*Connection val, ok := s.domains.Get(auth.Domain) if ok { conns, ok = val.([]*Connection) if !ok { ln.Error(errors.New("fundamental assertion is not met"), connection.F(), ln.F{ "action": "looking_up_for_disconnect_removal", }) return } } for i, cntn := range conns { if cntn.id == connection.id { conns[i] = conns[len(conns)-1] conns = conns[:len(conns)-1] } } if len(conns) != 0 { s.domains.Set(auth.Domain, conns) } else { s.domains.Remove(auth.Domain) } ln.Log(connection.F(), ln.F{ "action": "client_disconnecting", }) } func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) { var conns []*Connection val, ok := s.domains.Get(req.Host) if ok { conns, ok = val.([]*Connection) if !ok { ln.Error(errors.New("no backend connected"), ln.F{ "action": "no_backend_connected", "remote": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) resp := &http.Response{ StatusCode: http.StatusBadGateway, Body: ioutil.NopCloser(strings.NewReader("no such domain")), ContentLength: 14, Close: true, Request: req, } return resp, errors.New("no backend connected") } } if len(conns) == 0 { ln.Error(errors.New("no backend connected"), ln.F{ "action": "no_backend_connected", "remote": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) return nil, errors.New("no backend connected") } c := conns[rand.Intn(len(conns))] c.conn.SetDeadline(time.Now().Add(time.Second)) stream, err := c.session.OpenStream() if err != nil { ln.Error(err, ln.F{ "action": "opening_session_stream", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }, c.F()) c.cancel() return s.RoundTrip(req) } defer stream.Close() c.conn.SetDeadline(time.Time{}) err = req.Write(stream) if err != nil { ln.Error(err, ln.F{ "action": "request_writing", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }, c.F()) c.cancel() return s.RoundTrip(req) } buf := bufio.NewReader(stream) resp, err := http.ReadResponse(buf, req) if err != nil { ln.Error(err, ln.F{ "action": "response_reading", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }, c.F()) c.cancel() return nil, err } ln.Log(c.F(), ln.F{ "action": "http_traffic", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) return resp, nil } type Auth struct { Token string `json:"token"` Domain string `json:"domain"` } const defaultUser = "Cadey"