tun2: on disconnection, explode gently

This commit is contained in:
Cadey Ratio 2017-03-26 13:38:05 -07:00
parent 89943398fc
commit 70191a0b88
1 changed files with 34 additions and 0 deletions

View File

@ -2,6 +2,7 @@ package tun2
import ( import (
"bufio" "bufio"
"context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"errors" "errors"
@ -50,6 +51,7 @@ type Connection struct {
controlStream *smux.Stream controlStream *smux.Stream
user string user string
domain string domain string
cancel context.CancelFunc
} }
func NewServer(cfg *ServerConfig) (*Server, error) { func NewServer(cfg *ServerConfig) (*Server, error) {
@ -141,6 +143,9 @@ func (s *Server) ListenAndServe() error {
} }
func (s *Server) HandleConn(c net.Conn, isKCP bool) { func (s *Server) HandleConn(c net.Conn, isKCP bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
session, err := smux.Server(c, s.cfg.SmuxConf) session, err := smux.Server(c, s.cfg.SmuxConf)
if err != nil { if err != nil {
ln.Error(err, ln.F{ ln.Error(err, ln.F{
@ -223,6 +228,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
session: session, session: session,
user: defaultUser, // XXX RIP replace this with the actual token user once users are implemented user: defaultUser, // XXX RIP replace this with the actual token user once users are implemented
domain: auth.Domain, domain: auth.Domain,
cancel: cancel,
} }
ln.Log(ln.F{ ln.Log(ln.F{
@ -231,6 +237,7 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
"kcp": isKCP, "kcp": isKCP,
"domain": auth.Domain, "domain": auth.Domain,
"user": connection.user, "user": connection.user,
"id": connection.id,
}) })
s.connlock.Lock() s.connlock.Lock()
@ -240,6 +247,33 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
s.domainlock.Lock() s.domainlock.Lock()
s.domains[auth.Domain] = append(s.domains[auth.Domain], connection) s.domains[auth.Domain] = append(s.domains[auth.Domain], connection)
s.domainlock.Unlock() s.domainlock.Unlock()
select {
case <-ctx.Done():
s.connlock.Lock()
delete(s.conns, c)
s.connlock.Unlock()
s.domainlock.Lock()
for i, cntn := range s.domains[auth.Domain] {
if cntn.id == connection.id {
s.domains[auth.Domain][i] = s.domains[auth.Domain][len(s.domains[auth.Domain])-1]
s.domains[auth.Domain] = s.domains[auth.Domain][:len(s.domains[auth.Domain])-1]
}
}
s.domainlock.Unlock()
ln.Log(ln.F{
"action": "client_disconnecting",
"remote": c.RemoteAddr(),
"domain": auth.Domain,
"id": connection.id,
})
return
}
} }
func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) { func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {