From a02553e0efad986316280095c6939b3b2ada913f Mon Sep 17 00:00:00 2001 From: Christine Dodrill Date: Sun, 26 Mar 2017 21:56:54 -0700 Subject: [PATCH] tun2: use a better hashmap --- lib/tun2/server.go | 81 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 64 insertions(+), 17 deletions(-) diff --git a/lib/tun2/server.go b/lib/tun2/server.go index ed71cc4..b4ae7ff 100644 --- a/lib/tun2/server.go +++ b/lib/tun2/server.go @@ -7,15 +7,18 @@ import ( "encoding/json" "errors" "fmt" + "io/ioutil" "math/rand" "net" "net/http" + "strings" "sync" "time" "git.xeserv.us/xena/route/database" "github.com/Xe/ln" "github.com/mtneug/pkg/ulid" + cmap "github.com/streamrail/concurrent-map" kcp "github.com/xtaci/kcp-go" "github.com/xtaci/smux" ) @@ -40,8 +43,7 @@ type Server struct { connlock sync.Mutex conns map[net.Conn]*Connection - domainlock sync.Mutex - domains map[string][]*Connection + domains cmap.ConcurrentMap } type Connection struct { @@ -79,7 +81,7 @@ func NewServer(cfg *ServerConfig) (*Server, error) { cfg: cfg, conns: map[net.Conn]*Connection{}, - domains: map[string][]*Connection{}, + domains: cmap.New(), } return server, nil @@ -297,9 +299,21 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { s.conns[c] = connection s.connlock.Unlock() - s.domainlock.Lock() - s.domains[auth.Domain] = append(s.domains[auth.Domain], connection) - s.domainlock.Unlock() + var conns []*Connection + + val, ok := s.domains.Get(auth.Domain) + if ok { + conns, ok = val.([]*Connection) + if !ok { + conns = nil + + s.domains.Remove(auth.Domain) + } + } + + conns = append(conns, connection) + + s.domains.Set(auth.Domain, conns) select { case <-ctx.Done(): @@ -307,16 +321,31 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { delete(s.conns, c) s.connlock.Unlock() - s.domainlock.Lock() + var conns []*Connection - for i, cntn := range s.domains[auth.Domain] { - if cntn.id == connection.id { - s.domains[auth.Domain][i] = s.domains[auth.Domain][len(s.domains[auth.Domain])-1] - s.domains[auth.Domain] = s.domains[auth.Domain][:len(s.domains[auth.Domain])-1] + val, ok := s.domains.Get(auth.Domain) + if ok { + conns, ok = val.([]*Connection) + if !ok { + ln.Error(err, connection.F(), ln.F{ + "action": "looking_up_for_disconnect_removal", + }) + return } } - s.domainlock.Unlock() + for i, cntn := range conns { + if cntn.id == connection.id { + conns[i] = conns[len(conns)-1] + conns = conns[:len(conns)-1] + } + } + + if len(conns) != 0 { + s.domains.Set(auth.Domain, conns) + } else { + s.domains.Remove(auth.Domain) + } ln.Log(connection.F(), ln.F{ "action": "client_disconnecting", @@ -331,11 +360,29 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) { } func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) { - s.domainlock.Lock() - conns, ok := s.domains[req.Host] - s.domainlock.Unlock() - if !ok { - return nil, errors.New("domain not found") + var conns []*Connection + + val, ok := s.domains.Get(req.Host) + if ok { + conns, ok = val.([]*Connection) + if !ok { + ln.Error(errors.New("no backend connected"), ln.F{ + "action": "no_backend_connected", + "remote": req.RemoteAddr, + "host": req.Host, + "uri": req.RequestURI, + }) + + resp := &http.Response{ + StatusCode: http.StatusBadGateway, + Body: ioutil.NopCloser(strings.NewReader("no such domain")), + ContentLength: 14, + Close: true, + Request: req, + } + + return resp, nil + } } c := conns[rand.Intn(len(conns))]