package tun2 import ( "bufio" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io/ioutil" "math/rand" "net" "net/http" "strings" "sync" "time" "git.xeserv.us/xena/route/database" "github.com/Xe/ln" "github.com/mtneug/pkg/ulid" cmap "github.com/streamrail/concurrent-map" kcp "github.com/xtaci/kcp-go" "github.com/xtaci/smux" ) type ServerConfig struct { TCPAddr string KCPAddr string TLSConfig *tls.Config SmuxConf *smux.Config Storage Storage } type Storage interface { GetRouteForHost(name string) (*database.Route, error) //ValidateToken(token string) (username string, ok bool, err error) // XXX RIP implement when users are implemented } type Server struct { cfg *ServerConfig connlock sync.Mutex conns map[net.Conn]*Connection domains cmap.ConcurrentMap } type Connection struct { id string conn net.Conn isKCP bool session *smux.Session controlStream *smux.Stream user string domain string cancel context.CancelFunc } func (c *Connection) F() ln.F { return map[string]interface{}{ "id": c.id, "remote": c.conn.RemoteAddr(), "local": c.conn.LocalAddr(), "isKCP": c.isKCP, "user": c.user, "domain": c.domain, } } func NewServer(cfg *ServerConfig) (*Server, error) { if cfg == nil { return nil, errors.New("tun2: config must be specified") } if cfg.SmuxConf == nil { cfg.SmuxConf = smux.DefaultConfig() } server := &Server{ cfg: cfg, conns: map[net.Conn]*Connection{}, domains: cmap.New(), } return server, nil } func (s *Server) ListenAndServe() error { ln.Log(ln.F{ "action": "listen_and_serve_called", }) if s.cfg.TCPAddr != "" { go func() { l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig) if err != nil { panic(err) } ln.Log(ln.F{ "action": "tcp+tls_listening", "addr": l.Addr(), }) for { conn, err := l.Accept() if err != nil { ln.Error(err, ln.F{"kind": "tcp", "addr": l.Addr().String()}) continue } ln.Log(ln.F{ "action": "new_client", "kcp": false, "addr": conn.RemoteAddr(), }) go s.HandleConn(conn, false) } }() } if s.cfg.KCPAddr != "" { go func() { l, err := kcp.Listen(s.cfg.KCPAddr) if err != nil { panic(err) } ln.Log(ln.F{ "action": "kcp+tls_listening", "addr": l.Addr(), }) for { conn, err := l.Accept() if err != nil { ln.Error(err, ln.F{"kind": "kcp", "addr": l.Addr().String()}) } ln.Log(ln.F{ "action": "new_client", "kcp": true, "addr": conn.RemoteAddr(), }) tc := tls.Server(conn, s.cfg.TLSConfig) go s.HandleConn(tc, true) } }() } go func() { for { time.Sleep(5 * time.Second) s.pingLoopInner() } }() return nil } func (s *Server) pingLoopInner() { s.connlock.Lock() defer s.connlock.Unlock() for _, c := range s.conns { req, err := http.NewRequest("GET", "http://backend/health", nil) if err != nil { panic(err) } c.conn.SetDeadline(time.Now().Add(time.Second)) stream, err := c.session.OpenStream() if err != nil { ln.Error(err, c.F()) c.cancel() return } c.conn.SetDeadline(time.Time{}) stream.SetWriteDeadline(time.Now().Add(time.Second)) err = req.Write(stream) if err != nil { ln.Error(err, c.F()) c.cancel() return } stream.SetReadDeadline(time.Now().Add(time.Second)) _, err = stream.Read(make([]byte, 30)) if err != nil { ln.Error(err, c.F()) c.cancel() return } stream.Close() /* ln.Log(ln.F{ "action": "ping_health_is_good", }, c.F()) */ } } func (s *Server) HandleConn(c net.Conn, isKCP bool) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() session, err := smux.Server(c, s.cfg.SmuxConf) if err != nil { ln.Error(err, ln.F{ "action": "session_failure", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) c.Close() return } controlStream, err := session.OpenStream() if err != nil { ln.Error(err, ln.F{ "action": "control_stream_failure", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) session.Close() c.Close() return } csd := json.NewDecoder(controlStream) auth := &Auth{} err = csd.Decode(auth) if err != nil { ln.Error(err, ln.F{ "action": "control_stream_auth_decoding_failure", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) controlStream.Close() session.Close() c.Close() return } route, err := s.cfg.Storage.GetRouteForHost(auth.Domain) if err != nil { ln.Error(err, ln.F{ "action": "nosuch_domain", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) controlStream.Close() session.Close() c.Close() return } if route.Token != auth.Token { ln.Error(err, ln.F{ "action": "bad_token", "local": c.LocalAddr().String(), "remote": c.RemoteAddr().String(), }) fmt.Fprintln(controlStream, "bad token") controlStream.Close() session.Close() c.Close() return } connection := &Connection{ id: ulid.New().String(), conn: c, isKCP: isKCP, session: session, user: defaultUser, // XXX RIP replace this with the actual token user once users are implemented domain: auth.Domain, cancel: cancel, } ln.Log(ln.F{ "action": "backend_connected", }, connection.F()) s.connlock.Lock() s.conns[c] = connection s.connlock.Unlock() var conns []*Connection val, ok := s.domains.Get(auth.Domain) if ok { conns, ok = val.([]*Connection) if !ok { conns = nil s.domains.Remove(auth.Domain) } } conns = append(conns, connection) s.domains.Set(auth.Domain, conns) select { case <-ctx.Done(): s.connlock.Lock() delete(s.conns, c) s.connlock.Unlock() var conns []*Connection val, ok := s.domains.Get(auth.Domain) if ok { conns, ok = val.([]*Connection) if !ok { ln.Error(err, connection.F(), ln.F{ "action": "looking_up_for_disconnect_removal", }) return } } for i, cntn := range conns { if cntn.id == connection.id { conns[i] = conns[len(conns)-1] conns = conns[:len(conns)-1] } } if len(conns) != 0 { s.domains.Set(auth.Domain, conns) } else { s.domains.Remove(auth.Domain) } ln.Log(connection.F(), ln.F{ "action": "client_disconnecting", }) controlStream.Close() session.Close() c.Close() return } } func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) { var conns []*Connection val, ok := s.domains.Get(req.Host) if ok { conns, ok = val.([]*Connection) if !ok { ln.Error(errors.New("no backend connected"), ln.F{ "action": "no_backend_connected", "remote": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) resp := &http.Response{ StatusCode: http.StatusBadGateway, Body: ioutil.NopCloser(strings.NewReader("no such domain")), ContentLength: 14, Close: true, Request: req, } return resp, errors.New("no backend connected") } } if len(conns) == 0 { ln.Error(errors.New("no backend connected"), ln.F{ "action": "no_backend_connected", "remote": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) return nil, errors.New("no backend connected") } c := conns[rand.Intn(len(conns))] c.conn.SetDeadline(time.Now().Add(time.Second)) stream, err := c.session.OpenStream() if err != nil { ln.Error(err, ln.F{ "action": "opening_session_stream", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }, c.F()) c.cancel() return s.RoundTrip(req) } defer stream.Close() c.conn.SetDeadline(time.Time{}) err = req.Write(stream) if err != nil { ln.Error(err, ln.F{ "action": "request_writing", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }, c.F()) c.cancel() return s.RoundTrip(req) } buf := bufio.NewReader(stream) resp, err := http.ReadResponse(buf, req) if err != nil { ln.Error(err, ln.F{ "action": "response_reading", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }, c.F()) c.cancel() return nil, err } ln.Log(c.F(), ln.F{ "action": "http_traffic", "remote_addr": req.RemoteAddr, "host": req.Host, "uri": req.RequestURI, }) return resp, nil } type Auth struct { Token string `json:"token"` Domain string `json:"domain"` } const defaultUser = "Cadey"