tun2: use a better hashmap
This commit is contained in:
parent
ecc31b1eb7
commit
a02553e0ef
|
@ -7,15 +7,18 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/database"
|
"git.xeserv.us/xena/route/database"
|
||||||
"github.com/Xe/ln"
|
"github.com/Xe/ln"
|
||||||
"github.com/mtneug/pkg/ulid"
|
"github.com/mtneug/pkg/ulid"
|
||||||
|
cmap "github.com/streamrail/concurrent-map"
|
||||||
kcp "github.com/xtaci/kcp-go"
|
kcp "github.com/xtaci/kcp-go"
|
||||||
"github.com/xtaci/smux"
|
"github.com/xtaci/smux"
|
||||||
)
|
)
|
||||||
|
@ -40,8 +43,7 @@ type Server struct {
|
||||||
connlock sync.Mutex
|
connlock sync.Mutex
|
||||||
conns map[net.Conn]*Connection
|
conns map[net.Conn]*Connection
|
||||||
|
|
||||||
domainlock sync.Mutex
|
domains cmap.ConcurrentMap
|
||||||
domains map[string][]*Connection
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Connection struct {
|
type Connection struct {
|
||||||
|
@ -79,7 +81,7 @@ func NewServer(cfg *ServerConfig) (*Server, error) {
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
|
||||||
conns: map[net.Conn]*Connection{},
|
conns: map[net.Conn]*Connection{},
|
||||||
domains: map[string][]*Connection{},
|
domains: cmap.New(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return server, nil
|
return server, nil
|
||||||
|
@ -297,9 +299,21 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
s.conns[c] = connection
|
s.conns[c] = connection
|
||||||
s.connlock.Unlock()
|
s.connlock.Unlock()
|
||||||
|
|
||||||
s.domainlock.Lock()
|
var conns []*Connection
|
||||||
s.domains[auth.Domain] = append(s.domains[auth.Domain], connection)
|
|
||||||
s.domainlock.Unlock()
|
val, ok := s.domains.Get(auth.Domain)
|
||||||
|
if ok {
|
||||||
|
conns, ok = val.([]*Connection)
|
||||||
|
if !ok {
|
||||||
|
conns = nil
|
||||||
|
|
||||||
|
s.domains.Remove(auth.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conns = append(conns, connection)
|
||||||
|
|
||||||
|
s.domains.Set(auth.Domain, conns)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -307,16 +321,31 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
delete(s.conns, c)
|
delete(s.conns, c)
|
||||||
s.connlock.Unlock()
|
s.connlock.Unlock()
|
||||||
|
|
||||||
s.domainlock.Lock()
|
var conns []*Connection
|
||||||
|
|
||||||
for i, cntn := range s.domains[auth.Domain] {
|
val, ok := s.domains.Get(auth.Domain)
|
||||||
if cntn.id == connection.id {
|
if ok {
|
||||||
s.domains[auth.Domain][i] = s.domains[auth.Domain][len(s.domains[auth.Domain])-1]
|
conns, ok = val.([]*Connection)
|
||||||
s.domains[auth.Domain] = s.domains[auth.Domain][:len(s.domains[auth.Domain])-1]
|
if !ok {
|
||||||
|
ln.Error(err, connection.F(), ln.F{
|
||||||
|
"action": "looking_up_for_disconnect_removal",
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.domainlock.Unlock()
|
for i, cntn := range conns {
|
||||||
|
if cntn.id == connection.id {
|
||||||
|
conns[i] = conns[len(conns)-1]
|
||||||
|
conns = conns[:len(conns)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(conns) != 0 {
|
||||||
|
s.domains.Set(auth.Domain, conns)
|
||||||
|
} else {
|
||||||
|
s.domains.Remove(auth.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
ln.Log(connection.F(), ln.F{
|
ln.Log(connection.F(), ln.F{
|
||||||
"action": "client_disconnecting",
|
"action": "client_disconnecting",
|
||||||
|
@ -331,11 +360,29 @@ func (s *Server) HandleConn(c net.Conn, isKCP bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
s.domainlock.Lock()
|
var conns []*Connection
|
||||||
conns, ok := s.domains[req.Host]
|
|
||||||
s.domainlock.Unlock()
|
val, ok := s.domains.Get(req.Host)
|
||||||
if !ok {
|
if ok {
|
||||||
return nil, errors.New("domain not found")
|
conns, ok = val.([]*Connection)
|
||||||
|
if !ok {
|
||||||
|
ln.Error(errors.New("no backend connected"), ln.F{
|
||||||
|
"action": "no_backend_connected",
|
||||||
|
"remote": req.RemoteAddr,
|
||||||
|
"host": req.Host,
|
||||||
|
"uri": req.RequestURI,
|
||||||
|
})
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
Body: ioutil.NopCloser(strings.NewReader("no such domain")),
|
||||||
|
ContentLength: 14,
|
||||||
|
Close: true,
|
||||||
|
Request: req,
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c := conns[rand.Intn(len(conns))]
|
c := conns[rand.Intn(len(conns))]
|
||||||
|
|
Loading…
Reference in New Issue