diff --git a/lib/tun2/server.go b/lib/tun2/server.go index fa0fd2f..57ef639 100644 --- a/lib/tun2/server.go +++ b/lib/tun2/server.go @@ -244,6 +244,26 @@ func (c *Connection) OpenStream() (net.Conn, error) { return stream, c.conn.SetDeadline(time.Time{}) } +// Close destroys resouces specific to the connection. +func (c *Connection) Close() error { + err := c.controlStream.Close() + if err != nil { + return err + } + + err = c.session.Close() + if err != nil { + return err + } + + err = c.conn.Close() + if err != nil { + return err + } + + return nil +} + func (s *Server) HandleConn(c net.Conn, isKCP bool) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -369,49 +389,50 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { cancel() } case <-ctx.Done(): - s.connlock.Lock() - delete(s.conns, c) - s.connlock.Unlock() - - var conns []*Connection - - val, ok := s.domains.Get(auth.Domain) - if ok { - conns, ok = val.([]*Connection) - if !ok { - ln.Error(err, connection.F(), ln.F{ - "action": "looking_up_for_disconnect_removal", - }) - return - } - } - - for i, cntn := range conns { - if cntn.id == connection.id { - conns[i] = conns[len(conns)-1] - conns = conns[:len(conns)-1] - } - } - - if len(conns) != 0 { - s.domains.Set(auth.Domain, conns) - } else { - s.domains.Remove(auth.Domain) - } - - ln.Log(connection.F(), ln.F{ - "action": "client_disconnecting", - }) - - controlStream.Close() - session.Close() - c.Close() + s.RemoveConn(auth, connection) + connection.Close() return } } } +func (s *Server) RemoveConn(auth *Auth, connection *Connection) { + s.connlock.Lock() + delete(s.conns, connection.conn) + s.connlock.Unlock() + + var conns []*Connection + + val, ok := s.domains.Get(auth.Domain) + if ok { + conns, ok = val.([]*Connection) + if !ok { + ln.Error(errors.New("fundamental assertion is not met"), connection.F(), ln.F{ + "action": "looking_up_for_disconnect_removal", + }) + return + } + } + + for i, cntn := range conns { + if cntn.id == connection.id { + conns[i] = conns[len(conns)-1] + conns = conns[:len(conns)-1] + } + } + + if len(conns) != 0 { + s.domains.Set(auth.Domain, conns) + } else { + s.domains.Remove(auth.Domain) + } + + ln.Log(connection.F(), ln.F{ + "action": "client_disconnecting", + }) +} + func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) { var conns []*Connection