tun2: use a better hashmap

This commit is contained in:
Cadey Ratio 2017-03-26 21:56:54 -07:00
parent ecc31b1eb7
commit a02553e0ef
1 changed files with 64 additions and 17 deletions

View File

@ -7,15 +7,18 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"strings"
"sync"
"time"
"git.xeserv.us/xena/route/database"
"github.com/Xe/ln"
"github.com/mtneug/pkg/ulid"
cmap "github.com/streamrail/concurrent-map"
kcp "github.com/xtaci/kcp-go"
"github.com/xtaci/smux"
)
@ -40,8 +43,7 @@ type Server struct {
connlock sync.Mutex
conns map[net.Conn]*Connection
domainlock sync.Mutex
domains map[string][]*Connection
domains cmap.ConcurrentMap
}
type Connection struct {
@ -79,7 +81,7 @@ func NewServer(cfg *ServerConfig) (*Server, error) {
cfg: cfg,
conns: map[net.Conn]*Connection{},
domains: map[string][]*Connection{},
domains: cmap.New(),
}
return server, nil
@ -297,9 +299,21 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
s.conns[c] = connection
s.connlock.Unlock()
s.domainlock.Lock()
s.domains[auth.Domain] = append(s.domains[auth.Domain], connection)
s.domainlock.Unlock()
var conns []*Connection
val, ok := s.domains.Get(auth.Domain)
if ok {
conns, ok = val.([]*Connection)
if !ok {
conns = nil
s.domains.Remove(auth.Domain)
}
}
conns = append(conns, connection)
s.domains.Set(auth.Domain, conns)
select {
case <-ctx.Done():
@ -307,16 +321,31 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
delete(s.conns, c)
s.connlock.Unlock()
s.domainlock.Lock()
var conns []*Connection
for i, cntn := range s.domains[auth.Domain] {
if cntn.id == connection.id {
s.domains[auth.Domain][i] = s.domains[auth.Domain][len(s.domains[auth.Domain])-1]
s.domains[auth.Domain] = s.domains[auth.Domain][:len(s.domains[auth.Domain])-1]
val, ok := s.domains.Get(auth.Domain)
if ok {
conns, ok = val.([]*Connection)
if !ok {
ln.Error(err, connection.F(), ln.F{
"action": "looking_up_for_disconnect_removal",
})
return
}
}
s.domainlock.Unlock()
for i, cntn := range conns {
if cntn.id == connection.id {
conns[i] = conns[len(conns)-1]
conns = conns[:len(conns)-1]
}
}
if len(conns) != 0 {
s.domains.Set(auth.Domain, conns)
} else {
s.domains.Remove(auth.Domain)
}
ln.Log(connection.F(), ln.F{
"action": "client_disconnecting",
@ -331,11 +360,29 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
}
func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {
s.domainlock.Lock()
conns, ok := s.domains[req.Host]
s.domainlock.Unlock()
if !ok {
return nil, errors.New("domain not found")
var conns []*Connection
val, ok := s.domains.Get(req.Host)
if ok {
conns, ok = val.([]*Connection)
if !ok {
ln.Error(errors.New("no backend connected"), ln.F{
"action": "no_backend_connected",
"remote": req.RemoteAddr,
"host": req.Host,
"uri": req.RequestURI,
})
resp := &http.Response{
StatusCode: http.StatusBadGateway,
Body: ioutil.NopCloser(strings.NewReader("no such domain")),
ContentLength: 14,
Close: true,
Request: req,
}
return resp, nil
}
}
c := conns[rand.Intn(len(conns))]