import github.com/koding/tunnel

This commit is contained in:
Cadey Ratio 2017-01-19 17:27:14 -08:00
parent 551017e893
commit 86be40fea0
21 changed files with 3823 additions and 0 deletions

19
lib/tunnel/.travis.yml Normal file
View File

@ -0,0 +1,19 @@
language: go
sudo: false
addons:
apt:
packages:
- moreutils
go:
- 1.4.3
- 1.6.3
- 1.7
script:
- export GOMAXPROCS=$(nproc)
- gofmt -s -l . | ifne false
- go build ./...
- go test -race ./...

28
lib/tunnel/LICENSE Normal file
View File

@ -0,0 +1,28 @@
Copyright (c) 2015 The Koding Authors.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Koding Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

91
lib/tunnel/README.md Normal file
View File

@ -0,0 +1,91 @@
# Tunnel [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/koding/tunnel) [![Go Report Card](https://goreportcard.com/badge/github.com/koding/tunnel)](https://goreportcard.com/report/github.com/koding/tunnel) [![Build Status](http://img.shields.io/travis/koding/tunnel.svg?style=flat-square)](https://travis-ci.org/koding/tunnel)
Tunnel is a server/client package that enables to proxy public connections to
your local machine over a tunnel connection from the local machine to the
public server. What this means is, you can share your localhost even if it
doesn't have a Public IP or if it's not reachable from outside.
It uses the excellent [yamux](https://github.com/hashicorp/yamux) package to
multiplex connections between server and client.
The project is under active development, please vendor it if you want to use it.
# Usage
The tunnel package consists of two parts. The `server` and the `client`.
Server is the public facing part. It's type that satisfies the `http.Handler`.
So it's easily pluggable into existing servers.
Let assume that you setup your DNS service so all `*.example.com` domains route
to your server at the public IP `203.0.113.0`. Let us first create the server
part:
```go
package main
import (
"net/http"
"github.com/koding/tunnel"
)
func main() {
cfg := &tunnel.ServerConfig{}
server, _ := tunnel.NewServer(cfg)
server.AddHost("sub.example.com", "1234")
http.ListenAndServe(":80", server)
}
```
Once you create the `server`, you just plug it into your server. The only
detail here is to map a virtualhost to a secret token. The secret token is the
only part that needs to be known for the client side.
Let us now create the client side part:
```go
package main
import "github.com/koding/tunnel"
func main() {
cfg := &tunnel.ClientConfig{
Identifier: "1234",
ServerAddr: "203.0.113.0:80",
}
client, err := tunnel.NewClient(cfg)
if err != nil {
panic(err)
}
client.Start()
}
```
The `Start()` method is by default blocking. As you see you, we just passed the
server address and the secret token.
Now whenever someone hit `sub.example.com`, the request will be proxied to the
machine where client is running and hit the local server running `127.0.0.1:80`
(assuming there is one). If someone hits `sub.example.com:3000` (assume your
server is running at this port), it'll be routed to `127.0.0.1:3000`
That's it.
There are many options that can be changed, such as a static local address for
your client. Have alook at the
[documentation](http://godoc.org/github.com/koding/tunnel)
# Protocol
The server/client protocol is written in the [spec.md](spec.md) file. Please
have a look for more detail.
## License
The BSD 3-Clause License - see LICENSE for more details

565
lib/tunnel/client.go Normal file
View File

@ -0,0 +1,565 @@
package tunnel
import (
"bufio"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/koding/logging"
"git.xeserv.us/xena/route/lib/tunnel/proto"
"github.com/hashicorp/yamux"
)
//go:generate stringer -type ClientState
// ErrRedialAborted is emitted on ClientClosed event, when backoff policy
// used by a client decided no more reconnection attempts must be made.
var ErrRedialAborted = errors.New("unable to restore the connection, aborting")
// ClientState represents client connection state to tunnel server.
type ClientState uint32
// ClientState enumeration.
const (
ClientUnknown ClientState = iota
ClientStarted
ClientConnecting
ClientConnected
ClientDisconnected
ClientClosed // keep it always last
)
// ClientStateChange represents single client state transition.
type ClientStateChange struct {
Identifier string
Previous ClientState
Current ClientState
Error error
}
// Strings implements the fmt.Stringer interface.
func (cs *ClientStateChange) String() string {
if cs.Error != nil {
return fmt.Sprintf("[%s] %s->%s (%s)", cs.Identifier, cs.Previous, cs.Current, cs.Error)
}
return fmt.Sprintf("[%s] %s->%s", cs.Identifier, cs.Previous, cs.Current)
}
// Backoff defines behavior of staggering reconnection retries.
type Backoff interface {
// Next returns the duration to sleep before retrying reconnections.
// If the returned value is negative, the retry is aborted.
NextBackOff() time.Duration
// Reset is used to signal a reconnection was successful and next
// call to Next should return desired time duration for 1st reconnection
// attempt.
Reset()
}
// Client is responsible for creating a control connection to a tunnel server,
// creating new tunnels and proxy them to tunnel server.
type Client struct {
// underlying yamux session
session *yamux.Session
// config holds the ClientConfig
config *ClientConfig
// yamuxConfig is passed to new yamux.Session's
yamuxConfig *yamux.Config
// proxy handles local server communication.
proxy ProxyFunc
// startNotify is a chanel user can get to be notified when client is
// connected to the server. The preferred way of doing this however,
// would be using StateChanges in ClientConfig where user can provide
// his own channel.
startNotify chan bool
// closed is a flag set when client calls Close() and quits.
closed bool
// closedMu guards both closed flag and startNotify channel. Since library
// owns the channel it's cleared when trying to reconnect.
closedMu sync.RWMutex
reqWg sync.WaitGroup
ctrlWg sync.WaitGroup
state ClientState
// redialBackoff is used to reconnect in exponential backoff intervals
redialBackoff Backoff
log logging.Logger
}
// ClientConfig defines the configuration for the Client
type ClientConfig struct {
// Identifier is the secret token that needs to be passed to the server.
// Required if FetchIdentifier is not set.
Identifier string
// FetchIdentifier can be used to fetch identifier. Required if Identifier
// is not set.
FetchIdentifier func() (string, error)
// ServerAddr defines the TCP address of the tunnel server to be connected.
// Required if FetchServerAddr is not set.
ServerAddr string
// FetchServerAddr can be used to fetch tunnel server address.
// Required if ServerAddress is not set.
FetchServerAddr func() (string, error)
// Dial provides custom transport layer for client server communication.
//
// If nil, default implementation is to return net.Dial("tcp", address).
//
// It can be used for connection monitoring, setting different timeouts or
// securing the connection.
Dial func(network, address string) (net.Conn, error)
// Proxy defines custom proxing logic. This is optional extension point
// where you can provide your local server selection or communication rules.
Proxy ProxyFunc
// StateChanges receives state transition details each time client
// connection state changes. The channel is expected to be sufficiently
// buffered to keep up with event pace.
//
// If nil, no information about state transitions are dispatched
// by the library.
StateChanges chan<- *ClientStateChange
// Backoff is used to control behavior of staggering reconnection loop.
//
// If nil, default backoff policy is used which makes a client to never
// give up on reconnection.
//
// If custom backoff is used, client will emit ErrRedialAborted set
// with ClientClosed event when no more reconnection atttemps should
// be made.
Backoff Backoff
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
// yamux.DefaultConfig() is used.
YamuxConfig *yamux.Config
// Log defines the logger. If nil a default logging.Logger is used.
Log logging.Logger
// Debug enables debug mode, enable only if you want to debug the server.
Debug bool
// DEPRECATED:
// LocalAddr is DEPRECATED please use ProxyHTTP.LocalAddr, see ProxyOverwrite for more details.
LocalAddr string
// FetchLocalAddr is DEPRECATED please use ProxyTCP.FetchLocalAddr, see ProxyOverwrite for more details.
FetchLocalAddr func(port int) (string, error)
}
// verify is used to verify the ClientConfig
func (c *ClientConfig) verify() error {
if c.ServerAddr == "" && c.FetchServerAddr == nil {
return errors.New("neither ServerAddr nor FetchServerAddr is set")
}
if c.Identifier == "" && c.FetchIdentifier == nil {
return errors.New("neither Identifier nor FetchIdentifier is set")
}
if c.YamuxConfig != nil {
if err := yamux.VerifyConfig(c.YamuxConfig); err != nil {
return err
}
}
if c.Proxy != nil && (c.LocalAddr != "" || c.FetchLocalAddr != nil) {
return errors.New("both Proxy and LocalAddr or FetchLocalAddr are set")
}
return nil
}
// NewClient creates a new tunnel that is established between the serverAddr
// and localAddr. It exits if it can't create a new control connection to the
// server. If localAddr is empty client will always try to proxy to a local
// port.
func NewClient(cfg *ClientConfig) (*Client, error) {
if err := cfg.verify(); err != nil {
return nil, err
}
yamuxConfig := yamux.DefaultConfig()
if cfg.YamuxConfig != nil {
yamuxConfig = cfg.YamuxConfig
}
var proxy = DefaultProxy
if cfg.Proxy != nil {
proxy = cfg.Proxy
}
// DEPRECATED API SUPPORT
if cfg.LocalAddr != "" || cfg.FetchLocalAddr != nil {
var f ProxyFuncs
if cfg.LocalAddr != "" {
f.HTTP = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
f.WS = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
}
if cfg.FetchLocalAddr != nil {
f.TCP = (&TCPProxy{FetchLocalAddr: cfg.FetchLocalAddr}).Proxy
}
proxy = Proxy(f)
}
var bo Backoff = newForeverBackoff()
if cfg.Backoff != nil {
bo = cfg.Backoff
}
log := newLogger("tunnel-client", cfg.Debug)
if cfg.Log != nil {
log = cfg.Log
}
client := &Client{
config: cfg,
yamuxConfig: yamuxConfig,
proxy: proxy,
startNotify: make(chan bool, 1),
redialBackoff: bo,
log: log,
}
return client, nil
}
// Start starts the client and connects to the server with the identifier.
// client.FetchIdentifier() will be used if it's not nil. It's supports
// reconnecting with exponential backoff intervals when the connection to the
// server disconnects. Call client.Close() to shutdown the client completely. A
// successful connection will cause StartNotify() to receive a value.
func (c *Client) Start() {
fetchIdent := func() (string, error) {
if c.config.FetchIdentifier != nil {
return c.config.FetchIdentifier()
}
return c.config.Identifier, nil
}
fetchServerAddr := func() (string, error) {
if c.config.FetchServerAddr != nil {
return c.config.FetchServerAddr()
}
return c.config.ServerAddr, nil
}
c.changeState(ClientStarted, nil)
c.redialBackoff.Reset()
var lastErr error
for {
prev := c.changeState(ClientConnecting, lastErr)
if c.isRetry(prev) {
dur := c.redialBackoff.NextBackOff()
if dur < 0 {
c.setClosed(true)
c.changeState(ClientClosed, ErrRedialAborted)
return
}
time.Sleep(dur)
// exit if closed
if c.isClosed() {
c.changeState(ClientClosed, lastErr)
return
}
}
identifier, err := fetchIdent()
if err != nil {
lastErr = err
c.log.Critical("client fetch identifier error: %s", err)
continue
}
serverAddr, err := fetchServerAddr()
if err != nil {
lastErr = err
c.log.Critical("client fetch server address error: %s", err)
continue
}
c.setClosed(false)
if err := c.connect(identifier, serverAddr); err != nil {
lastErr = err
c.log.Debug("client connect error: %s", err)
}
// exit if closed
if c.isClosed() {
c.changeState(ClientClosed, lastErr)
return
}
}
}
// Close closes the client and shutdowns the connection to the tunnel server
func (c *Client) Close() error {
defer c.setClosed(true)
if c.session == nil {
return errors.New("session is not initialized")
}
// wait until all connections are finished
waitCh := make(chan struct{})
go func() {
if err := c.session.GoAway(); err != nil {
c.log.Debug("Session go away failed: %s", err)
}
c.reqWg.Wait()
close(waitCh)
}()
select {
case <-waitCh:
// ok
case <-time.After(time.Second * 10):
c.log.Info("Timeout waiting for connections to finish")
}
if err := c.session.Close(); err != nil {
return err
}
return nil
}
// isClosed securely checks if client is marked as closed.
func (c *Client) isClosed() bool {
c.closedMu.RLock()
defer c.closedMu.RUnlock()
return c.closed
}
// setClosed securely marks client as closed (or not closed). If not closed
// also empty the value inside the startNotify channel by retrieving it (if any),
// so it doesn't block during connect, when the client was closed and started again,
// and startNotify was never listened to.
func (c *Client) setClosed(closed bool) {
c.closedMu.Lock()
defer c.closedMu.Unlock()
c.closed = closed
if !closed {
// clear channel
select {
case <-c.startNotify:
default:
}
}
}
// startNotifyIfNeeded sends ok to startNotify channel if it's listened to.
// This function is called by connect when connection was successful.
func (c *Client) startNotifyIfNeeded() {
c.closedMu.RLock()
if !c.closed {
c.log.Debug("sending ok to startNotify chan")
select {
case c.startNotify <- true:
default:
// reaching here means the client never read the signal via
// StartNotify(). This is OK, we shouldn't except it the consumer
// to read from this channel. It's optional, so we just drop the
// signal.
c.log.Debug("startNotify message was dropped")
}
}
c.closedMu.RUnlock()
}
// StartNotify returns a channel that receives a single value when the client
// established a successful connection to the server.
func (c *Client) StartNotify() <-chan bool {
return c.startNotify
}
func (c *Client) changeState(state ClientState, err error) (prev ClientState) {
prev = ClientState(atomic.LoadUint32((*uint32)(&c.state)))
if c.config.StateChanges != nil {
change := &ClientStateChange{
Identifier: c.config.Identifier,
Previous: ClientState(prev),
Current: state,
Error: err,
}
select {
case c.config.StateChanges <- change:
default:
c.log.Warning("Dropping state change due to slow reader: %s", change)
}
}
atomic.CompareAndSwapUint32((*uint32)(&c.state), uint32(prev), uint32(state))
return prev
}
func (c *Client) isRetry(state ClientState) bool {
return state != ClientStarted && state != ClientClosed
}
func (c *Client) connect(identifier, serverAddr string) error {
c.log.Debug("Trying to connect to %q with identifier %q", serverAddr, identifier)
conn, err := c.dial(serverAddr)
if err != nil {
return err
}
remoteURL := controlURL(conn)
c.log.Debug("CONNECT to %q", remoteURL)
req, err := http.NewRequest("CONNECT", remoteURL, nil)
if err != nil {
return fmt.Errorf("error creating request to %s: %s", remoteURL, err)
}
req.Header.Set(proto.ClientIdentifierHeader, identifier)
c.log.Debug("Writing request to TCP: %+v", req)
if err := req.Write(conn); err != nil {
return fmt.Errorf("writing CONNECT request to %s failed: %s", req.URL, err)
}
c.log.Debug("Reading response from TCP")
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return fmt.Errorf("reading CONNECT response from %s failed: %s", req.URL, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK || resp.Status != proto.Connected {
out, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("tunnel server error: status=%d, error=%s", resp.StatusCode, err)
}
return fmt.Errorf("tunnel server error: status=%d, body=%s", resp.StatusCode, string(out))
}
c.ctrlWg.Wait() // wait until previous listenControl observes disconnection
c.session, err = yamux.Client(conn, c.yamuxConfig)
if err != nil {
return fmt.Errorf("session initialization failed: %s", err)
}
var stream net.Conn
openStream := func() error {
// this is blocking until client opens a session to us
stream, err = c.session.Open()
return err
}
// if we don't receive anything from the server, we'll timeout
select {
case err := <-async(openStream):
if err != nil {
return fmt.Errorf("waiting for session to open failed: %s", err)
}
case <-time.After(time.Second * 10):
if stream != nil {
stream.Close()
}
return errors.New("timeout opening session")
}
if _, err := stream.Write([]byte(proto.HandshakeRequest)); err != nil {
return fmt.Errorf("writing handshake request failed: %s", err)
}
buf := make([]byte, len(proto.HandshakeResponse))
if _, err := stream.Read(buf); err != nil {
return fmt.Errorf("reading handshake response failed: %s", err)
}
if string(buf) != proto.HandshakeResponse {
return fmt.Errorf("invalid handshake response, received: %s", string(buf))
}
ct := newControl(stream)
c.log.Debug("client has started successfully")
c.redialBackoff.Reset() // we successfully connected, so we can reset the backoff
c.startNotifyIfNeeded()
return c.listenControl(ct)
}
func (c *Client) dial(serverAddr string) (net.Conn, error) {
if c.config.Dial != nil {
return c.config.Dial("tcp", serverAddr)
}
return net.Dial("tcp", serverAddr)
}
func (c *Client) listenControl(ct *control) error {
c.ctrlWg.Add(1)
defer c.ctrlWg.Done()
c.changeState(ClientConnected, nil)
for {
var msg proto.ControlMessage
if err := ct.dec.Decode(&msg); err != nil {
c.reqWg.Wait() // wait until all requests are finished
c.session.GoAway()
c.session.Close()
c.changeState(ClientDisconnected, err)
return fmt.Errorf("failure decoding control message: %s", err)
}
c.log.Debug("Received control msg %+v", msg)
c.log.Debug("Opening a new stream from server session")
remote, err := c.session.Open()
if err != nil {
return err
}
isHTTP := msg.Protocol == proto.HTTP
if isHTTP {
c.reqWg.Add(1)
}
go func() {
c.proxy(remote, &msg)
if isHTTP {
c.reqWg.Done()
}
remote.Close()
}()
}
}

View File

@ -0,0 +1,16 @@
// Code generated by "stringer -type ClientState"; DO NOT EDIT
package tunnel
import "fmt"
const _ClientState_name = "ClientUnknownClientStartedClientConnectingClientConnectedClientDisconnectedClientClosed"
var _ClientState_index = [...]uint8{0, 13, 26, 42, 57, 75, 87}
func (i ClientState) String() string {
if i >= ClientState(len(_ClientState_index)-1) {
return fmt.Sprintf("ClientState(%d)", i)
}
return _ClientState_name[_ClientState_index[i]:_ClientState_index[i+1]]
}

110
lib/tunnel/control.go Normal file
View File

@ -0,0 +1,110 @@
package tunnel
import (
"encoding/json"
"errors"
"net"
"sync"
)
var errControlClosed = errors.New("control connection is closed")
type control struct {
// enc and dec are responsible for encoding and decoding json values forth
// and back
enc *json.Encoder
dec *json.Decoder
// underlying connection responsible for encoder and decoder
nc net.Conn
// identifier associated with this control
identifier string
mu sync.Mutex // guards the following
closed bool // if Close() and quits
}
func newControl(nc net.Conn) *control {
c := &control{
enc: json.NewEncoder(nc),
dec: json.NewDecoder(nc),
nc: nc,
}
return c
}
func (c *control) send(v interface{}) error {
if c.enc == nil {
return errors.New("encoder is not initialized")
}
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return errControlClosed
}
c.mu.Unlock()
return c.enc.Encode(v)
}
func (c *control) recv(v interface{}) error {
if c.dec == nil {
return errors.New("decoder is not initialized")
}
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return errControlClosed
}
c.mu.Unlock()
return c.dec.Decode(v)
}
func (c *control) Close() error {
if c.nc == nil {
return nil
}
c.mu.Lock()
c.closed = true
c.mu.Unlock()
return c.nc.Close()
}
type controls struct {
sync.Mutex
controls map[string]*control
}
func newControls() *controls {
return &controls{
controls: make(map[string]*control),
}
}
func (c *controls) getControl(identifier string) (*control, bool) {
c.Lock()
control, ok := c.controls[identifier]
c.Unlock()
return control, ok
}
func (c *controls) addControl(identifier string, control *control) {
control.identifier = identifier
c.Lock()
c.controls[identifier] = control
c.Unlock()
}
func (c *controls) deleteControl(identifier string) {
c.Lock()
delete(c.controls, identifier)
c.Unlock()
}

