route/lib/tunnel/tunneltest/tunneltest.go

562 lines
13 KiB
Go

package tunneltest
import (
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"os"
"sort"
"strconv"
"sync"
"testing"
"time"
"git.xeserv.us/xena/route/lib/tunnel"
)
var debugNet = os.Getenv("DEBUGNET") == "1"
type dbgListener struct {
net.Listener
}
func (l dbgListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return dbgConn{conn}, nil
}
type dbgConn struct {
net.Conn
}
func (c dbgConn) Read(p []byte) (int, error) {
n, err := c.Conn.Read(p)
os.Stderr.Write(p)
return n, err
}
func (c dbgConn) Write(p []byte) (int, error) {
n, err := c.Conn.Write(p)
os.Stderr.Write(p)
return n, err
}
func logf(format string, args ...interface{}) {
if testing.Verbose() {
log.Printf("[tunneltest] "+format, args...)
}
}
func nonil(err ...error) error {
for _, e := range err {
if e != nil {
return e
}
}
return nil
}
func parseHostPort(addr string) (string, int, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return "", 0, err
}
n, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return "", 0, err
}
return host, int(n), nil
}
// UsableAddrs returns all tcp addresses that we can bind a listener to.
func UsableAddrs() ([]*net.TCPAddr, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil, err
}
var usable []*net.TCPAddr
for _, addr := range addrs {
if ipNet, ok := addr.(*net.IPNet); ok {
if !ipNet.IP.IsLinkLocalUnicast() {
usable = append(usable, &net.TCPAddr{IP: ipNet.IP})
}
}
}
if len(usable) == 0 {
return nil, errors.New("no usable addresses found")
}
return usable, nil
}
const (
TypeHTTP = iota
TypeTCP
)
// Tunnel represents a single HTTP or TCP tunnel that can be served
// by TunnelTest.
type Tunnel struct {
// Type specifies a tunnel type - either TypeHTTP (default) or TypeTCP.
Type int
// Handler is a handler to use for serving tunneled connections on
// local server. The value of this field is required to be of type:
//
// - http.Handler or http.HandlerFunc for HTTP tunnels
// - func(net.Conn) for TCP tunnels
//
// Required field.
Handler interface{}
// LocalAddr is a network address of local server that handles
// connections/requests with Handler.
//
// Optional field, takes value of "127.0.0.1:0" when empty.
LocalAddr string
// ClientIdent is an identifier of a client that have already
// registered a HTTP tunnel and have established control connection.
//
// If the Type is TypeTCP, instead of creating new client
// for this TCP tunnel, we add it to an existing client
// specified by the field.
//
// Optional field for TCP tunnels.
// Ignored field for HTTP tunnels.
ClientIdent string
// RemoteAddr is a network address of remote server, which accepts
// connections on a tunnel server side.
//
// Required field for TCP tunnels.
// Ignored field for HTTP tunnels.
RemoteAddr string
// RemoteAddrIdent an identifier of an already existing listener,
// that listens on multiple interfaces; if the RemoteAddrIdent is valid
// identifier the IP field is required to be non-nil and RemoteAddr
// is ignored.
//
// Optional field for TCP tunnels.
// Ignored field for HTTP tunnels.
RemoteAddrIdent string
// IP specifies an IP address value for IP-based routing for TCP tunnels.
// For more details see inline documentation for (*tunnel.Server).AddAddr.
//
// Optional field for TCP tunnels.
// Ignored field for HTTP tunnels.
IP net.IP
// StateChanges listens on state transitions.
//
// If ClientIdent field is empty, the StateChanges will receive
// state transition events for the newly created client.
// Otherwise setting this field is a nop.
StateChanges chan<- *tunnel.ClientStateChange
}
type TunnelTest struct {
Server *tunnel.Server
ServerStateRecorder *StateRecorder
Clients map[string]*tunnel.Client
Listeners map[string][2]net.Listener // [0] is local listener, [1] is remote one (for TCP tunnels)
Addrs []*net.TCPAddr
Tunnels map[string]*Tunnel
DebugNet bool // for debugging network communication
mu sync.Mutex // protects Listeners
}
func NewTunnelTest() (*TunnelTest, error) {
rec := NewStateRecorder()
cfg := &tunnel.ServerConfig{
StateChanges: rec.C(),
Debug: testing.Verbose(),
}
s, err := tunnel.NewServer(cfg)
if err != nil {
return nil, err
}
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
}
if debugNet {
l = dbgListener{l}
}
addrs, err := UsableAddrs()
if err != nil {
return nil, err
}
go (&http.Server{Handler: s}).Serve(l)
return &TunnelTest{
Server: s,
ServerStateRecorder: rec,
Clients: make(map[string]*tunnel.Client),
Listeners: map[string][2]net.Listener{"": {l, nil}},
Addrs: addrs,
Tunnels: make(map[string]*Tunnel),
DebugNet: debugNet,
}, nil
}
// Serve creates new TunnelTest that serves the given tunnels.
//
// If tunnels is nil, DefaultTunnels() are used instead.
func Serve(tunnels map[string]*Tunnel) (*TunnelTest, error) {
tt, err := NewTunnelTest()
if err != nil {
return nil, err
}
if err = tt.Serve(tunnels); err != nil {
return nil, err
}
return tt, nil
}
func (tt *TunnelTest) serveSingle(ident string, t *Tunnel) (bool, error) {
// Verify tunnel dependencies for TCP tunnels.
if t.Type == TypeTCP {
// If tunnel specified by t.Client was not already started,
// skip and move on.
if _, ok := tt.Clients[t.ClientIdent]; t.ClientIdent != "" && !ok {
return false, nil
}
// Verify the TCP tunnel whose remote endpoint listens on multiple
// interfaces is already served.
if t.RemoteAddrIdent != "" {
if _, ok := tt.Listeners[t.RemoteAddrIdent]; !ok {
return false, nil
}
if tt.Tunnels[t.RemoteAddrIdent].Type != TypeTCP {
return false, fmt.Errorf("expected tunnel %q to be of TCP type", t.RemoteAddrIdent)
}
}
}
l, err := net.Listen("tcp", t.LocalAddr)
if err != nil {
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.LocalAddr, ident, err)
}
if tt.DebugNet {
l = dbgListener{l}
}
localAddr := l.Addr().String()
httpProxy := &tunnel.HTTPProxy{LocalAddr: localAddr}
tcpProxy := &tunnel.TCPProxy{FetchLocalAddr: tt.fetchLocalAddr}
cfg := &tunnel.ClientConfig{
Identifier: ident,
ServerAddr: tt.ServerAddr().String(),
Proxy: tunnel.Proxy(tunnel.ProxyFuncs{
HTTP: httpProxy.Proxy,
TCP: tcpProxy.Proxy,
}),
StateChanges: t.StateChanges,
Debug: testing.Verbose(),
}
// Register tunnel:
//
// - start tunnel.Client (tt.Clients[ident]) or reuse existing one (tt.Clients[t.ExistingClient])
// - listen on local address and start local server (tt.Listeners[ident][0])
// - register tunnel on tunnel.Server
//
switch t.Type {
case TypeHTTP:
// TODO(rjeczalik): refactor to separate method
h, ok := t.Handler.(http.Handler)
if !ok {
h, ok = t.Handler.(http.HandlerFunc)
if !ok {
fn, ok := t.Handler.(func(http.ResponseWriter, *http.Request))
if !ok {
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
}
h = http.HandlerFunc(fn)
}
}
logf("serving on local %s for HTTP tunnel %q", l.Addr(), ident)
go (&http.Server{Handler: h}).Serve(l)
tt.Server.AddHost(localAddr, ident)
tt.mu.Lock()
tt.Listeners[ident] = [2]net.Listener{l, nil}
tt.mu.Unlock()
if err := tt.addClient(ident, cfg); err != nil {
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
}
logf("registered HTTP tunnel: host=%s, ident=%s", localAddr, ident)
case TypeTCP:
// TODO(rjeczalik): refactor to separate method
h, ok := t.Handler.(func(net.Conn))
if !ok {
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
}
logf("serving on local %s for TCP tunnel %q", l.Addr(), ident)
go func() {
for {
conn, err := l.Accept()
if err != nil {
log.Printf("failed accepting conn for %q tunnel: %s", ident, err)
return
}
go h(conn)
}
}()
var remote net.Listener
if t.RemoteAddrIdent != "" {
tt.mu.Lock()
remote = tt.Listeners[t.RemoteAddrIdent][1]
tt.mu.Unlock()
} else {
remote, err = net.Listen("tcp", t.RemoteAddr)
if err != nil {
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.RemoteAddr, ident, err)
}
}
// addrIdent holds identifier of client which is going to have registered
// tunnel via (*tunnel.Server).AddAddr
addrIdent := ident
if t.ClientIdent != "" {
tt.Clients[ident] = tt.Clients[t.ClientIdent]
addrIdent = t.ClientIdent
}
tt.Server.AddAddr(remote, t.IP, addrIdent)
tt.mu.Lock()
tt.Listeners[ident] = [2]net.Listener{l, remote}
tt.mu.Unlock()
if _, ok := tt.Clients[ident]; !ok {
if err := tt.addClient(ident, cfg); err != nil {
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
}
}
logf("registered TCP tunnel: listener=%s, ip=%v, ident=%s", remote.Addr(), t.IP, addrIdent)
default:
return false, fmt.Errorf("unknown %q tunnel type: %d", ident, t.Type)
}
return true, nil
}
func (tt *TunnelTest) addClient(ident string, cfg *tunnel.ClientConfig) error {
if _, ok := tt.Clients[ident]; ok {
return fmt.Errorf("tunnel %q is already being served", ident)
}
c, err := tunnel.NewClient(cfg)
if err != nil {
return err
}
done := make(chan struct{})
tt.Server.OnConnect(ident, func() error {
close(done)
return nil
})
go c.Start()
<-c.StartNotify()
select {
case <-time.After(10 * time.Second):
return errors.New("timed out after 10s waiting on client to establish control conn")
case <-done:
}
tt.Clients[ident] = c
return nil
}
func (tt *TunnelTest) Serve(tunnels map[string]*Tunnel) error {
if len(tunnels) == 0 {
return errors.New("no tunnels to serve")
}
// Since one tunnels depends on others do 3 passes to start them
// all, each started tunnel is removed from the tunnels map.
// After 3 passes all of them must be started, otherwise the
// configuration is bad:
//
// - first pass starts HTTP tunnels as new client tunnels
// - second pass starts TCP tunnels that rely on on already existing client tunnels (t.ClientIdent)
// - third pass starts TCP tunnels that rely on on already existing TCP tunnels (t.RemoteAddrIdent)
//
for i := 0; i < 3; i++ {
if err := tt.popServedDeps(tunnels); err != nil {
return err
}
}
if len(tunnels) != 0 {
unresolved := make([]string, len(tunnels))
for ident := range tunnels {
unresolved = append(unresolved, ident)
}
sort.Strings(unresolved)
return fmt.Errorf("unable to start tunnels due to unresolved dependencies: %v", unresolved)
}
return nil
}
func (tt *TunnelTest) popServedDeps(tunnels map[string]*Tunnel) error {
for ident, t := range tunnels {
ok, err := tt.serveSingle(ident, t)
if err != nil {
return err
}
if ok {
// Remove already started tunnels so they won't get started again.
delete(tunnels, ident)
tt.Tunnels[ident] = t
}
}
return nil
}
func (tt *TunnelTest) fetchLocalAddr(port int) (string, error) {
tt.mu.Lock()
defer tt.mu.Unlock()
for _, l := range tt.Listeners {
if l[1] == nil {
// this listener does not belong to a TCP tunnel
continue
}
_, remotePort, err := parseHostPort(l[1].Addr().String())
if err != nil {
return "", err
}
if port == remotePort {
return l[0].Addr().String(), nil
}
}
return "", fmt.Errorf("no route for %d port", port)
}
func (tt *TunnelTest) ServerAddr() net.Addr {
return tt.Listeners[""][0].Addr()
}
// Addr gives server endpoint of the TCP tunnel for the given ident.
//
// If the tunnel does not exist or is a HTTP one, TunnelAddr return nil.
func (tt *TunnelTest) Addr(ident string) net.Addr {
l, ok := tt.Listeners[ident]
if !ok {
return nil
}
return l[1].Addr()
}
// Request creates a HTTP request to a server endpoint of the HTTP tunnel
// for the given ident.
//
// If the tunnel does not exist, Request returns nil.
func (tt *TunnelTest) Request(ident string, query url.Values) *http.Request {
l, ok := tt.Listeners[ident]
if !ok {
return nil
}
var raw string
if query != nil {
raw = query.Encode()
}
return &http.Request{
Method: "GET",
URL: &url.URL{
Scheme: "http",
Host: tt.ServerAddr().String(),
Path: "/",
RawQuery: raw,
},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Host: l[0].Addr().String(),
}
}
func (tt *TunnelTest) Close() (err error) {
// Close tunnel.Clients.
clients := make(map[*tunnel.Client]struct{})
for _, c := range tt.Clients {
clients[c] = struct{}{}
}
for c := range clients {
err = nonil(err, c.Close())
}
// Stop all TCP/HTTP servers.
listeners := make(map[net.Listener]struct{})
for _, l := range tt.Listeners {
for _, l := range l {
if l != nil {
listeners[l] = struct{}{}
}
}
}
for l := range listeners {
err = nonil(err, l.Close())
}
return err
}