180 lines
3.9 KiB
Go
180 lines
3.9 KiB
Go
package tunnel
|
|
|
|
import (
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/koding/logging"
|
|
)
|
|
|
|
type listener struct {
|
|
net.Listener
|
|
*vaddrOptions
|
|
|
|
done int32
|
|
|
|
// ips keeps track of registered clients for ip-based routing;
|
|
// when last client is deleted from the ip routing map, we stop
|
|
// listening on connections
|
|
ips map[string]struct{}
|
|
}
|
|
|
|
type vaddrOptions struct {
|
|
connCh chan<- net.Conn
|
|
log logging.Logger
|
|
}
|
|
|
|
type vaddrStorage struct {
|
|
*vaddrOptions
|
|
|
|
listeners map[net.Listener]*listener
|
|
ports map[int]string // port-based routing: maps port number to identifier
|
|
ips map[string]string // ip-based routing: maps ip address to identifier
|
|
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func newVirtualAddrs(opts *vaddrOptions) *vaddrStorage {
|
|
return &vaddrStorage{
|
|
vaddrOptions: opts,
|
|
listeners: make(map[net.Listener]*listener),
|
|
ports: make(map[int]string),
|
|
ips: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
func (l *listener) serve() {
|
|
for {
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
l.log.Error("failue listening on %q: %s", l.Addr(), err)
|
|
return
|
|
}
|
|
|
|
if atomic.LoadInt32(&l.done) != 0 {
|
|
l.log.Debug("stopped serving %q", l.Addr())
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
l.connCh <- conn
|
|
}
|
|
}
|
|
|
|
func (l *listener) localAddr() string {
|
|
if addr, ok := l.Addr().(*net.TCPAddr); ok {
|
|
if addr.IP.Equal(net.IPv4zero) {
|
|
return net.JoinHostPort("127.0.0.1", strconv.Itoa(addr.Port))
|
|
}
|
|
}
|
|
return l.Addr().String()
|
|
}
|
|
|
|
func (l *listener) stop() {
|
|
if atomic.CompareAndSwapInt32(&l.done, 0, 1) {
|
|
// stop is called when no more connections should be accepted by
|
|
// the user-provided listener; as we can't simple close the listener
|
|
// to not break the guarantee given by the (*Server).DeleteAddr
|
|
// method, we make a dummy connection to break out of serve loop.
|
|
// It is safe to make a dummy connection, as either the following
|
|
// dial will time out when the listener is busy accepting connections,
|
|
// or will get closed immadiately after idle listeners accepts connection
|
|
// and returns from the serve loop.
|
|
conn, err := net.DialTimeout("tcp", l.localAddr(), defaultTimeout)
|
|
if err == nil {
|
|
conn.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (vaddr *vaddrStorage) Add(l net.Listener, ip net.IP, ident string) {
|
|
vaddr.mu.Lock()
|
|
defer vaddr.mu.Unlock()
|
|
|
|
lis, ok := vaddr.listeners[l]
|
|
if !ok {
|
|
lis = vaddr.newListener(l)
|
|
vaddr.listeners[l] = lis
|
|
go lis.serve()
|
|
}
|
|
|
|
if ip != nil {
|
|
lis.ips[ip.String()] = struct{}{}
|
|
vaddr.ips[ip.String()] = ident
|
|
} else {
|
|
vaddr.ports[mustPort(l)] = ident
|
|
}
|
|
}
|
|
|
|
func (vaddr *vaddrStorage) Delete(l net.Listener, ip net.IP) {
|
|
vaddr.mu.Lock()
|
|
defer vaddr.mu.Unlock()
|
|
|
|
lis, ok := vaddr.listeners[l]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
var stop bool
|
|
|
|
if ip != nil {
|
|
delete(lis.ips, ip.String())
|
|
delete(vaddr.ips, ip.String())
|
|
|
|
stop = len(lis.ips) == 0
|
|
} else {
|
|
delete(vaddr.ports, mustPort(l))
|
|
|
|
stop = true
|
|
}
|
|
|
|
// Only stop listening for connections when listener has clients
|
|
// registered to tunnel the connections to.
|
|
if stop {
|
|
lis.stop()
|
|
delete(vaddr.listeners, l)
|
|
}
|
|
}
|
|
|
|
func (vaddr *vaddrStorage) newListener(l net.Listener) *listener {
|
|
return &listener{
|
|
Listener: l,
|
|
vaddrOptions: vaddr.vaddrOptions,
|
|
ips: make(map[string]struct{}),
|
|
}
|
|
}
|
|
|
|
func (vaddr *vaddrStorage) getIdent(conn net.Conn) (string, bool) {
|
|
vaddr.mu.Lock()
|
|
defer vaddr.mu.Unlock()
|
|
|
|
ip, port, err := parseHostPort(conn.LocalAddr().String())
|
|
if err != nil {
|
|
vaddr.log.Debug("failed to get identifier for connection %q: %s", conn.LocalAddr(), err)
|
|
return "", false
|
|
}
|
|
|
|
// First lookup if there's a ip-based route, then try port-base one.
|
|
|
|
if ident, ok := vaddr.ips[ip]; ok {
|
|
return ident, true
|
|
}
|
|
|
|
ident, ok := vaddr.ports[port]
|
|
return ident, ok
|
|
}
|
|
|
|
func mustPort(l net.Listener) int {
|
|
_, port, err := parseHostPort(l.Addr().String())
|
|
if err != nil {
|
|
// This can happened when user passed custom type that
|
|
// implements net.Listener, which returns ill-formed
|
|
// net.Addr value.
|
|
panic("ill-formed net.Addr: " + err.Error())
|
|
}
|
|
|
|
return port
|
|
}
|