263
lib/tunnel/helper_test.go Normal file
View File

@ -0,0 +1,263 @@
package tunnel_test
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"time"
"git.xeserv.us/xena/route/lib/tunnel"
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
"github.com/gorilla/websocket"
)
func init() {
rand.Seed(time.Now().UnixNano() + int64(os.Getpid()))
}
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type EchoMessage struct {
Value string `json:"value,omitempty"`
Close bool `json:"close,omitempty"`
}
var timeout = 10 * time.Second
var dialer = &websocket.Dialer{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: timeout,
NetDial: func(_, addr string) (net.Conn, error) {
return net.DialTimeout("tcp4", addr, timeout)
},
}
func echoHTTP(tt *tunneltest.TunnelTest, echo string) (string, error) {
req := tt.Request("http", url.Values{"echo": []string{echo}})
if req == nil {
return "", fmt.Errorf(`tunnel "http" does not exist`)
}
req.Close = rand.Int()%2 == 0
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
p, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(bytes.TrimSpace(p)), nil
}
func echoTCP(tt *tunneltest.TunnelTest, echo string) (string, error) {
return echoTCPIdent(tt, echo, "tcp")
}
func echoTCPIdent(tt *tunneltest.TunnelTest, echo, ident string) (string, error) {
addr := tt.Addr(ident)
if addr == nil {
return "", fmt.Errorf("tunnel %q does not exist", ident)
}
s := addr.String()
ip := tt.Tunnels[ident].IP
if ip != nil {
_, port, err := net.SplitHostPort(s)
if err != nil {
return "", err
}
s = net.JoinHostPort(ip.String(), port)
}
c, err := dialTCP(s)
if err != nil {
return "", err
}
c.out <- echo
select {
case reply := <-c.in:
return reply, nil
case <-time.After(tcpTimeout):
return "", fmt.Errorf("timed out waiting for reply from %s (%s) after %v", s, addr, tcpTimeout)
}
}
func websocketDial(tt *tunneltest.TunnelTest, ident string) (*websocket.Conn, error) {
req := tt.Request(ident, nil)
if req == nil {
return nil, fmt.Errorf("no client found for ident %q", ident)
}
h := http.Header{"Host": {req.Host}}
wsurl := fmt.Sprintf("ws://%s", tt.ServerAddr())
conn, _, err := dialer.Dial(wsurl, h)
return conn, err
}
func sleep() {
time.Sleep(time.Duration(rand.Intn(2000)) * time.Millisecond)
}
func handlerEchoWS(sleepFn func()) func(w http.ResponseWriter, r *http.Request) error {
return func(w http.ResponseWriter, r *http.Request) (e error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return err
}
defer func() {
err := conn.Close()
if e == nil {
e = err
}
}()
if sleepFn != nil {
sleepFn()
}
for {
var msg EchoMessage
err := conn.ReadJSON(&msg)
if err != nil {
return fmt.Errorf("ReadJSON error: %s", err)
}
if sleepFn != nil {
sleepFn()
}
err = conn.WriteJSON(&msg)
if err != nil {
return fmt.Errorf("WriteJSON error: %s", err)
}
if msg.Close {
return nil
}
}
}
}
func handlerEchoHTTP(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.URL.Query().Get("echo"))
}
func handlerLatencyEchoHTTP(w http.ResponseWriter, r *http.Request) {
sleep()
handlerEchoHTTP(w, r)
}
func handlerEchoTCP(conn net.Conn) {
io.Copy(conn, conn)
}
func handlerLatencyEchoTCP(conn net.Conn) {
sleep()
handlerEchoTCP(conn)
}
var tcpTimeout = 10 * time.Second
type tcpClient struct {
conn net.Conn
scanner *bufio.Scanner
in chan string
out chan string
}
func (c *tcpClient) loop() {
for out := range c.out {
if _, err := fmt.Fprintln(c.conn, out); err != nil {
log.Printf("[tunnelclient] error writing %q to %q: %s", out, c.conn.RemoteAddr(), err)
return
}
if !c.scanner.Scan() {
log.Printf("[tunnelclient] error reading from %q: %v", c.conn.RemoteAddr(), c.scanner.Err())
return
}
c.in <- c.scanner.Text()
}
}
func (c *tcpClient) Close() error {
close(c.out)
return c.conn.Close()
}
func dialTCP(addr string) (*tcpClient, error) {
conn, err := net.DialTimeout("tcp", addr, tcpTimeout)
if err != nil {
return nil, err
}
c := &tcpClient{
conn: conn,
scanner: bufio.NewScanner(conn),
in: make(chan string, 1),
out: make(chan string, 1),
}
go c.loop()
return c, nil
}
func singleHTTP(handler interface{}) map[string]*tunneltest.Tunnel {
return singleRecHTTP(handler, nil)
}
func singleRecHTTP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel {
return map[string]*tunneltest.Tunnel{
"http": {
Type: tunneltest.TypeHTTP,
LocalAddr: "127.0.0.1:0",
Handler: handler,
StateChanges: stateChanges,
},
}
}
func singleTCP(handler interface{}) map[string]*tunneltest.Tunnel {
return singleRecTCP(handler, nil)
}
func singleRecTCP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel {
return map[string]*tunneltest.Tunnel{
"http": {
Type: tunneltest.TypeHTTP,
LocalAddr: "127.0.0.1:0",
Handler: handlerEchoHTTP,
StateChanges: stateChanges,
},
"tcp": {
Type: tunneltest.TypeTCP,
ClientIdent: "http",
LocalAddr: "127.0.0.1:0",
RemoteAddr: "127.0.0.1:0",
Handler: handler,
},
}
}

115
lib/tunnel/httpproxy.go Normal file
View File

@ -0,0 +1,115 @@
package tunnel
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"github.com/koding/logging"
"git.xeserv.us/xena/route/lib/tunnel/proto"
)
var (
httpLog = logging.NewLogger("http")
)
// HTTPProxy forwards HTTP traffic.
//
// When tunnel server requests a connection it's proxied to 127.0.0.1:incomingPort
// where incomingPort is control message LocalPort.
// Usually this is tunnel server's public exposed Port.
// This behaviour can be changed by setting LocalAddr or FetchLocalAddr.
// FetchLocalAddr takes precedence over LocalAddr.
//
// When connection to local server cannot be established proxy responds with http error message.
type HTTPProxy struct {
// LocalAddr defines the TCP address of the local server.
// This is optional if you want to specify a single TCP address.
LocalAddr string
// FetchLocalAddr is used for looking up TCP address of the server.
// This is optional if you want to specify a dynamic TCP address based on incommig port.
FetchLocalAddr func(port int) (string, error)
// ErrorResp is custom response send to tunnel server when client cannot
// establish connection to local server. If not set a default "no local server"
// response is sent.
ErrorResp *http.Response
// Log is a custom logger that can be used for the proxy.
// If not set a "http" logger is used.
Log logging.Logger
}
// Proxy is a ProxyFunc.
func (p *HTTPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) {
if msg.Protocol != proto.HTTP && msg.Protocol != proto.WS {
panic("Proxy mismatch")
}
var log = p.log()
var port = msg.LocalPort
if port == 0 {
port = 80
}
var localAddr = fmt.Sprintf("127.0.0.1:%d", port)
if p.LocalAddr != "" {
localAddr = p.LocalAddr
} else if p.FetchLocalAddr != nil {
l, err := p.FetchLocalAddr(msg.LocalPort)
if err != nil {
log.Warning("Failed to get custom local address: %s", err)
p.sendError(remote)
return
}
localAddr = l
}
log.Debug("Dialing local server %q", localAddr)
local, err := net.DialTimeout("tcp", localAddr, defaultTimeout)
if err != nil {
log.Error("Dialing local server %q failed: %s", localAddr, err)
p.sendError(remote)
return
}
Join(local, remote, log)
}
func (p *HTTPProxy) sendError(remote net.Conn) {
var w = noLocalServer()
if p.ErrorResp != nil {
w = p.ErrorResp
}
buf := new(bytes.Buffer)
w.Write(buf)
if _, err := io.Copy(remote, buf); err != nil {
var log = p.log()
log.Debug("Copy in-mem response error: %s", err)
}
remote.Close()
}
func noLocalServer() *http.Response {
body := bytes.NewBufferString("no local server")
return &http.Response{
Status: http.StatusText(http.StatusServiceUnavailable),
StatusCode: http.StatusServiceUnavailable,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Body: ioutil.NopCloser(body),
ContentLength: int64(body.Len()),
}
}
func (p *HTTPProxy) log() logging.Logger {
if p.Log != nil {
return p.Log
}
return httpLog
}

View File

@ -0,0 +1,26 @@
package proto
// ControlMessage is sent from server to client to establish tunneled connection.
type ControlMessage struct {
Action Action `json:"action"`
Protocol Type `json:"transportProtocol"`
LocalPort int `json:"localPort"`
}
// Action represents type of ControlMsg request.
type Action int
// ControlMessage actions.
const (
RequestClientSession Action = iota + 1
)
// Type represents tunneled connection type.
type Type int
// ControlMessage protocols.
const (
HTTP Type = iota + 1
TCP
WS
)

19
lib/tunnel/proto/proto.go Normal file
View File

@ -0,0 +1,19 @@
// Package proto defines tunnel client server communication protocol.
package proto
const (
// ControlPath is http.Handler url path for control connection.
ControlPath = "/_controlPath/"
// ClientIdentifierHeader is header carrying information about tunnel identifier.
ClientIdentifierHeader = "X-KTunnel-Identifier"
// control messages
// Connected is message sent by server to client when control connection was established.
Connected = "200 Connected to Tunnel"
// HandshakeRequest is hello message sent by client to server.
HandshakeRequest = "controlHandshake"
// HandshakeResponse is response to HandshakeRequest sent by server to client.
HandshakeResponse = "controlOk"
)

101
lib/tunnel/proxy.go Normal file
View File

@ -0,0 +1,101 @@
package tunnel
import (
"io"
"net"
"sync"
"github.com/koding/logging"
"git.xeserv.us/xena/route/lib/tunnel/proto"
)
// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back.
type ProxyFunc func(remote net.Conn, msg *proto.ControlMessage)
var (
// DefaultProxyFuncs holds global default proxy functions for all transport protocols.
DefaultProxyFuncs = ProxyFuncs{
HTTP: new(HTTPProxy).Proxy,
TCP: new(TCPProxy).Proxy,
WS: new(HTTPProxy).Proxy,
}
// DefaultProxy is a ProxyFunc that uses DefaultProxyFuncs.
DefaultProxy = Proxy(ProxyFuncs{})
)
// ProxyFuncs is a collection of ProxyFunc.
type ProxyFuncs struct {
// HTTP is custom implementation of HTTP proxing.
HTTP ProxyFunc
// TCP is custom implementation of TCP proxing.
TCP ProxyFunc
// WS is custom implementation of web socket proxing.
WS ProxyFunc
}
// Proxy returns a ProxyFunc that uses custom function if provided, otherwise falls back to DefaultProxyFuncs.
func Proxy(p ProxyFuncs) ProxyFunc {
return func(remote net.Conn, msg *proto.ControlMessage) {
var f ProxyFunc
switch msg.Protocol {
case proto.HTTP:
f = DefaultProxyFuncs.HTTP
if p.HTTP != nil {
f = p.HTTP
}
case proto.TCP:
f = DefaultProxyFuncs.TCP
if p.TCP != nil {
f = p.TCP
}
case proto.WS:
f = DefaultProxyFuncs.WS
if p.WS != nil {
f = p.WS
}
}
if f == nil {
logging.Error("Could not determine proxy function for %v", msg)
remote.Close()
}
f(remote, msg)
}
}
// Join copies data between local and remote connections.
// It reads from one connection and writes to the other.
// It's a building block for ProxyFunc implementations.
func Join(local, remote net.Conn, log logging.Logger) {
var wg sync.WaitGroup
wg.Add(2)
transfer := func(side string, dst, src net.Conn) {
log.Debug("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
n, err := io.Copy(dst, src)
if err != nil {
log.Error("%s: copy error: %s", side, err)
}
if err := src.Close(); err != nil {
log.Debug("%s: close error: %s", side, err)
}
// not for yamux streams, but for client to local server connections
if d, ok := dst.(*net.TCPConn); ok {
if err := d.CloseWrite(); err != nil {
log.Debug("%s: closeWrite error: %s", side, err)
}
}
wg.Done()
log.Debug("done proxing %s -> %s: %d bytes", src.RemoteAddr(), dst.RemoteAddr(), n)
}
go transfer("remote to local", local, remote)
go transfer("local to remote", remote, local)
wg.Wait()
}

755
lib/tunnel/server.go Normal file
View File

@ -0,0 +1,755 @@
// Package tunnel is a server/client package that enables to proxy public
// connections to your local machine over a tunnel connection from the local
// machine to the public server.
package tunnel
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"path"
"strconv"
"strings"
"sync"
"time"
"github.com/koding/logging"
"git.xeserv.us/xena/route/lib/tunnel/proto"
"github.com/hashicorp/yamux"
)
var (
errNoClientSession = errors.New("no client session established")
defaultTimeout = 10 * time.Second
)
// Server is responsible for proxying public connections to the client over a
// tunnel connection. It also listens to control messages from the client.
type Server struct {
// pending contains the channel that is associated with each new tunnel request.
pending map[string]chan net.Conn
// pendingMu protects the pending map.
pendingMu sync.Mutex
// sessions contains a session per virtual host.
// Sessions provides multiplexing over one connection.
sessions map[string]*yamux.Session
// sessionsMu protects sessions.
sessionsMu sync.Mutex
// controls contains the control connection from the client to the server.
controls *controls
// virtualHosts is used to map public hosts to remote clients.
virtualHosts vhostStorage
// virtualAddrs.
virtualAddrs *vaddrStorage
// connCh is used to publish accepted connections for tcp tunnels.
connCh chan net.Conn
// onConnectCallbacks contains client callbacks called when control
// session is established for a client with given identifier.
onConnectCallbacks *callbacks
// onDisconnectCallbacks contains client callbacks called when control
// session is closed for a client with given identifier.
onDisconnectCallbacks *callbacks
// states represents current clients' connections state.
states map[string]ClientState
// statesMu protects states.
statesMu sync.RWMutex
// stateCh notifies receiver about client state changes.
stateCh chan<- *ClientStateChange
// httpDirector is provided by ServerConfig, if not nil decorates http requests
// before forwarding them to client.
httpDirector func(*http.Request)
// yamuxConfig is passed to new yamux.Session's
yamuxConfig *yamux.Config
log logging.Logger
}
// ServerConfig defines the configuration for the Server
type ServerConfig struct {
// StateChanges receives state transition details each time client
// connection state changes. The channel is expected to be sufficiently
// buffered to keep up with event pace.
//
// If nil, no information about state transitions are dispatched
// by the library.
StateChanges chan<- *ClientStateChange
// Director is a function that modifies HTTP request into a new HTTP request
// before sending to client. If nil no modifications are done.
Director func(*http.Request)
// Debug enables debug mode, enable only if you want to debug the server
Debug bool
// Log defines the logger. If nil a default logging.Logger is used.
Log logging.Logger
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
// yamux.DefaultConfig() is used.
YamuxConfig *yamux.Config
}
// NewServer creates a new Server. The defaults are used if config is nil.
func NewServer(cfg *ServerConfig) (*Server, error) {
yamuxConfig := yamux.DefaultConfig()
if cfg.YamuxConfig != nil {
if err := yamux.VerifyConfig(cfg.YamuxConfig); err != nil {
return nil, err
}
yamuxConfig = cfg.YamuxConfig
}
log := newLogger("tunnel-server", cfg.Debug)
if cfg.Log != nil {
log = cfg.Log
}
connCh := make(chan net.Conn)
opts := &vaddrOptions{
connCh: connCh,
log: log,
}
s := &Server{
pending: make(map[string]chan net.Conn),
sessions: make(map[string]*yamux.Session),
onConnectCallbacks: newCallbacks("OnConnect"),
onDisconnectCallbacks: newCallbacks("OnDisconnect"),
virtualHosts: newVirtualHosts(),
virtualAddrs: newVirtualAddrs(opts),
controls: newControls(),
states: make(map[string]ClientState),
stateCh: cfg.StateChanges,
httpDirector: cfg.Director,
yamuxConfig: yamuxConfig,
connCh: connCh,
log: log,
}
go s.serveTCP()
return s, nil
}
// ServeHTTP is a tunnel that creates an http/websocket tunnel between a
// public connection and the client connection.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// if the user didn't add the control and tunnel handler manually, we'll
// going to infer and call the respective path handlers.
switch path.Clean(r.URL.Path) + "/" {
case proto.ControlPath:
s.checkConnect(s.controlHandler).ServeHTTP(w, r)
return
}
if err := s.handleHTTP(w, r); err != nil {
if !strings.Contains(err.Error(), "no virtual host available") { // this one is outputted too much, unnecessarily
s.log.Error("remote %s (%s): %s", r.RemoteAddr, r.RequestURI, err)
}
http.Error(w, err.Error(), http.StatusBadGateway)
}
}
// handleHTTP handles a single HTTP request
func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) error {
s.log.Debug("HandleHTTP request:")
s.log.Debug("%v", r)
if s.httpDirector != nil {
s.httpDirector(r)
}
hostPort := strings.ToLower(r.Host)
if hostPort == "" {
return errors.New("request host is empty")
}
// if someone hits foo.example.com:8080, this should be proxied to
// localhost:8080, so send the port to the client so it knows how to proxy
// correctly. If no port is available, it's up to client how to interpret it
host, port, err := parseHostPort(hostPort)
if err != nil {
// no need to return, just continue lazily, port will be 0, which in
// our case will be proxied to client's local servers port 80
s.log.Debug("No port available for %q, sending port 80 to client", hostPort)
}
// get the identifier associated with this host
identifier, ok := s.getIdentifier(hostPort)
if !ok {
// fallback to host
identifier, ok = s.getIdentifier(host)
if !ok {
return fmt.Errorf("no virtual host available for %q", hostPort)
}
}
if isWebsocketConn(r) {
s.log.Debug("handling websocket connection")
return s.handleWSConn(w, r, identifier, port)
}
stream, err := s.dial(identifier, proto.HTTP, port)
if err != nil {
return err
}
defer func() {
s.log.Debug("Closing stream")
stream.Close()
}()
if err := r.Write(stream); err != nil {
return err
}
s.log.Debug("Session opened to client, writing request to client")
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
if err != nil {
return fmt.Errorf("read from tunnel: %s", err.Error())
}
defer func() {
if resp.Body != nil {
if err := resp.Body.Close(); err != nil && err != io.ErrUnexpectedEOF {
s.log.Error("resp.Body Close error: %s", err.Error())
}
}
}()
s.log.Debug("Response received, writing back to public connection: %+v", resp)
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resp.Body); err != nil {
if err == io.ErrUnexpectedEOF {
s.log.Debug("Client closed the connection, couldn't copy response")
} else {
s.log.Error("copy err: %s", err) // do not return, because we might write multipe headers
}
}
return nil
}
func (s *Server) serveTCP() {
for conn := range s.connCh {
go s.serveTCPConn(conn)
}
}
func (s *Server) serveTCPConn(conn net.Conn) {
err := s.handleTCPConn(conn)
if err != nil {
s.log.Warning("failed to serve %q: %s", conn.RemoteAddr(), err)
conn.Close()
}
}
func (s *Server) handleWSConn(w http.ResponseWriter, r *http.Request, ident string, port int) error {
hj, ok := w.(http.Hijacker)
if !ok {
return errors.New("webserver doesn't support hijacking")
}
conn, _, err := hj.Hijack()
if err != nil {
return fmt.Errorf("hijack not possible: %s", err)
}
stream, err := s.dial(ident, proto.WS, port)
if err != nil {
return err
}
if err := r.Write(stream); err != nil {
err = errors.New("unable to write upgrade request: " + err.Error())
return nonil(err, stream.Close())
}
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
if err != nil {
err = errors.New("unable to read upgrade response: " + err.Error())
return nonil(err, stream.Close())
}
if err := resp.Write(conn); err != nil {
err = errors.New("unable to write upgrade response: " + err.Error())
return nonil(err, stream.Close())
}
var wg sync.WaitGroup
wg.Add(2)
go s.proxy(&wg, conn, stream)
go s.proxy(&wg, stream, conn)
wg.Wait()
return nonil(stream.Close(), conn.Close())
}
func (s *Server) handleTCPConn(conn net.Conn) error {
ident, ok := s.virtualAddrs.getIdent(conn)
if !ok {
return fmt.Errorf("no virtual address available for %s", conn.LocalAddr())
}
_, port, err := parseHostPort(conn.LocalAddr().String())
if err != nil {
return err
}
stream, err := s.dial(ident, proto.TCP, port)
if err != nil {
return err
}
var wg sync.WaitGroup
wg.Add(2)
go s.proxy(&wg, conn, stream)
go s.proxy(&wg, stream, conn)
wg.Wait()
return nonil(stream.Close(), conn.Close())
}
func (s *Server) proxy(wg *sync.WaitGroup, dst, src net.Conn) {
defer wg.Done()
s.log.Debug("tunneling %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
n, err := io.Copy(dst, src)
s.log.Debug("tunneled %d bytes %s -> %s: %v", n, src.RemoteAddr(), dst.RemoteAddr(), err)
}
func (s *Server) dial(identifier string, p proto.Type, port int) (net.Conn, error) {
control, ok := s.getControl(identifier)
if !ok {
return nil, errNoClientSession
}
session, err := s.getSession(identifier)
if err != nil {
return nil, err
}
msg := proto.ControlMessage{
Action: proto.RequestClientSession,
Protocol: p,
LocalPort: port,
}
s.log.Debug("Sending control msg %+v", msg)
// ask client to open a session to us, so we can accept it
if err := control.send(msg); err != nil {
// we might have several issues here, either the stream is closed, or
// the session is going be shut down, the underlying connection might
// be broken. In all cases, it's not reliable anymore having a client
// session.
control.Close()
s.deleteControl(identifier)
return nil, errNoClientSession
}
var stream net.Conn
acceptStream := func() error {
stream, err = session.Accept()
return err
}
// if we don't receive anything from the client, we'll timeout
s.log.Debug("Waiting for session accept")
select {
case err := <-async(acceptStream):
return stream, err
case <-time.After(defaultTimeout):
return nil, errors.New("timeout getting session")
}
}
// controlHandler is used to capture incoming tunnel connect requests into raw
// tunnel TCP connections.
func (s *Server) controlHandler(w http.ResponseWriter, r *http.Request) (ctErr error) {
identifier := r.Header.Get(proto.ClientIdentifierHeader)
_, ok := s.getHost(identifier)
if !ok {
return fmt.Errorf("no host associated for identifier %s. please use server.AddHost()", identifier)
}
ct, ok := s.getControl(identifier)
if ok {
ct.Close()
s.deleteControl(identifier)
s.deleteSession(identifier)
s.log.Warning("Control connection for %q already exists. This is a race condition and needs to be fixed on client implementation", identifier)
return fmt.Errorf("control conn for %s already exist. \n", identifier)
}
s.log.Debug("Tunnel with identifier %s", identifier)
hj, ok := w.(http.Hijacker)
if !ok {
return errors.New("webserver doesn't support hijacking")
}
conn, _, err := hj.Hijack()
if err != nil {
return fmt.Errorf("hijack not possible: %s", err)
}
if _, err := io.WriteString(conn, "HTTP/1.1 "+proto.Connected+"\n\n"); err != nil {
return fmt.Errorf("error writing response: %s", err)
}
if err := conn.SetDeadline(time.Time{}); err != nil {
return fmt.Errorf("error setting connection deadline: %s", err)
}
s.log.Debug("Creating control session")
session, err := yamux.Server(conn, s.yamuxConfig)
if err != nil {
return err
}
s.addSession(identifier, session)
var stream net.Conn
// close and delete the session/stream if something goes wrong
defer func() {
if ctErr != nil {
if stream != nil {
stream.Close()
}
s.deleteSession(identifier)
}
}()
acceptStream := func() error {
stream, err = session.Accept()
return err
}
// if we don't receive anything from the client, we'll timeout
select {
case err := <-async(acceptStream):
if err != nil {
return err
}
case <-time.After(time.Second * 10):
return errors.New("timeout getting session")
}
s.log.Debug("Initiating handshake protocol")
buf := make([]byte, len(proto.HandshakeRequest))
if _, err := stream.Read(buf); err != nil {
return err
}
if string(buf) != proto.HandshakeRequest {
return fmt.Errorf("handshake aborted. got: %s", string(buf))
}
if _, err := stream.Write([]byte(proto.HandshakeResponse)); err != nil {
return err
}
// setup control stream and start to listen to messages
ct = newControl(stream)
s.addControl(identifier, ct)
go s.listenControl(ct)
s.log.Debug("Control connection is setup")
return nil
}
// listenControl listens to messages coming from the client.
func (s *Server) listenControl(ct *control) {
s.onConnect(ct.identifier)
for {
var msg map[string]interface{}
err := ct.dec.Decode(&msg)
if err != nil {
host, _ := s.getHost(ct.identifier)
s.log.Debug("Closing client connection: '%s', %s'", host, ct.identifier)
// close client connection so it reconnects again
ct.Close()
// don't forget to cleanup anything
s.deleteControl(ct.identifier)
s.deleteSession(ct.identifier)
s.onDisconnect(ct.identifier, err)
if err != io.EOF {
s.log.Error("decode err: %s", err)
}
return
}
// right now we don't do anything with the messages, but because the
// underlying connection needs to establihsed, we know when we have
// disconnection(above), so we can cleanup the connection.
s.log.Debug("msg: %s", msg)
}
}
// OnConnect invokes a callback for client with given identifier,
// when it establishes a control session.
// After a client is connected, the associated function
// is also removed and needs to be added again.
func (s *Server) OnConnect(identifier string, fn func() error) {
s.onConnectCallbacks.add(identifier, fn)
}
// onConnect sends notifications to listeners (registered in onConnectCallbacks
// or stateChanges chanel readers) when client connects.
func (s *Server) onConnect(identifier string) {
if err := s.onConnectCallbacks.call(identifier); err != nil {
s.log.Error("OnConnect: error calling callback for %q: %s", identifier, err)
}
s.changeState(identifier, ClientConnected, nil)
}
// OnDisconnect calls the function when the client connected with the
// associated identifier disconnects from the server.
// After a client is disconnected, the associated function
// is also removed and needs to be added again.
func (s *Server) OnDisconnect(identifier string, fn func() error) {
s.onDisconnectCallbacks.add(identifier, fn)
}
// onDisconnect sends notifications to listeners (registered in onDisconnectCallbacks
// or stateChanges chanel readers) when client disconnects.
func (s *Server) onDisconnect(identifier string, err error) {
if err := s.onDisconnectCallbacks.call(identifier); err != nil {
s.log.Error("OnDisconnect: error calling callback for %q: %s", identifier, err)
}
s.changeState(identifier, ClientClosed, err)
}
func (s *Server) changeState(identifier string, state ClientState, err error) (prev ClientState) {
s.statesMu.Lock()
defer s.statesMu.Unlock()
prev = s.states[identifier]
s.states[identifier] = state
if s.stateCh != nil {
change := &ClientStateChange{
Identifier: identifier,
Previous: prev,
Current: state,
Error: err,
}
select {
case s.stateCh <- change:
default:
s.log.Warning("Dropping state change due to slow reader: %s", change)
}
}
return prev
}
// AddHost adds the given virtual host and maps it to the identifier.
func (s *Server) AddHost(host, identifier string) {
s.virtualHosts.AddHost(host, identifier)
}
// DeleteHost deletes the given virtual host. Once removed any request to this
// host is denied.
func (s *Server) DeleteHost(host string) {
s.virtualHosts.DeleteHost(host)
}
// AddAddr starts accepting connections on listener l, routing every connection
// to a tunnel client given by the identifier.
//
// When ip parameter is nil, all connections accepted from the listener are
// routed to the tunnel client specified by the identifier (port-based routing).
//
// When ip parameter is non-nil, only those connections are routed whose local
// address matches the specified ip (ip-based routing).
//
// If l listens on multiple interfaces it's desirable to call AddAddr multiple
// times with the same l value but different ip one.
func (s *Server) AddAddr(l net.Listener, ip net.IP, identifier string) {
s.virtualAddrs.Add(l, ip, identifier)
}
// DeleteAddr stops listening for connections on the given listener.
//
// Upon return no more connections will be tunneled, but as the method does not
// close the listener, so any ongoing connection won't get interrupted.
func (s *Server) DeleteAddr(l net.Listener, ip net.IP) {
s.virtualAddrs.Delete(l, ip)
}
func (s *Server) getIdentifier(host string) (string, bool) {
identifier, ok := s.virtualHosts.GetIdentifier(host)
return identifier, ok
}
func (s *Server) getHost(identifier string) (string, bool) {
host, ok := s.virtualHosts.GetHost(identifier)
return host, ok
}
func (s *Server) addControl(identifier string, conn *control) {
s.controls.addControl(identifier, conn)
}
func (s *Server) getControl(identifier string) (*control, bool) {
return s.controls.getControl(identifier)
}
func (s *Server) deleteControl(identifier string) {
s.controls.deleteControl(identifier)
}
func (s *Server) getSession(identifier string) (*yamux.Session, error) {
s.sessionsMu.Lock()
session, ok := s.sessions[identifier]
s.sessionsMu.Unlock()
if !ok {
return nil, fmt.Errorf("no session available for identifier: '%s'", identifier)
}
return session, nil
}
func (s *Server) addSession(identifier string, session *yamux.Session) {
s.sessionsMu.Lock()
s.sessions[identifier] = session
s.sessionsMu.Unlock()
}
func (s *Server) deleteSession(identifier string) {
s.sessionsMu.Lock()
defer s.sessionsMu.Unlock()
session, ok := s.sessions[identifier]
if !ok {
return // nothing to delete
}
if session != nil {
session.GoAway() // don't accept any new connection
session.Close()
}
delete(s.sessions, identifier)
}
func copyHeader(dst, src http.Header) {
for k, v := range src {
vv := make([]string, len(v))
copy(vv, v)
dst[k] = vv
}
}
// checkConnect checks whether the incoming request is HTTP CONNECT method.
func (s *Server) checkConnect(fn func(w http.ResponseWriter, r *http.Request) error) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "CONNECT" {
http.Error(w, "405 must CONNECT\n", http.StatusMethodNotAllowed)
return
}
if err := fn(w, r); err != nil {
s.log.Error("Handler err: %v", err.Error())
if identifier := r.Header.Get(proto.ClientIdentifierHeader); identifier != "" {
s.onDisconnect(identifier, err)
}
http.Error(w, err.Error(), 502)
}
})
}
func parseHostPort(addr string) (string, int, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return "", 0, err
}
n, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return "", 0, err
}
return host, int(n), nil
}
func isWebsocketConn(r *http.Request) bool {
return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") &&
headerContains(r.Header["Upgrade"], "websocket")
}
// headerContains is a copy of tokenListContainsValue from gorilla/websocket/util.go
func headerContains(header []string, value string) bool {
for _, h := range header {
for _, v := range strings.Split(h, ",") {
if strings.EqualFold(strings.TrimSpace(v), value) {
return true
}
}
}
return false
}
func nonil(err ...error) error {
for _, e := range err {
if e != nil {
return e
}
}
return nil
}
func newLogger(name string, debug bool) logging.Logger {
log := logging.NewLogger(name)
logHandler := logging.NewWriterHandler(os.Stderr)
logHandler.Colorize = true
log.SetHandler(logHandler)
if debug {
log.SetLevel(logging.DEBUG)
logHandler.SetLevel(logging.DEBUG)
}
return log
}

100
lib/tunnel/spec.md Normal file
View File

@ -0,0 +1,100 @@
# Specification
# Naming conventions
* `server` is listening to public connection and is responsible of routing
public HTTP requests to clients.
* `client` is a long running process, connected to a server and running on a local machine.
* `virtualHost` is a virtual domain that maps a domain to a single client. i.e:
`arslan.koding.io` is a virtualhost which is mapped to my `client` running on
my local machine.
* `identifier` is a secret token, which is not meant to be shared with others.
An identifier is responsible of mapping a virtualhost to a client.
* `session` is a single TCP connection which uses the library `yamux`. A
session can be created either via `yamux.Server()` or `yamux.Client`
* `stream` is a `net.Conn` compatible `virtual` connection that is multiplexed
over the `session`. A session can have hundreds of thousands streams
* `control connection` is a single `stream` which is used to communicate and
handle messaging between server and client. It uses a custom protocol which
is JSON encoded.
* `tunnel connection` is a single `stream` which is used to proxy public HTTP
requests from the `server` to the `client` and vice versa. A single `tunnel`
connection is created for every single HTTP requests.
* `public connection` is a connection from a remote machine to the `server`
* `ControlHandler` is a http.Handler which listens to requests coming to
`/_controlPath_/`. It's used to setup the initial `session` connection from
`client` to `server`. And creates the `control connection` from this session.
server and client, and also for all additional new tunnel. It literally
captures the incoming HTTP request and hijacks it and converts it into RAW TCP,
which then is used as the foundation for all yamux `sessions.`
# Server
1. Server is created with `NewServer()` which returns `*Server`, a `http.Handler`
compatible type. Plug into any HTTP server you want. The root path `"/"` is
recommended to listen and proxy any tunnels. It also listens to any request
coming to `ControlHandler`
2. Tunneling is based on virtual hosts. A virtual hosts is identified with an
unique identifier. This identifier is the only piece that both client and
server needs to known ahead. Think of it as a secret token.
3. To add a virtual host, call `server.AddHost(virtualHost, identifier)`. This
step needs to be done from the server itself. This can be could manually or
via custom auth based HTTP handlers, such as "/addhost", which adds
virtualhosts and returns the `identifier` to the requester (in our case `client`)
4. A DNS record and it's subdomains needs to point to a `server`, so it can
handle virtual hosts, i.e: `*.example.com` is routed to a server, which can
handle `foo.example.com`, `bar.example.com`, etc..
# Client
1. Client is created with `NewClient(serverAddr, localAddr)` which returns a
`*Client`. Here `serverAddr` is the TCP address to the server. `localAddr`
is the server in which all public requests are forwarded to. It's optional if
you want it to be done dynamically
2. Once a client is created, it starts with `client.Start(identifier)`. Here
`identifier` is needed upfront. This method creates the initial TCP
connection to the server. It sends the identifier back to the server. This
TCP connection is used as the foundation for `yamux.Client()`. Once a yamux
session is established, we are able to use this single connection to have
multiple streams, which are multiplexed over this one connection. A `control
connection` is created and client starts to listen it. `client.Start` is
blocking.
# Control Handshake
1. Client sends a `handshakeRequest` over the `control connection` stream
2. The server sends back a `handshakeResponse` to the client over the `control connection` stream
3. Once the client receives the `handshakeResponse` from the server, it starts
to listen from the `control connection` stream.
4. A `control connection` is json.Encoder/Decoder both for server and client
# Tunnel creation
1. When the server receives a public connection, it checks the HTTP host
headers and retrieves the corresponding identifier from the given host.
2. The server retrieves the `control connection` which was associated with this
`identifier` and sends a `ControlMsg` message with the action
`RequestClientSession`. This message is in the form of:
type ControlMsg struct {
Action Action `json:"action"`
Protocol TransportProtocol `json:"transportProtocol"`
LocalPort string `json:"localPort"`
}
Here the `LocalPort` is read from the HTTP Host header. If absent a zero
port is sent and client maps it to the local server running at port 80, unless
the `localAddr` is specified in `client.Start()` method. `Protocol` is
reserved for future features.
3. The server immediately starts to listen(accept) to a new `stream`. This is
blocking and it waits there.
4. When the client receives the `RequestClientSession` message, it opens a new
`virtual` TCP connection, a `stream` to the server.
5. The server which was waiting for a new stream in step 3, establish the stream.
6. The server copies the request over the stream to the client.
7. The client copies the request coming from the server to the local server and
copies back the result to the server
8. The server reads the response coming from the client and returns back it to
the public connection requester

78
lib/tunnel/tcpproxy.go Normal file
View File

@ -0,0 +1,78 @@
package tunnel
import (
"fmt"
"net"
"github.com/koding/logging"
"git.xeserv.us/xena/route/lib/tunnel/proto"
)
var (
tpcLog = logging.NewLogger("tcp")
)
// TCPProxy forwards TCP streams.
//
// If port-based routing is used, LocalAddr or FetchLocalAddr field is required
// for tunneling to function properly.
// Otherwise you'll be forwarding traffic to random ports and this is usually not desired.
//
// If IP-based routing is used then tunnel server connection request is
// proxied to 127.0.0.1:incomingPort where incomingPort is control message LocalPort.
// Usually this is tunnel server's public exposed Port.
// This behaviour can be changed by setting LocalAddr or FetchLocalAddr.
// FetchLocalAddr takes precedence over LocalAddr.
type TCPProxy struct {
// LocalAddr defines the TCP address of the local server.
// This is optional if you want to specify a single TCP address.
LocalAddr string
// FetchLocalAddr is used for looking up TCP address of the server.
// This is optional if you want to specify a dynamic TCP address based on incommig port.
FetchLocalAddr func(port int) (string, error)
// Log is a custom logger that can be used for the proxy.
// If not set a "tcp" logger is used.
Log logging.Logger
}
// Proxy is a ProxyFunc.
func (p *TCPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) {
if msg.Protocol != proto.TCP {
panic("Proxy mismatch")
}
var log = p.log()
var port = msg.LocalPort
if port == 0 {
log.Warning("TCP proxy to port 0")
}
var localAddr = fmt.Sprintf("127.0.0.1:%d", port)
if p.LocalAddr != "" {
localAddr = p.LocalAddr
} else if p.FetchLocalAddr != nil {
l, err := p.FetchLocalAddr(msg.LocalPort)
if err != nil {
log.Warning("Failed to get custom local address: %s", err)
return
}
localAddr = l
}
log.Debug("Dialing local server: %q", localAddr)
local, err := net.DialTimeout("tcp", localAddr, defaultTimeout)
if err != nil {
log.Error("Dialing local server %q failed: %s", localAddr, err)
return
}
Join(local, remote, log)
}
func (p *TCPProxy) log() logging.Logger {
if p.Log != nil {
return p.Log
}
return tpcLog
}

412
lib/tunnel/tunnel_test.go Normal file
View File

@ -0,0 +1,412 @@
package tunnel_test
import (
"fmt"
"strconv"
"sync"
"testing"
"time"
"git.xeserv.us/xena/route/lib/tunnel"
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
"github.com/cenkalti/backoff"
)
func TestMultipleRequest(t *testing.T) {
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
// make a request to tunnelserver, this should be tunneled to local server
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
msg := "hello" + strconv.Itoa(i)
res, err := echoHTTP(tt, msg)
if err != nil {
t.Fatalf("echoHTTP error: %s", err)
}
if res != msg {
t.Errorf("got %q, want %q", res, msg)
}
}(i)
}
wg.Wait()
}
func TestMultipleLatencyRequest(t *testing.T) {
tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
// make a request to tunnelserver, this should be tunneled to local server
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
msg := "hello" + strconv.Itoa(i)
res, err := echoHTTP(tt, msg)
if err != nil {
t.Fatalf("echoHTTP error: %s", err)
}
if res != msg {
t.Errorf("got %q, want %q", res, msg)
}
}(i)
}
wg.Wait()
}
func TestReconnectClient(t *testing.T) {
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
msg := "hello"
res, err := echoHTTP(tt, msg)
if err != nil {
t.Fatalf("echoHTTP error: %s", err)
}
if res != msg {
t.Errorf("got %q, want %q", res, msg)
}
client := tt.Clients["http"]
// close client, and start it again
client.Close()
go client.Start()
<-client.StartNotify()
msg = "helloagain"
res, err = echoHTTP(tt, msg)
if err != nil {
t.Fatalf("echoHTTP error: %s", err)
}
if res != msg {
t.Errorf("got %q, want %q", res, msg)
}
}
func TestNoClient(t *testing.T) {
const expectedErr = "no client session established"
rec := tunneltest.NewStateRecorder()
tt, err := tunneltest.Serve(singleRecHTTP(handlerEchoHTTP, rec.C()))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
if err := rec.WaitTransitions(
tunnel.ClientStarted,
tunnel.ClientConnecting,
tunnel.ClientConnected,
); err != nil {
t.Fatal(err)
}
if err := tt.ServerStateRecorder.WaitTransition(
tunnel.ClientUnknown,
tunnel.ClientConnected,
); err != nil {
t.Fatal(err)
}
// close client, this is the main point of the test
if err := tt.Clients["http"].Close(); err != nil {
t.Fatal(err)
}
if err := rec.WaitTransitions(
tunnel.ClientConnected,
tunnel.ClientDisconnected,
tunnel.ClientClosed,
); err != nil {
t.Fatal(err)
}
if err := tt.ServerStateRecorder.WaitTransition(
tunnel.ClientConnected,
tunnel.ClientClosed,
); err != nil {
t.Fatal(err)
}
msg := "hello"
res, err := echoHTTP(tt, msg)
if err != nil {
t.Fatalf("echoHTTP error: %s", err)
}
if res != expectedErr {
t.Errorf("got %q, want %q", res, msg)
}
}
func TestNoHost(t *testing.T) {
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
noBackoff := backoff.NewConstantBackOff(time.Duration(-1))
unknown, err := tunnel.NewClient(&tunnel.ClientConfig{
Identifier: "unknown",
ServerAddr: tt.ServerAddr().String(),
Backoff: noBackoff,
Debug: testing.Verbose(),
})
if err != nil {
t.Fatalf("client error: %s", err)
}
unknown.Start()
defer unknown.Close()
if err := tt.ServerStateRecorder.WaitTransition(
tunnel.ClientUnknown,
tunnel.ClientClosed,
); err != nil {
t.Fatal(err)
}
unknown.Start()
if err := tt.ServerStateRecorder.WaitTransition(
tunnel.ClientClosed,
tunnel.ClientClosed,
); err != nil {
t.Fatal(err)
}
}
func TestNoLocalServer(t *testing.T) {
const expectedErr = "no local server"
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
// close local listener, this is the main point of the test
tt.Listeners["http"][0].Close()
msg := "hello"
res, err := echoHTTP(tt, msg)
if err != nil {
t.Fatalf("echoHTTP error: %s", err)
}
if res != expectedErr {
t.Errorf("got %q, want %q", res, msg)
}
}
func TestSingleRequest(t *testing.T) {
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
msg := "hello"
res, err := echoHTTP(tt, msg)
if err != nil {
t.Fatalf("echoHTTP error: %s", err)
}
if res != msg {
t.Errorf("got %q, want %q", res, msg)
}
}
func TestSingleLatencyRequest(t *testing.T) {
tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
msg := "hello"
res, err := echoHTTP(tt, msg)
if err != nil {
t.Fatalf("echoHTTP error: %s", err)
}
if res != msg {
t.Errorf("got %q, want %q", res, msg)
}
}
func TestSingleTCP(t *testing.T) {
tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
msg := "hello"
res, err := echoTCP(tt, msg)
if err != nil {
t.Fatalf("echoTCP error: %s", err)
}
if msg != res {
t.Errorf("got %q, want %q", res, msg)
}
}
func TestMultipleTCP(t *testing.T) {
tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
msg := "hello" + strconv.Itoa(i)
res, err := echoTCP(tt, msg)
if err != nil {
t.Errorf("echoTCP: %s", err)
}
if res != msg {
t.Errorf("got %q, want %q", res, msg)
}
}(i)
}
wg.Wait()
}
func TestMultipleLatencyTCP(t *testing.T) {
tt, err := tunneltest.Serve(singleTCP(handlerLatencyEchoTCP))
if err != nil {
t.Fatal(err)
}
defer tt.Close()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
msg := "hello" + strconv.Itoa(i)
res, err := echoTCP(tt, msg)
if err != nil {
t.Errorf("echoTCP: %s", err)
}
if res != msg {
t.Errorf("got %q, want %q", res, msg)
}
}(i)
}
wg.Wait()
}
func TestMultipleStreamTCP(t *testing.T) {
tunnels := map[string]*tunneltest.Tunnel{
"http": {
Type: tunneltest.TypeHTTP,
LocalAddr: "127.0.0.1:0",
Handler: handlerEchoHTTP,
},
"tcp": {
Type: tunneltest.TypeTCP,
ClientIdent: "http",
LocalAddr: "127.0.0.1:0",
RemoteAddr: "127.0.0.1:0",
Handler: handlerEchoTCP,
},
"tcp_all": {
Type: tunneltest.TypeTCP,
ClientIdent: "http",
LocalAddr: "127.0.0.1:0",
RemoteAddr: "0.0.0.0:0",
Handler: handlerEchoTCP,
},
}
addrs, err := tunneltest.UsableAddrs()
if err != nil {
t.Fatal(err)
}
clients := []string{"tcp"}
for i, addr := range addrs {
if addr.IP.IsLoopback() {
continue
}
client := fmt.Sprintf("tcp_%d", i)
tunnels[client] = &tunneltest.Tunnel{
Type: tunneltest.TypeTCP,
ClientIdent: "http",
LocalAddr: "127.0.0.1:0",
RemoteAddrIdent: "tcp_all",
IP: addr.IP,
Handler: handlerEchoTCP,
}
clients = append(clients, client)
}
tt, err := tunneltest.Serve(tunnels)
if err != nil {
t.Fatal(err)
}
defer tt.Close()
var wg sync.WaitGroup
for i := 0; i < 100/len(clients); i++ {
wg.Add(len(clients))
for j, ident := range clients {
go func(ident string, i, j int) {
defer wg.Done()
msg := fmt.Sprintf("hello_%d_client_%d", j, i)
res, err := echoTCPIdent(tt, msg, ident)
if err != nil {
t.Errorf("echoTCP: %s", err)
}
if res != msg {
t.Errorf("got %q, want %q", res, msg)
}
}(ident, i, j)
}
}
wg.Wait()
}

View File

@ -0,0 +1,118 @@
package tunneltest
import (
"bytes"
"fmt"
"sync"
"time"
"git.xeserv.us/xena/route/lib/tunnel"
)
var (
recWaitTimeout = 5 * time.Second
recBuffer = 32
)
// States is a sequence of client state changes.
type States []*tunnel.ClientStateChange
func (s States) String() string {
if len(s) == 0 {
return ""
}
var buf bytes.Buffer
fmt.Fprintf(&buf, "[%s", s[0].String())
for _, s := range s[1:] {
fmt.Fprintf(&buf, ",%s", s.String())
}
buf.WriteRune(']')
return buf.String()
}
// StateRecorder saves state changes pushed to StateRecorder.C().
type StateRecorder struct {
mu sync.Mutex
ch chan *tunnel.ClientStateChange
recorded []*tunnel.ClientStateChange
offset int
}
func NewStateRecorder() *StateRecorder {
rec := &StateRecorder{
ch: make(chan *tunnel.ClientStateChange, recBuffer),
}
go rec.record()
return rec
}
func (rec *StateRecorder) record() {
for state := range rec.ch {
rec.mu.Lock()
rec.recorded = append(rec.recorded, state)
rec.mu.Unlock()
}
}
func (rec *StateRecorder) C() chan<- *tunnel.ClientStateChange {
return rec.ch
}
func (rec *StateRecorder) WaitTransitions(states ...tunnel.ClientState) error {
from := states[0]
for _, to := range states[1:] {
if err := rec.WaitTransition(from, to); err != nil {
return err
}
from = to
}
return nil
}
func (rec *StateRecorder) WaitTransition(from, to tunnel.ClientState) error {
timeout := time.After(recWaitTimeout)
var lastStates []*tunnel.ClientStateChange
for {
select {
case <-timeout:
return fmt.Errorf("timed out waiting for %s->%s transition: %v", from, to, States(lastStates))
default:
time.Sleep(50 * time.Millisecond)
lastStates = rec.States()[rec.offset:]
for i, state := range lastStates {
if from != 0 && state.Previous != from {
continue
}
if to != 0 && state.Current != to {
continue
}
rec.offset += i
return nil
}
}
}
}
func (rec *StateRecorder) States() []*tunnel.ClientStateChange {
rec.mu.Lock()
defer rec.mu.Unlock()
states := make([]*tunnel.ClientStateChange, len(rec.recorded))
copy(states, rec.recorded)
return states
}

View File

@ -0,0 +1,561 @@
package tunneltest
import (
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"os"
"sort"
"strconv"
"sync"
"testing"
"time"
"git.xeserv.us/xena/route/lib/tunnel"
)
var debugNet = os.Getenv("DEBUGNET") == "1"
type dbgListener struct {
net.Listener
}
func (l dbgListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return dbgConn{conn}, nil
}
type dbgConn struct {
net.Conn
}
func (c dbgConn) Read(p []byte) (int, error) {
n, err := c.Conn.Read(p)
os.Stderr.Write(p)
return n, err
}
func (c dbgConn) Write(p []byte) (int, error) {
n, err := c.Conn.Write(p)
os.Stderr.Write(p)
return n, err
}
func logf(format string, args ...interface{}) {
if testing.Verbose() {
log.Printf("[tunneltest] "+format, args...)
}
}
func nonil(err ...error) error {
for _, e := range err {
if e != nil {
return e
}
}
return nil
}
func parseHostPort(addr string) (string, int, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return "", 0, err
}
n, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return "", 0, err
}
return host, int(n), nil
}
// UsableAddrs returns all tcp addresses that we can bind a listener to.
func UsableAddrs() ([]*net.TCPAddr, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil, err
}
var usable []*net.TCPAddr
for _, addr := range addrs {
if ipNet, ok := addr.(*net.IPNet); ok {
if !ipNet.IP.IsLinkLocalUnicast() {
usable = append(usable, &net.TCPAddr{IP: ipNet.IP})
}
}
}
if len(usable) == 0 {
return nil, errors.New("no usable addresses found")
}
return usable, nil
}
const (
TypeHTTP = iota
TypeTCP
)
// Tunnel represents a single HTTP or TCP tunnel that can be served
// by TunnelTest.
type Tunnel struct {
// Type specifies a tunnel type - either TypeHTTP (default) or TypeTCP.
Type int
// Handler is a handler to use for serving tunneled connections on
// local server. The value of this field is required to be of type:
//
// - http.Handler or http.HandlerFunc for HTTP tunnels
// - func(net.Conn) for TCP tunnels
//
// Required field.
Handler interface{}
// LocalAddr is a network address of local server that handles
// connections/requests with Handler.
//
// Optional field, takes value of "127.0.0.1:0" when empty.
LocalAddr string
// ClientIdent is an identifier of a client that have already
// registered a HTTP tunnel and have established control connection.
//
// If the Type is TypeTCP, instead of creating new client
// for this TCP tunnel, we add it to an existing client
// specified by the field.
//
// Optional field for TCP tunnels.
// Ignored field for HTTP tunnels.
ClientIdent string
// RemoteAddr is a network address of remote server, which accepts
// connections on a tunnel server side.
//
// Required field for TCP tunnels.
// Ignored field for HTTP tunnels.
RemoteAddr string
// RemoteAddrIdent an identifier of an already existing listener,
// that listens on multiple interfaces; if the RemoteAddrIdent is valid
// identifier the IP field is required to be non-nil and RemoteAddr
// is ignored.
//
// Optional field for TCP tunnels.
// Ignored field for HTTP tunnels.
RemoteAddrIdent string
// IP specifies an IP address value for IP-based routing for TCP tunnels.
// For more details see inline documentation for (*tunnel.Server).AddAddr.
//
// Optional field for TCP tunnels.
// Ignored field for HTTP tunnels.
IP net.IP
// StateChanges listens on state transitions.
//
// If ClientIdent field is empty, the StateChanges will receive
// state transition events for the newly created client.
// Otherwise setting this field is a nop.
StateChanges chan<- *tunnel.ClientStateChange
}
type TunnelTest struct {
Server *tunnel.Server
ServerStateRecorder *StateRecorder
Clients map[string]*tunnel.Client
Listeners map[string][2]net.Listener // [0] is local listener, [1] is remote one (for TCP tunnels)
Addrs []*net.TCPAddr
Tunnels map[string]*Tunnel
DebugNet bool // for debugging network communication
mu sync.Mutex // protects Listeners
}
func NewTunnelTest() (*TunnelTest, error) {
rec := NewStateRecorder()
cfg := &tunnel.ServerConfig{
StateChanges: rec.C(),
Debug: testing.Verbose(),
}
s, err := tunnel.NewServer(cfg)
if err != nil {
return nil, err
}
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
}
if debugNet {
l = dbgListener{l}
}
addrs, err := UsableAddrs()
if err != nil {
return nil, err
}
go (&http.Server{Handler: s}).Serve(l)
return &TunnelTest{
Server: s,
ServerStateRecorder: rec,
Clients: make(map[string]*tunnel.Client),
Listeners: map[string][2]net.Listener{"": {l, nil}},
Addrs: addrs,
Tunnels: make(map[string]*Tunnel),
DebugNet: debugNet,
}, nil
}
// Serve creates new TunnelTest that serves the given tunnels.
//
// If tunnels is nil, DefaultTunnels() are used instead.
func Serve(tunnels map[string]*Tunnel) (*TunnelTest, error) {
tt, err := NewTunnelTest()
if err != nil {
return nil, err
}
if err = tt.Serve(tunnels); err != nil {
return nil, err
}
return tt, nil
}
func (tt *TunnelTest) serveSingle(ident string, t *Tunnel) (bool, error) {
// Verify tunnel dependencies for TCP tunnels.
if t.Type == TypeTCP {
// If tunnel specified by t.Client was not already started,
// skip and move on.
if _, ok := tt.Clients[t.ClientIdent]; t.ClientIdent != "" && !ok {
return false, nil
}
// Verify the TCP tunnel whose remote endpoint listens on multiple
// interfaces is already served.
if t.RemoteAddrIdent != "" {
if _, ok := tt.Listeners[t.RemoteAddrIdent]; !ok {
return false, nil
}
if tt.Tunnels[t.RemoteAddrIdent].Type != TypeTCP {
return false, fmt.Errorf("expected tunnel %q to be of TCP type", t.RemoteAddrIdent)
}
}
}
l, err := net.Listen("tcp", t.LocalAddr)
if err != nil {
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.LocalAddr, ident, err)
}
if tt.DebugNet {
l = dbgListener{l}
}
localAddr := l.Addr().String()
httpProxy := &tunnel.HTTPProxy{LocalAddr: localAddr}
tcpProxy := &tunnel.TCPProxy{FetchLocalAddr: tt.fetchLocalAddr}
cfg := &tunnel.ClientConfig{
Identifier: ident,
ServerAddr: tt.ServerAddr().String(),
Proxy: tunnel.Proxy(tunnel.ProxyFuncs{
HTTP: httpProxy.Proxy,
TCP: tcpProxy.Proxy,
}),
StateChanges: t.StateChanges,
Debug: testing.Verbose(),
}
// Register tunnel:
//
// - start tunnel.Client (tt.Clients[ident]) or reuse existing one (tt.Clients[t.ExistingClient])
// - listen on local address and start local server (tt.Listeners[ident][0])
// - register tunnel on tunnel.Server
//
switch t.Type {
case TypeHTTP:
// TODO(rjeczalik): refactor to separate method
h, ok := t.Handler.(http.Handler)
if !ok {
h, ok = t.Handler.(http.HandlerFunc)
if !ok {
fn, ok := t.Handler.(func(http.ResponseWriter, *http.Request))
if !ok {
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
}
h = http.HandlerFunc(fn)
}
}
logf("serving on local %s for HTTP tunnel %q", l.Addr(), ident)
go (&http.Server{Handler: h}).Serve(l)
tt.Server.AddHost(localAddr, ident)
tt.mu.Lock()
tt.Listeners[ident] = [2]net.Listener{l, nil}
tt.mu.Unlock()
if err := tt.addClient(ident, cfg); err != nil {
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
}
logf("registered HTTP tunnel: host=%s, ident=%s", localAddr, ident)
case TypeTCP:
// TODO(rjeczalik): refactor to separate method
h, ok := t.Handler.(func(net.Conn))
if !ok {
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
}
logf("serving on local %s for TCP tunnel %q", l.Addr(), ident)
go func() {
for {
conn, err := l.Accept()
if err != nil {
log.Printf("failed accepting conn for %q tunnel: %s", ident, err)
return
}
go h(conn)
}
}()
var remote net.Listener
if t.RemoteAddrIdent != "" {
tt.mu.Lock()
remote = tt.Listeners[t.RemoteAddrIdent][1]
tt.mu.Unlock()
} else {
remote, err = net.Listen("tcp", t.RemoteAddr)
if err != nil {
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.RemoteAddr, ident, err)
}
}
// addrIdent holds identifier of client which is going to have registered
// tunnel via (*tunnel.Server).AddAddr
addrIdent := ident
if t.ClientIdent != "" {
tt.Clients[ident] = tt.Clients[t.ClientIdent]
addrIdent = t.ClientIdent
}
tt.Server.AddAddr(remote, t.IP, addrIdent)
tt.mu.Lock()
tt.Listeners[ident] = [2]net.Listener{l, remote}
tt.mu.Unlock()
if _, ok := tt.Clients[ident]; !ok {
if err := tt.addClient(ident, cfg); err != nil {
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
}
}
logf("registered TCP tunnel: listener=%s, ip=%v, ident=%s", remote.Addr(), t.IP, addrIdent)
default:
return false, fmt.Errorf("unknown %q tunnel type: %d", ident, t.Type)
}
return true, nil
}
func (tt *TunnelTest) addClient(ident string, cfg *tunnel.ClientConfig) error {
if _, ok := tt.Clients[ident]; ok {
return fmt.Errorf("tunnel %q is already being served", ident)
}
c, err := tunnel.NewClient(cfg)
if err != nil {
return err
}
done := make(chan struct{})
tt.Server.OnConnect(ident, func() error {
close(done)
return nil
})
go c.Start()
<-c.StartNotify()
select {
case <-time.After(10 * time.Second):
return errors.New("timed out after 10s waiting on client to establish control conn")
case <-done:
}
tt.Clients[ident] = c
return nil
}
func (tt *TunnelTest) Serve(tunnels map[string]*Tunnel) error {
if len(tunnels) == 0 {
return errors.New("no tunnels to serve")
}
// Since one tunnels depends on others do 3 passes to start them
// all, each started tunnel is removed from the tunnels map.
// After 3 passes all of them must be started, otherwise the
// configuration is bad:
//
// - first pass starts HTTP tunnels as new client tunnels
// - second pass starts TCP tunnels that rely on on already existing client tunnels (t.ClientIdent)
// - third pass starts TCP tunnels that rely on on already existing TCP tunnels (t.RemoteAddrIdent)
//
for i := 0; i < 3; i++ {
if err := tt.popServedDeps(tunnels); err != nil {
return err
}
}
if len(tunnels) != 0 {
unresolved := make([]string, len(tunnels))
for ident := range tunnels {
unresolved = append(unresolved, ident)
}
sort.Strings(unresolved)
return fmt.Errorf("unable to start tunnels due to unresolved dependencies: %v", unresolved)
}
return nil
}
func (tt *TunnelTest) popServedDeps(tunnels map[string]*Tunnel) error {
for ident, t := range tunnels {
ok, err := tt.serveSingle(ident, t)
if err != nil {
return err
}
if ok {
// Remove already started tunnels so they won't get started again.
delete(tunnels, ident)
tt.Tunnels[ident] = t
}
}
return nil
}
func (tt *TunnelTest) fetchLocalAddr(port int) (string, error) {
tt.mu.Lock()
defer tt.mu.Unlock()
for _, l := range tt.Listeners {
if l[1] == nil {
// this listener does not belong to a TCP tunnel
continue
}
_, remotePort, err := parseHostPort(l[1].Addr().String())
if err != nil {
return "", err
}
if port == remotePort {
return l[0].Addr().String(), nil
}
}
return "", fmt.Errorf("no route for %d port", port)
}
func (tt *TunnelTest) ServerAddr() net.Addr {
return tt.Listeners[""][0].Addr()
}
// Addr gives server endpoint of the TCP tunnel for the given ident.
//
// If the tunnel does not exist or is a HTTP one, TunnelAddr return nil.
func (tt *TunnelTest) Addr(ident string) net.Addr {
l, ok := tt.Listeners[ident]
if !ok {
return nil
}
return l[1].Addr()
}
// Request creates a HTTP request to a server endpoint of the HTTP tunnel
// for the given ident.
//
// If the tunnel does not exist, Request returns nil.
func (tt *TunnelTest) Request(ident string, query url.Values) *http.Request {
l, ok := tt.Listeners[ident]
if !ok {
return nil
}
var raw string
if query != nil {
raw = query.Encode()
}
return &http.Request{
Method: "GET",
URL: &url.URL{
Scheme: "http",
Host: tt.ServerAddr().String(),
Path: "/",
RawQuery: raw,
},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Host: l[0].Addr().String(),
}
}
func (tt *TunnelTest) Close() (err error) {
// Close tunnel.Clients.
clients := make(map[*tunnel.Client]struct{})
for _, c := range tt.Clients {
clients[c] = struct{}{}
}
for c := range clients {
err = nonil(err, c.Close())
}
// Stop all TCP/HTTP servers.
listeners := make(map[net.Listener]struct{})
for _, l := range tt.Listeners {
for _, l := range l {
if l != nil {
listeners[l] = struct{}{}
}
}
}
for l := range listeners {
err = nonil(err, l.Close())
}
return err
}

121
lib/tunnel/util.go Normal file
View File

@ -0,0 +1,121 @@
package tunnel
import (
"crypto/tls"
"fmt"
"net"
"sync"
"time"
"git.xeserv.us/xena/route/lib/tunnel/proto"
"github.com/cenkalti/backoff"
)
// async is a helper function to convert a blocking function to a function
// returning an error. Useful for plugging function closures into select and co
func async(fn func() error) <-chan error {
errChan := make(chan error, 0)
go func() {
select {
case errChan <- fn():
default:
}
close(errChan)
}()
return errChan
}
type expBackoff struct {
mu sync.Mutex
bk *backoff.ExponentialBackOff
}
func newForeverBackoff() *expBackoff {
eb := &expBackoff{
bk: backoff.NewExponentialBackOff(),
}
eb.bk.MaxElapsedTime = 0 // never stops
return eb
}
func (eb *expBackoff) NextBackOff() time.Duration {
eb.mu.Lock()
defer eb.mu.Unlock()
return eb.bk.NextBackOff()
}
func (eb *expBackoff) Reset() {
eb.mu.Lock()
eb.bk.Reset()
eb.mu.Unlock()
}
type callbacks struct {
mu sync.Mutex
name string
funcs map[string]func() error
}
func newCallbacks(name string) *callbacks {
return &callbacks{
name: name,
funcs: make(map[string]func() error),
}
}
func (c *callbacks) add(identifier string, fn func() error) {
c.mu.Lock()
c.funcs[identifier] = fn
c.mu.Unlock()
}
func (c *callbacks) pop(identifier string) (func() error, error) {
c.mu.Lock()
defer c.mu.Unlock()
fn, ok := c.funcs[identifier]
if !ok {
return nil, nil // nop
}
delete(c.funcs, identifier)
if fn == nil {
return nil, fmt.Errorf("nil callback set for %q client", identifier)
}
return fn, nil
}
func (c *callbacks) call(identifier string) error {
fn, err := c.pop(identifier)
if err != nil {
return err
}
if fn == nil {
return nil // nop
}
return fn()
}
// Returns server control url as a string. Reads scheme and remote address from connection.
func controlURL(conn net.Conn) string {
return fmt.Sprint(scheme(conn), "://", conn.RemoteAddr(), proto.ControlPath)
}
func scheme(conn net.Conn) (scheme string) {
switch conn.(type) {
case *tls.Conn:
scheme = "https"
default:
scheme = "http"
}
return
}

179
lib/tunnel/virtualaddr.go Normal file
View File

@ -0,0 +1,179 @@
package tunnel
import (
"net"
"strconv"
"sync"
"sync/atomic"
"github.com/koding/logging"
)
type listener struct {
net.Listener
*vaddrOptions
done int32
// ips keeps track of registered clients for ip-based routing;
// when last client is deleted from the ip routing map, we stop
// listening on connections
ips map[string]struct{}
}
type vaddrOptions struct {
connCh chan<- net.Conn
log logging.Logger
}
type vaddrStorage struct {
*vaddrOptions
listeners map[net.Listener]*listener
ports map[int]string // port-based routing: maps port number to identifier
ips map[string]string // ip-based routing: maps ip address to identifier
mu sync.RWMutex
}
func newVirtualAddrs(opts *vaddrOptions) *vaddrStorage {
return &vaddrStorage{
vaddrOptions: opts,
listeners: make(map[net.Listener]*listener),
ports: make(map[int]string),
ips: make(map[string]string),
}
}
func (l *listener) serve() {
for {
conn, err := l.Accept()
if err != nil {
l.log.Error("failue listening on %q: %s", l.Addr(), err)
return
}
if atomic.LoadInt32(&l.done) != 0 {
l.log.Debug("stopped serving %q", l.Addr())
conn.Close()
return
}
l.connCh <- conn
}
}
func (l *listener) localAddr() string {
if addr, ok := l.Addr().(*net.TCPAddr); ok {
if addr.IP.Equal(net.IPv4zero) {
return net.JoinHostPort("127.0.0.1", strconv.Itoa(addr.Port))
}
}
return l.Addr().String()
}
func (l *listener) stop() {
if atomic.CompareAndSwapInt32(&l.done, 0, 1) {
// stop is called when no more connections should be accepted by
// the user-provided listener; as we can't simple close the listener
// to not break the guarantee given by the (*Server).DeleteAddr
// method, we make a dummy connection to break out of serve loop.
// It is safe to make a dummy connection, as either the following
// dial will time out when the listener is busy accepting connections,
// or will get closed immadiately after idle listeners accepts connection
// and returns from the serve loop.
conn, err := net.DialTimeout("tcp", l.localAddr(), defaultTimeout)
if err == nil {
conn.Close()
}
}
}
func (vaddr *vaddrStorage) Add(l net.Listener, ip net.IP, ident string) {
vaddr.mu.Lock()
defer vaddr.mu.Unlock()
lis, ok := vaddr.listeners[l]
if !ok {
lis = vaddr.newListener(l)
vaddr.listeners[l] = lis
go lis.serve()
}
if ip != nil {
lis.ips[ip.String()] = struct{}{}
vaddr.ips[ip.String()] = ident
} else {
vaddr.ports[mustPort(l)] = ident
}
}
func (vaddr *vaddrStorage) Delete(l net.Listener, ip net.IP) {
vaddr.mu.Lock()
defer vaddr.mu.Unlock()
lis, ok := vaddr.listeners[l]
if !ok {
return
}
var stop bool
if ip != nil {
delete(lis.ips, ip.String())
delete(vaddr.ips, ip.String())
stop = len(lis.ips) == 0
} else {
delete(vaddr.ports, mustPort(l))
stop = true
}
// Only stop listening for connections when listener has clients
// registered to tunnel the connections to.
if stop {
lis.stop()
delete(vaddr.listeners, l)
}
}
func (vaddr *vaddrStorage) newListener(l net.Listener) *listener {
return &listener{
Listener: l,
vaddrOptions: vaddr.vaddrOptions,
ips: make(map[string]struct{}),
}
}
func (vaddr *vaddrStorage) getIdent(conn net.Conn) (string, bool) {
vaddr.mu.Lock()
defer vaddr.mu.Unlock()
ip, port, err := parseHostPort(conn.LocalAddr().String())
if err != nil {
vaddr.log.Debug("failed to get identifier for connection %q: %s", conn.LocalAddr(), err)
return "", false
}
// First lookup if there's a ip-based route, then try port-base one.
if ident, ok := vaddr.ips[ip]; ok {
return ident, true
}
ident, ok := vaddr.ports[port]
return ident, ok
}
func mustPort(l net.Listener) int {
_, port, err := parseHostPort(l.Addr().String())
if err != nil {
// This can happened when user passed custom type that
// implements net.Listener, which returns ill-formed
// net.Addr value.
panic("ill-formed net.Addr: " + err.Error())
}
return port
}

77
lib/tunnel/virtualhost.go Normal file
View File

@ -0,0 +1,77 @@
package tunnel
import (
"sync"
)
type vhostStorage interface {
// AddHost adds the given host and identifier to the storage
AddHost(host, identifier string)
// DeleteHost deletes the given host
DeleteHost(host string)
// GetHost returns the host name for the given identifier
GetHost(identifier string) (string, bool)
// GetIdentifier returns the identifier for the given host
GetIdentifier(host string) (string, bool)
}
type virtualHost struct {
identifier string
}
// virtualHosts is used for mapping host to users example: host
// "fs-1-fatih.kd.io" belongs to user "arslan"
type virtualHosts struct {
mapping map[string]*virtualHost
sync.Mutex
}
// newVirtualHosts provides an in memory virtual host storage for mapping
// virtual hosts to identifiers.
func newVirtualHosts() *virtualHosts {
return &virtualHosts{
mapping: make(map[string]*virtualHost),
}
}
func (v *virtualHosts) AddHost(host, identifier string) {
v.Lock()
v.mapping[host] = &virtualHost{identifier: identifier}
v.Unlock()
}
func (v *virtualHosts) DeleteHost(host string) {
v.Lock()
delete(v.mapping, host)
v.Unlock()
}
// GetIdentifier returns the identifier associated with the given host
func (v *virtualHosts) GetIdentifier(host string) (string, bool) {
v.Lock()
ht, ok := v.mapping[host]
v.Unlock()
if !ok {
return "", false
}
return ht.identifier, true
}
// GetHost returns the host associated with the given identifier
func (v *virtualHosts) GetHost(identifier string) (string, bool) {
v.Lock()
defer v.Unlock()
for hostname, hst := range v.mapping {
if hst.identifier == identifier {
return hostname, true
}
}
return "", false
}

View File

@ -0,0 +1,69 @@
package tunnel_test
import (
"fmt"
"net/http"
"reflect"
"testing"
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
)
func testWebsocket(name string, n int, t *testing.T, tt *tunneltest.TunnelTest) {
conn, err := websocketDial(tt, "http")
if err != nil {
t.Fatalf("Dial()=%s", err)
}
defer conn.Close()
for i := 0; i < n; i++ {
want := &EchoMessage{
Value: fmt.Sprintf("message #%d", i),
Close: i == (n - 1),
}
err := conn.WriteJSON(want)
if err != nil {
t.Errorf("(test %s) %d: failed sending %q: %s", name, i, want, err)
continue
}
got := &EchoMessage{}
err = conn.ReadJSON(got)
if err != nil {
t.Errorf("(test %s) %d: failed reading: %s", name, i, err)
continue
}
if !reflect.DeepEqual(got, want) {
t.Errorf("(test %s) %d: got %+v, want %+v", name, i, got, want)
}
}
}
func testHandler(t *testing.T, fn func(w http.ResponseWriter, r *http.Request) error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := fn(w, r); err != nil {
t.Errorf("handler func error: %s", err)
}
}
}
func TestWebsocket(t *testing.T) {
tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(nil))))
if err != nil {
t.Fatal(err)
}
testWebsocket("handlerEchoWS", 100, t, tt)
}
func TestLatencyWebsocket(t *testing.T) {
tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(sleep))))
if err != nil {
t.Fatal(err)
}
testWebsocket("handlerLatencyEchoWS", 20, t, tt)
}