diff --git a/lib/tun2/server.go b/lib/tun2/server.go index 30b48b2..726dd86 100644 --- a/lib/tun2/server.go +++ b/lib/tun2/server.go @@ -2,6 +2,7 @@ package tun2 import ( "bufio" + "context" "crypto/tls" "encoding/json" "errors" @@ -50,6 +51,7 @@ type Connection struct { controlStream *smux.Stream user string domain string + cancel context.CancelFunc } func NewServer(cfg *ServerConfig) (*Server, error) { @@ -141,6 +143,9 @@ func (s *Server) ListenAndServe() error { } func (s *Server) HandleConn(c net.Conn, isKCP bool) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + session, err := smux.Server(c, s.cfg.SmuxConf) if err != nil { ln.Error(err, ln.F{ @@ -223,6 +228,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { session: session, user: defaultUser, // XXX RIP replace this with the actual token user once users are implemented domain: auth.Domain, + cancel: cancel, } ln.Log(ln.F{ @@ -231,6 +237,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { "kcp": isKCP, "domain": auth.Domain, "user": connection.user, + "id": connection.id, }) s.connlock.Lock() @@ -240,6 +247,33 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { s.domainlock.Lock() s.domains[auth.Domain] = append(s.domains[auth.Domain], connection) s.domainlock.Unlock() + + select { + case <-ctx.Done(): + s.connlock.Lock() + delete(s.conns, c) + s.connlock.Unlock() + + s.domainlock.Lock() + + for i, cntn := range s.domains[auth.Domain] { + if cntn.id == connection.id { + s.domains[auth.Domain][i] = s.domains[auth.Domain][len(s.domains[auth.Domain])-1] + s.domains[auth.Domain] = s.domains[auth.Domain][:len(s.domains[auth.Domain])-1] + } + } + + s.domainlock.Unlock() + + ln.Log(ln.F{ + "action": "client_disconnecting", + "remote": c.RemoteAddr(), + "domain": auth.Domain, + "id": connection.id, + }) + + return + } } func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {