tun2: on disconnection, explode gently
This commit is contained in:
parent
89943398fc
commit
70191a0b88
|
@ -2,6 +2,7 @@ package tun2
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -50,6 +51,7 @@ type Connection struct {
|
|||
controlStream *smux.Stream
|
||||
user string
|
||||
domain string
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewServer(cfg *ServerConfig) (*Server, error) {
|
||||
|
@ -141,6 +143,9 @@ func (s *Server) ListenAndServe() error {
|
|||
}
|
||||
|
||||
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
session, err := smux.Server(c, s.cfg.SmuxConf)
|
||||
if err != nil {
|
||||
ln.Error(err, ln.F{
|
||||
|
@ -223,6 +228,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
session: session,
|
||||
user: defaultUser, // XXX RIP replace this with the actual token user once users are implemented
|
||||
domain: auth.Domain,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
ln.Log(ln.F{
|
||||
|
@ -231,6 +237,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
"kcp": isKCP,
|
||||
"domain": auth.Domain,
|
||||
"user": connection.user,
|
||||
"id": connection.id,
|
||||
})
|
||||
|
||||
s.connlock.Lock()
|
||||
|
@ -240,6 +247,33 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
|||
s.domainlock.Lock()
|
||||
s.domains[auth.Domain] = append(s.domains[auth.Domain], connection)
|
||||
s.domainlock.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.connlock.Lock()
|
||||
delete(s.conns, c)
|
||||
s.connlock.Unlock()
|
||||
|
||||
s.domainlock.Lock()
|
||||
|
||||
for i, cntn := range s.domains[auth.Domain] {
|
||||
if cntn.id == connection.id {
|
||||
s.domains[auth.Domain][i] = s.domains[auth.Domain][len(s.domains[auth.Domain])-1]
|
||||
s.domains[auth.Domain] = s.domains[auth.Domain][:len(s.domains[auth.Domain])-1]
|
||||
}
|
||||
}
|
||||
|
||||
s.domainlock.Unlock()
|
||||
|
||||
ln.Log(ln.F{
|
||||
"action": "client_disconnecting",
|
||||
"remote": c.RemoteAddr(),
|
||||
"domain": auth.Domain,
|
||||
"id": connection.id,
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
|
|
Loading…
Reference in New Issue