diff --git a/lib/tunnel/.travis.yml b/lib/tunnel/.travis.yml new file mode 100644 index 0000000..ca5e23b --- /dev/null +++ b/lib/tunnel/.travis.yml @@ -0,0 +1,19 @@ +language: go + +sudo: false + +addons: + apt: + packages: + - moreutils + +go: + - 1.4.3 + - 1.6.3 + - 1.7 + +script: + - export GOMAXPROCS=$(nproc) + - gofmt -s -l . | ifne false + - go build ./... + - go test -race ./... diff --git a/lib/tunnel/LICENSE b/lib/tunnel/LICENSE new file mode 100644 index 0000000..5f26881 --- /dev/null +++ b/lib/tunnel/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2015 The Koding Authors. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of Koding Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/lib/tunnel/README.md b/lib/tunnel/README.md new file mode 100644 index 0000000..debc0ca --- /dev/null +++ b/lib/tunnel/README.md @@ -0,0 +1,91 @@ +# Tunnel [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/koding/tunnel) [![Go Report Card](https://goreportcard.com/badge/github.com/koding/tunnel)](https://goreportcard.com/report/github.com/koding/tunnel) [![Build Status](http://img.shields.io/travis/koding/tunnel.svg?style=flat-square)](https://travis-ci.org/koding/tunnel) + +Tunnel is a server/client package that enables to proxy public connections to +your local machine over a tunnel connection from the local machine to the +public server. What this means is, you can share your localhost even if it +doesn't have a Public IP or if it's not reachable from outside. + +It uses the excellent [yamux](https://github.com/hashicorp/yamux) package to +multiplex connections between server and client. + +The project is under active development, please vendor it if you want to use it. + +# Usage + +The tunnel package consists of two parts. The `server` and the `client`. + +Server is the public facing part. It's type that satisfies the `http.Handler`. +So it's easily pluggable into existing servers. + + +Let assume that you setup your DNS service so all `*.example.com` domains route +to your server at the public IP `203.0.113.0`. Let us first create the server +part: + +```go +package main + +import ( + "net/http" + + "github.com/koding/tunnel" +) + +func main() { + cfg := &tunnel.ServerConfig{} + server, _ := tunnel.NewServer(cfg) + server.AddHost("sub.example.com", "1234") + http.ListenAndServe(":80", server) +} +``` + +Once you create the `server`, you just plug it into your server. The only +detail here is to map a virtualhost to a secret token. The secret token is the +only part that needs to be known for the client side. + +Let us now create the client side part: + +```go +package main + +import "github.com/koding/tunnel" + +func main() { + cfg := &tunnel.ClientConfig{ + Identifier: "1234", + ServerAddr: "203.0.113.0:80", + } + + client, err := tunnel.NewClient(cfg) + if err != nil { + panic(err) + } + + client.Start() +} +``` + +The `Start()` method is by default blocking. As you see you, we just passed the +server address and the secret token. + +Now whenever someone hit `sub.example.com`, the request will be proxied to the +machine where client is running and hit the local server running `127.0.0.1:80` +(assuming there is one). If someone hits `sub.example.com:3000` (assume your +server is running at this port), it'll be routed to `127.0.0.1:3000` + +That's it. + +There are many options that can be changed, such as a static local address for +your client. Have alook at the +[documentation](http://godoc.org/github.com/koding/tunnel) + + +# Protocol + +The server/client protocol is written in the [spec.md](spec.md) file. Please +have a look for more detail. + + +## License + +The BSD 3-Clause License - see LICENSE for more details diff --git a/lib/tunnel/client.go b/lib/tunnel/client.go new file mode 100644 index 0000000..acd9c36 --- /dev/null +++ b/lib/tunnel/client.go @@ -0,0 +1,565 @@ +package tunnel + +import ( + "bufio" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/koding/logging" + "git.xeserv.us/xena/route/lib/tunnel/proto" + + "github.com/hashicorp/yamux" +) + +//go:generate stringer -type ClientState + +// ErrRedialAborted is emitted on ClientClosed event, when backoff policy +// used by a client decided no more reconnection attempts must be made. +var ErrRedialAborted = errors.New("unable to restore the connection, aborting") + +// ClientState represents client connection state to tunnel server. +type ClientState uint32 + +// ClientState enumeration. +const ( + ClientUnknown ClientState = iota + ClientStarted + ClientConnecting + ClientConnected + ClientDisconnected + ClientClosed // keep it always last +) + +// ClientStateChange represents single client state transition. +type ClientStateChange struct { + Identifier string + Previous ClientState + Current ClientState + Error error +} + +// Strings implements the fmt.Stringer interface. +func (cs *ClientStateChange) String() string { + if cs.Error != nil { + return fmt.Sprintf("[%s] %s->%s (%s)", cs.Identifier, cs.Previous, cs.Current, cs.Error) + } + return fmt.Sprintf("[%s] %s->%s", cs.Identifier, cs.Previous, cs.Current) +} + +// Backoff defines behavior of staggering reconnection retries. +type Backoff interface { + // Next returns the duration to sleep before retrying reconnections. + // If the returned value is negative, the retry is aborted. + NextBackOff() time.Duration + + // Reset is used to signal a reconnection was successful and next + // call to Next should return desired time duration for 1st reconnection + // attempt. + Reset() +} + +// Client is responsible for creating a control connection to a tunnel server, +// creating new tunnels and proxy them to tunnel server. +type Client struct { + // underlying yamux session + session *yamux.Session + + // config holds the ClientConfig + config *ClientConfig + + // yamuxConfig is passed to new yamux.Session's + yamuxConfig *yamux.Config + + // proxy handles local server communication. + proxy ProxyFunc + + // startNotify is a chanel user can get to be notified when client is + // connected to the server. The preferred way of doing this however, + // would be using StateChanges in ClientConfig where user can provide + // his own channel. + startNotify chan bool + // closed is a flag set when client calls Close() and quits. + closed bool + // closedMu guards both closed flag and startNotify channel. Since library + // owns the channel it's cleared when trying to reconnect. + closedMu sync.RWMutex + + reqWg sync.WaitGroup + ctrlWg sync.WaitGroup + + state ClientState + + // redialBackoff is used to reconnect in exponential backoff intervals + redialBackoff Backoff + + log logging.Logger +} + +// ClientConfig defines the configuration for the Client +type ClientConfig struct { + // Identifier is the secret token that needs to be passed to the server. + // Required if FetchIdentifier is not set. + Identifier string + + // FetchIdentifier can be used to fetch identifier. Required if Identifier + // is not set. + FetchIdentifier func() (string, error) + + // ServerAddr defines the TCP address of the tunnel server to be connected. + // Required if FetchServerAddr is not set. + ServerAddr string + + // FetchServerAddr can be used to fetch tunnel server address. + // Required if ServerAddress is not set. + FetchServerAddr func() (string, error) + + // Dial provides custom transport layer for client server communication. + // + // If nil, default implementation is to return net.Dial("tcp", address). + // + // It can be used for connection monitoring, setting different timeouts or + // securing the connection. + Dial func(network, address string) (net.Conn, error) + + // Proxy defines custom proxing logic. This is optional extension point + // where you can provide your local server selection or communication rules. + Proxy ProxyFunc + + // StateChanges receives state transition details each time client + // connection state changes. The channel is expected to be sufficiently + // buffered to keep up with event pace. + // + // If nil, no information about state transitions are dispatched + // by the library. + StateChanges chan<- *ClientStateChange + + // Backoff is used to control behavior of staggering reconnection loop. + // + // If nil, default backoff policy is used which makes a client to never + // give up on reconnection. + // + // If custom backoff is used, client will emit ErrRedialAborted set + // with ClientClosed event when no more reconnection atttemps should + // be made. + Backoff Backoff + + // YamuxConfig defines the config which passed to every new yamux.Session. If nil + // yamux.DefaultConfig() is used. + YamuxConfig *yamux.Config + + // Log defines the logger. If nil a default logging.Logger is used. + Log logging.Logger + + // Debug enables debug mode, enable only if you want to debug the server. + Debug bool + + // DEPRECATED: + + // LocalAddr is DEPRECATED please use ProxyHTTP.LocalAddr, see ProxyOverwrite for more details. + LocalAddr string + + // FetchLocalAddr is DEPRECATED please use ProxyTCP.FetchLocalAddr, see ProxyOverwrite for more details. + FetchLocalAddr func(port int) (string, error) +} + +// verify is used to verify the ClientConfig +func (c *ClientConfig) verify() error { + if c.ServerAddr == "" && c.FetchServerAddr == nil { + return errors.New("neither ServerAddr nor FetchServerAddr is set") + } + + if c.Identifier == "" && c.FetchIdentifier == nil { + return errors.New("neither Identifier nor FetchIdentifier is set") + } + + if c.YamuxConfig != nil { + if err := yamux.VerifyConfig(c.YamuxConfig); err != nil { + return err + } + } + + if c.Proxy != nil && (c.LocalAddr != "" || c.FetchLocalAddr != nil) { + return errors.New("both Proxy and LocalAddr or FetchLocalAddr are set") + } + + return nil +} + +// NewClient creates a new tunnel that is established between the serverAddr +// and localAddr. It exits if it can't create a new control connection to the +// server. If localAddr is empty client will always try to proxy to a local +// port. +func NewClient(cfg *ClientConfig) (*Client, error) { + if err := cfg.verify(); err != nil { + return nil, err + } + + yamuxConfig := yamux.DefaultConfig() + if cfg.YamuxConfig != nil { + yamuxConfig = cfg.YamuxConfig + } + + var proxy = DefaultProxy + if cfg.Proxy != nil { + proxy = cfg.Proxy + } + // DEPRECATED API SUPPORT + if cfg.LocalAddr != "" || cfg.FetchLocalAddr != nil { + var f ProxyFuncs + if cfg.LocalAddr != "" { + f.HTTP = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy + f.WS = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy + } + if cfg.FetchLocalAddr != nil { + f.TCP = (&TCPProxy{FetchLocalAddr: cfg.FetchLocalAddr}).Proxy + } + proxy = Proxy(f) + } + + var bo Backoff = newForeverBackoff() + if cfg.Backoff != nil { + bo = cfg.Backoff + } + + log := newLogger("tunnel-client", cfg.Debug) + if cfg.Log != nil { + log = cfg.Log + } + + client := &Client{ + config: cfg, + yamuxConfig: yamuxConfig, + proxy: proxy, + startNotify: make(chan bool, 1), + redialBackoff: bo, + log: log, + } + + return client, nil +} + +// Start starts the client and connects to the server with the identifier. +// client.FetchIdentifier() will be used if it's not nil. It's supports +// reconnecting with exponential backoff intervals when the connection to the +// server disconnects. Call client.Close() to shutdown the client completely. A +// successful connection will cause StartNotify() to receive a value. +func (c *Client) Start() { + fetchIdent := func() (string, error) { + if c.config.FetchIdentifier != nil { + return c.config.FetchIdentifier() + } + + return c.config.Identifier, nil + } + + fetchServerAddr := func() (string, error) { + if c.config.FetchServerAddr != nil { + return c.config.FetchServerAddr() + } + + return c.config.ServerAddr, nil + } + + c.changeState(ClientStarted, nil) + + c.redialBackoff.Reset() + var lastErr error + for { + prev := c.changeState(ClientConnecting, lastErr) + + if c.isRetry(prev) { + dur := c.redialBackoff.NextBackOff() + if dur < 0 { + c.setClosed(true) + c.changeState(ClientClosed, ErrRedialAborted) + return + } + + time.Sleep(dur) + + // exit if closed + if c.isClosed() { + c.changeState(ClientClosed, lastErr) + return + } + } + + identifier, err := fetchIdent() + if err != nil { + lastErr = err + c.log.Critical("client fetch identifier error: %s", err) + continue + } + + serverAddr, err := fetchServerAddr() + if err != nil { + lastErr = err + c.log.Critical("client fetch server address error: %s", err) + continue + } + + c.setClosed(false) + + if err := c.connect(identifier, serverAddr); err != nil { + lastErr = err + c.log.Debug("client connect error: %s", err) + } + + // exit if closed + if c.isClosed() { + c.changeState(ClientClosed, lastErr) + return + } + } +} + +// Close closes the client and shutdowns the connection to the tunnel server +func (c *Client) Close() error { + defer c.setClosed(true) + + if c.session == nil { + return errors.New("session is not initialized") + } + + // wait until all connections are finished + waitCh := make(chan struct{}) + go func() { + if err := c.session.GoAway(); err != nil { + c.log.Debug("Session go away failed: %s", err) + } + + c.reqWg.Wait() + close(waitCh) + }() + select { + case <-waitCh: + // ok + case <-time.After(time.Second * 10): + c.log.Info("Timeout waiting for connections to finish") + } + + if err := c.session.Close(); err != nil { + return err + } + + return nil +} + +// isClosed securely checks if client is marked as closed. +func (c *Client) isClosed() bool { + c.closedMu.RLock() + defer c.closedMu.RUnlock() + return c.closed +} + +// setClosed securely marks client as closed (or not closed). If not closed +// also empty the value inside the startNotify channel by retrieving it (if any), +// so it doesn't block during connect, when the client was closed and started again, +// and startNotify was never listened to. +func (c *Client) setClosed(closed bool) { + c.closedMu.Lock() + defer c.closedMu.Unlock() + c.closed = closed + + if !closed { + // clear channel + select { + case <-c.startNotify: + default: + } + } +} + +// startNotifyIfNeeded sends ok to startNotify channel if it's listened to. +// This function is called by connect when connection was successful. +func (c *Client) startNotifyIfNeeded() { + c.closedMu.RLock() + if !c.closed { + c.log.Debug("sending ok to startNotify chan") + select { + case c.startNotify <- true: + default: + // reaching here means the client never read the signal via + // StartNotify(). This is OK, we shouldn't except it the consumer + // to read from this channel. It's optional, so we just drop the + // signal. + c.log.Debug("startNotify message was dropped") + } + } + c.closedMu.RUnlock() +} + +// StartNotify returns a channel that receives a single value when the client +// established a successful connection to the server. +func (c *Client) StartNotify() <-chan bool { + return c.startNotify +} + +func (c *Client) changeState(state ClientState, err error) (prev ClientState) { + prev = ClientState(atomic.LoadUint32((*uint32)(&c.state))) + + if c.config.StateChanges != nil { + change := &ClientStateChange{ + Identifier: c.config.Identifier, + Previous: ClientState(prev), + Current: state, + Error: err, + } + + select { + case c.config.StateChanges <- change: + default: + c.log.Warning("Dropping state change due to slow reader: %s", change) + } + } + + atomic.CompareAndSwapUint32((*uint32)(&c.state), uint32(prev), uint32(state)) + + return prev +} + +func (c *Client) isRetry(state ClientState) bool { + return state != ClientStarted && state != ClientClosed +} + +func (c *Client) connect(identifier, serverAddr string) error { + c.log.Debug("Trying to connect to %q with identifier %q", serverAddr, identifier) + + conn, err := c.dial(serverAddr) + if err != nil { + return err + } + + remoteURL := controlURL(conn) + c.log.Debug("CONNECT to %q", remoteURL) + req, err := http.NewRequest("CONNECT", remoteURL, nil) + if err != nil { + return fmt.Errorf("error creating request to %s: %s", remoteURL, err) + } + + req.Header.Set(proto.ClientIdentifierHeader, identifier) + + c.log.Debug("Writing request to TCP: %+v", req) + + if err := req.Write(conn); err != nil { + return fmt.Errorf("writing CONNECT request to %s failed: %s", req.URL, err) + } + + c.log.Debug("Reading response from TCP") + + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return fmt.Errorf("reading CONNECT response from %s failed: %s", req.URL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK || resp.Status != proto.Connected { + out, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("tunnel server error: status=%d, error=%s", resp.StatusCode, err) + } + + return fmt.Errorf("tunnel server error: status=%d, body=%s", resp.StatusCode, string(out)) + } + + c.ctrlWg.Wait() // wait until previous listenControl observes disconnection + + c.session, err = yamux.Client(conn, c.yamuxConfig) + if err != nil { + return fmt.Errorf("session initialization failed: %s", err) + } + + var stream net.Conn + openStream := func() error { + // this is blocking until client opens a session to us + stream, err = c.session.Open() + return err + } + + // if we don't receive anything from the server, we'll timeout + select { + case err := <-async(openStream): + if err != nil { + return fmt.Errorf("waiting for session to open failed: %s", err) + } + case <-time.After(time.Second * 10): + if stream != nil { + stream.Close() + } + return errors.New("timeout opening session") + } + + if _, err := stream.Write([]byte(proto.HandshakeRequest)); err != nil { + return fmt.Errorf("writing handshake request failed: %s", err) + } + + buf := make([]byte, len(proto.HandshakeResponse)) + if _, err := stream.Read(buf); err != nil { + return fmt.Errorf("reading handshake response failed: %s", err) + } + + if string(buf) != proto.HandshakeResponse { + return fmt.Errorf("invalid handshake response, received: %s", string(buf)) + } + + ct := newControl(stream) + c.log.Debug("client has started successfully") + c.redialBackoff.Reset() // we successfully connected, so we can reset the backoff + + c.startNotifyIfNeeded() + + return c.listenControl(ct) +} + +func (c *Client) dial(serverAddr string) (net.Conn, error) { + if c.config.Dial != nil { + return c.config.Dial("tcp", serverAddr) + } + + return net.Dial("tcp", serverAddr) +} + +func (c *Client) listenControl(ct *control) error { + c.ctrlWg.Add(1) + defer c.ctrlWg.Done() + + c.changeState(ClientConnected, nil) + + for { + var msg proto.ControlMessage + if err := ct.dec.Decode(&msg); err != nil { + c.reqWg.Wait() // wait until all requests are finished + c.session.GoAway() + c.session.Close() + c.changeState(ClientDisconnected, err) + + return fmt.Errorf("failure decoding control message: %s", err) + } + + c.log.Debug("Received control msg %+v", msg) + c.log.Debug("Opening a new stream from server session") + + remote, err := c.session.Open() + if err != nil { + return err + } + + isHTTP := msg.Protocol == proto.HTTP + if isHTTP { + c.reqWg.Add(1) + } + go func() { + c.proxy(remote, &msg) + if isHTTP { + c.reqWg.Done() + } + remote.Close() + }() + } +} diff --git a/lib/tunnel/clientstate_string.go b/lib/tunnel/clientstate_string.go new file mode 100644 index 0000000..bac8019 --- /dev/null +++ b/lib/tunnel/clientstate_string.go @@ -0,0 +1,16 @@ +// Code generated by "stringer -type ClientState"; DO NOT EDIT + +package tunnel + +import "fmt" + +const _ClientState_name = "ClientUnknownClientStartedClientConnectingClientConnectedClientDisconnectedClientClosed" + +var _ClientState_index = [...]uint8{0, 13, 26, 42, 57, 75, 87} + +func (i ClientState) String() string { + if i >= ClientState(len(_ClientState_index)-1) { + return fmt.Sprintf("ClientState(%d)", i) + } + return _ClientState_name[_ClientState_index[i]:_ClientState_index[i+1]] +} diff --git a/lib/tunnel/control.go b/lib/tunnel/control.go new file mode 100644 index 0000000..5abc387 --- /dev/null +++ b/lib/tunnel/control.go @@ -0,0 +1,110 @@ +package tunnel + +import ( + "encoding/json" + "errors" + "net" + "sync" +) + +var errControlClosed = errors.New("control connection is closed") + +type control struct { + // enc and dec are responsible for encoding and decoding json values forth + // and back + enc *json.Encoder + dec *json.Decoder + + // underlying connection responsible for encoder and decoder + nc net.Conn + + // identifier associated with this control + identifier string + + mu sync.Mutex // guards the following + closed bool // if Close() and quits +} + +func newControl(nc net.Conn) *control { + c := &control{ + enc: json.NewEncoder(nc), + dec: json.NewDecoder(nc), + nc: nc, + } + + return c +} + +func (c *control) send(v interface{}) error { + if c.enc == nil { + return errors.New("encoder is not initialized") + } + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return errControlClosed + } + c.mu.Unlock() + + return c.enc.Encode(v) +} + +func (c *control) recv(v interface{}) error { + if c.dec == nil { + return errors.New("decoder is not initialized") + } + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return errControlClosed + } + c.mu.Unlock() + + return c.dec.Decode(v) +} + +func (c *control) Close() error { + if c.nc == nil { + return nil + } + + c.mu.Lock() + c.closed = true + c.mu.Unlock() + + return c.nc.Close() +} + +type controls struct { + sync.Mutex + controls map[string]*control +} + +func newControls() *controls { + return &controls{ + controls: make(map[string]*control), + } +} + +func (c *controls) getControl(identifier string) (*control, bool) { + c.Lock() + control, ok := c.controls[identifier] + c.Unlock() + return control, ok +} + +func (c *controls) addControl(identifier string, control *control) { + control.identifier = identifier + + c.Lock() + c.controls[identifier] = control + c.Unlock() +} + +func (c *controls) deleteControl(identifier string) { + c.Lock() + delete(c.controls, identifier) + c.Unlock() +} diff --git a/lib/tunnel/helper_test.go b/lib/tunnel/helper_test.go new file mode 100644 index 0000000..adab80a --- /dev/null +++ b/lib/tunnel/helper_test.go @@ -0,0 +1,263 @@ +package tunnel_test + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net" + "net/http" + "net/url" + "os" + "time" + + "git.xeserv.us/xena/route/lib/tunnel" + "git.xeserv.us/xena/route/lib/tunnel/tunneltest" + + "github.com/gorilla/websocket" +) + +func init() { + rand.Seed(time.Now().UnixNano() + int64(os.Getpid())) +} + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +type EchoMessage struct { + Value string `json:"value,omitempty"` + Close bool `json:"close,omitempty"` +} + +var timeout = 10 * time.Second + +var dialer = &websocket.Dialer{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: timeout, + NetDial: func(_, addr string) (net.Conn, error) { + return net.DialTimeout("tcp4", addr, timeout) + }, +} + +func echoHTTP(tt *tunneltest.TunnelTest, echo string) (string, error) { + req := tt.Request("http", url.Values{"echo": []string{echo}}) + if req == nil { + return "", fmt.Errorf(`tunnel "http" does not exist`) + } + + req.Close = rand.Int()%2 == 0 + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + p, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + + return string(bytes.TrimSpace(p)), nil +} + +func echoTCP(tt *tunneltest.TunnelTest, echo string) (string, error) { + return echoTCPIdent(tt, echo, "tcp") +} + +func echoTCPIdent(tt *tunneltest.TunnelTest, echo, ident string) (string, error) { + addr := tt.Addr(ident) + if addr == nil { + return "", fmt.Errorf("tunnel %q does not exist", ident) + } + s := addr.String() + ip := tt.Tunnels[ident].IP + if ip != nil { + _, port, err := net.SplitHostPort(s) + if err != nil { + return "", err + } + s = net.JoinHostPort(ip.String(), port) + } + + c, err := dialTCP(s) + if err != nil { + return "", err + } + + c.out <- echo + + select { + case reply := <-c.in: + return reply, nil + case <-time.After(tcpTimeout): + return "", fmt.Errorf("timed out waiting for reply from %s (%s) after %v", s, addr, tcpTimeout) + } +} + +func websocketDial(tt *tunneltest.TunnelTest, ident string) (*websocket.Conn, error) { + req := tt.Request(ident, nil) + if req == nil { + return nil, fmt.Errorf("no client found for ident %q", ident) + } + + h := http.Header{"Host": {req.Host}} + wsurl := fmt.Sprintf("ws://%s", tt.ServerAddr()) + + conn, _, err := dialer.Dial(wsurl, h) + return conn, err +} + +func sleep() { + time.Sleep(time.Duration(rand.Intn(2000)) * time.Millisecond) +} + +func handlerEchoWS(sleepFn func()) func(w http.ResponseWriter, r *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) (e error) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return err + } + defer func() { + err := conn.Close() + if e == nil { + e = err + } + }() + + if sleepFn != nil { + sleepFn() + } + + for { + var msg EchoMessage + err := conn.ReadJSON(&msg) + if err != nil { + return fmt.Errorf("ReadJSON error: %s", err) + } + + if sleepFn != nil { + sleepFn() + } + + err = conn.WriteJSON(&msg) + if err != nil { + return fmt.Errorf("WriteJSON error: %s", err) + } + + if msg.Close { + return nil + } + } + } +} + +func handlerEchoHTTP(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.URL.Query().Get("echo")) +} + +func handlerLatencyEchoHTTP(w http.ResponseWriter, r *http.Request) { + sleep() + handlerEchoHTTP(w, r) +} + +func handlerEchoTCP(conn net.Conn) { + io.Copy(conn, conn) +} + +func handlerLatencyEchoTCP(conn net.Conn) { + sleep() + handlerEchoTCP(conn) +} + +var tcpTimeout = 10 * time.Second + +type tcpClient struct { + conn net.Conn + scanner *bufio.Scanner + in chan string + out chan string +} + +func (c *tcpClient) loop() { + for out := range c.out { + if _, err := fmt.Fprintln(c.conn, out); err != nil { + log.Printf("[tunnelclient] error writing %q to %q: %s", out, c.conn.RemoteAddr(), err) + return + } + + if !c.scanner.Scan() { + log.Printf("[tunnelclient] error reading from %q: %v", c.conn.RemoteAddr(), c.scanner.Err()) + return + } + + c.in <- c.scanner.Text() + } +} + +func (c *tcpClient) Close() error { + close(c.out) + return c.conn.Close() +} + +func dialTCP(addr string) (*tcpClient, error) { + conn, err := net.DialTimeout("tcp", addr, tcpTimeout) + if err != nil { + return nil, err + } + + c := &tcpClient{ + conn: conn, + scanner: bufio.NewScanner(conn), + in: make(chan string, 1), + out: make(chan string, 1), + } + + go c.loop() + + return c, nil +} + +func singleHTTP(handler interface{}) map[string]*tunneltest.Tunnel { + return singleRecHTTP(handler, nil) +} + +func singleRecHTTP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel { + return map[string]*tunneltest.Tunnel{ + "http": { + Type: tunneltest.TypeHTTP, + LocalAddr: "127.0.0.1:0", + Handler: handler, + StateChanges: stateChanges, + }, + } +} + +func singleTCP(handler interface{}) map[string]*tunneltest.Tunnel { + return singleRecTCP(handler, nil) +} + +func singleRecTCP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel { + return map[string]*tunneltest.Tunnel{ + "http": { + Type: tunneltest.TypeHTTP, + LocalAddr: "127.0.0.1:0", + Handler: handlerEchoHTTP, + StateChanges: stateChanges, + }, + "tcp": { + Type: tunneltest.TypeTCP, + ClientIdent: "http", + LocalAddr: "127.0.0.1:0", + RemoteAddr: "127.0.0.1:0", + Handler: handler, + }, + } +} diff --git a/lib/tunnel/httpproxy.go b/lib/tunnel/httpproxy.go new file mode 100644 index 0000000..2e18ab1 --- /dev/null +++ b/lib/tunnel/httpproxy.go @@ -0,0 +1,115 @@ +package tunnel + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + + "github.com/koding/logging" + "git.xeserv.us/xena/route/lib/tunnel/proto" +) + +var ( + httpLog = logging.NewLogger("http") +) + +// HTTPProxy forwards HTTP traffic. +// +// When tunnel server requests a connection it's proxied to 127.0.0.1:incomingPort +// where incomingPort is control message LocalPort. +// Usually this is tunnel server's public exposed Port. +// This behaviour can be changed by setting LocalAddr or FetchLocalAddr. +// FetchLocalAddr takes precedence over LocalAddr. +// +// When connection to local server cannot be established proxy responds with http error message. +type HTTPProxy struct { + // LocalAddr defines the TCP address of the local server. + // This is optional if you want to specify a single TCP address. + LocalAddr string + // FetchLocalAddr is used for looking up TCP address of the server. + // This is optional if you want to specify a dynamic TCP address based on incommig port. + FetchLocalAddr func(port int) (string, error) + // ErrorResp is custom response send to tunnel server when client cannot + // establish connection to local server. If not set a default "no local server" + // response is sent. + ErrorResp *http.Response + // Log is a custom logger that can be used for the proxy. + // If not set a "http" logger is used. + Log logging.Logger +} + +// Proxy is a ProxyFunc. +func (p *HTTPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) { + if msg.Protocol != proto.HTTP && msg.Protocol != proto.WS { + panic("Proxy mismatch") + } + + var log = p.log() + + var port = msg.LocalPort + if port == 0 { + port = 80 + } + + var localAddr = fmt.Sprintf("127.0.0.1:%d", port) + if p.LocalAddr != "" { + localAddr = p.LocalAddr + } else if p.FetchLocalAddr != nil { + l, err := p.FetchLocalAddr(msg.LocalPort) + if err != nil { + log.Warning("Failed to get custom local address: %s", err) + p.sendError(remote) + return + } + localAddr = l + } + + log.Debug("Dialing local server %q", localAddr) + local, err := net.DialTimeout("tcp", localAddr, defaultTimeout) + if err != nil { + log.Error("Dialing local server %q failed: %s", localAddr, err) + p.sendError(remote) + return + } + + Join(local, remote, log) +} + +func (p *HTTPProxy) sendError(remote net.Conn) { + var w = noLocalServer() + if p.ErrorResp != nil { + w = p.ErrorResp + } + + buf := new(bytes.Buffer) + w.Write(buf) + if _, err := io.Copy(remote, buf); err != nil { + var log = p.log() + log.Debug("Copy in-mem response error: %s", err) + } + + remote.Close() +} + +func noLocalServer() *http.Response { + body := bytes.NewBufferString("no local server") + return &http.Response{ + Status: http.StatusText(http.StatusServiceUnavailable), + StatusCode: http.StatusServiceUnavailable, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: ioutil.NopCloser(body), + ContentLength: int64(body.Len()), + } +} + +func (p *HTTPProxy) log() logging.Logger { + if p.Log != nil { + return p.Log + } + return httpLog +} diff --git a/lib/tunnel/proto/control_msg.go b/lib/tunnel/proto/control_msg.go new file mode 100644 index 0000000..283fcd9 --- /dev/null +++ b/lib/tunnel/proto/control_msg.go @@ -0,0 +1,26 @@ +package proto + +// ControlMessage is sent from server to client to establish tunneled connection. +type ControlMessage struct { + Action Action `json:"action"` + Protocol Type `json:"transportProtocol"` + LocalPort int `json:"localPort"` +} + +// Action represents type of ControlMsg request. +type Action int + +// ControlMessage actions. +const ( + RequestClientSession Action = iota + 1 +) + +// Type represents tunneled connection type. +type Type int + +// ControlMessage protocols. +const ( + HTTP Type = iota + 1 + TCP + WS +) diff --git a/lib/tunnel/proto/proto.go b/lib/tunnel/proto/proto.go new file mode 100644 index 0000000..7321c33 --- /dev/null +++ b/lib/tunnel/proto/proto.go @@ -0,0 +1,19 @@ +// Package proto defines tunnel client server communication protocol. +package proto + +const ( + // ControlPath is http.Handler url path for control connection. + ControlPath = "/_controlPath/" + + // ClientIdentifierHeader is header carrying information about tunnel identifier. + ClientIdentifierHeader = "X-KTunnel-Identifier" + + // control messages + + // Connected is message sent by server to client when control connection was established. + Connected = "200 Connected to Tunnel" + // HandshakeRequest is hello message sent by client to server. + HandshakeRequest = "controlHandshake" + // HandshakeResponse is response to HandshakeRequest sent by server to client. + HandshakeResponse = "controlOk" +) diff --git a/lib/tunnel/proxy.go b/lib/tunnel/proxy.go new file mode 100644 index 0000000..b37dd33 --- /dev/null +++ b/lib/tunnel/proxy.go @@ -0,0 +1,101 @@ +package tunnel + +import ( + "io" + "net" + "sync" + + "github.com/koding/logging" + "git.xeserv.us/xena/route/lib/tunnel/proto" +) + +// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back. +type ProxyFunc func(remote net.Conn, msg *proto.ControlMessage) + +var ( + // DefaultProxyFuncs holds global default proxy functions for all transport protocols. + DefaultProxyFuncs = ProxyFuncs{ + HTTP: new(HTTPProxy).Proxy, + TCP: new(TCPProxy).Proxy, + WS: new(HTTPProxy).Proxy, + } + // DefaultProxy is a ProxyFunc that uses DefaultProxyFuncs. + DefaultProxy = Proxy(ProxyFuncs{}) +) + +// ProxyFuncs is a collection of ProxyFunc. +type ProxyFuncs struct { + // HTTP is custom implementation of HTTP proxing. + HTTP ProxyFunc + // TCP is custom implementation of TCP proxing. + TCP ProxyFunc + // WS is custom implementation of web socket proxing. + WS ProxyFunc +} + +// Proxy returns a ProxyFunc that uses custom function if provided, otherwise falls back to DefaultProxyFuncs. +func Proxy(p ProxyFuncs) ProxyFunc { + return func(remote net.Conn, msg *proto.ControlMessage) { + var f ProxyFunc + switch msg.Protocol { + case proto.HTTP: + f = DefaultProxyFuncs.HTTP + if p.HTTP != nil { + f = p.HTTP + } + case proto.TCP: + f = DefaultProxyFuncs.TCP + if p.TCP != nil { + f = p.TCP + } + case proto.WS: + f = DefaultProxyFuncs.WS + if p.WS != nil { + f = p.WS + } + } + + if f == nil { + logging.Error("Could not determine proxy function for %v", msg) + remote.Close() + } + + f(remote, msg) + } +} + +// Join copies data between local and remote connections. +// It reads from one connection and writes to the other. +// It's a building block for ProxyFunc implementations. +func Join(local, remote net.Conn, log logging.Logger) { + var wg sync.WaitGroup + wg.Add(2) + + transfer := func(side string, dst, src net.Conn) { + log.Debug("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr()) + + n, err := io.Copy(dst, src) + if err != nil { + log.Error("%s: copy error: %s", side, err) + } + + if err := src.Close(); err != nil { + log.Debug("%s: close error: %s", side, err) + } + + // not for yamux streams, but for client to local server connections + if d, ok := dst.(*net.TCPConn); ok { + if err := d.CloseWrite(); err != nil { + log.Debug("%s: closeWrite error: %s", side, err) + } + + } + wg.Done() + log.Debug("done proxing %s -> %s: %d bytes", src.RemoteAddr(), dst.RemoteAddr(), n) + } + + go transfer("remote to local", local, remote) + go transfer("local to remote", remote, local) + + wg.Wait() +} diff --git a/lib/tunnel/server.go b/lib/tunnel/server.go new file mode 100644 index 0000000..b733320 --- /dev/null +++ b/lib/tunnel/server.go @@ -0,0 +1,755 @@ +// Package tunnel is a server/client package that enables to proxy public +// connections to your local machine over a tunnel connection from the local +// machine to the public server. +package tunnel + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "path" + "strconv" + "strings" + "sync" + "time" + + "github.com/koding/logging" + "git.xeserv.us/xena/route/lib/tunnel/proto" + + "github.com/hashicorp/yamux" +) + +var ( + errNoClientSession = errors.New("no client session established") + defaultTimeout = 10 * time.Second +) + +// Server is responsible for proxying public connections to the client over a +// tunnel connection. It also listens to control messages from the client. +type Server struct { + // pending contains the channel that is associated with each new tunnel request. + pending map[string]chan net.Conn + // pendingMu protects the pending map. + pendingMu sync.Mutex + + // sessions contains a session per virtual host. + // Sessions provides multiplexing over one connection. + sessions map[string]*yamux.Session + // sessionsMu protects sessions. + sessionsMu sync.Mutex + + // controls contains the control connection from the client to the server. + controls *controls + + // virtualHosts is used to map public hosts to remote clients. + virtualHosts vhostStorage + + // virtualAddrs. + virtualAddrs *vaddrStorage + + // connCh is used to publish accepted connections for tcp tunnels. + connCh chan net.Conn + + // onConnectCallbacks contains client callbacks called when control + // session is established for a client with given identifier. + onConnectCallbacks *callbacks + + // onDisconnectCallbacks contains client callbacks called when control + // session is closed for a client with given identifier. + onDisconnectCallbacks *callbacks + + // states represents current clients' connections state. + states map[string]ClientState + // statesMu protects states. + statesMu sync.RWMutex + // stateCh notifies receiver about client state changes. + stateCh chan<- *ClientStateChange + + // httpDirector is provided by ServerConfig, if not nil decorates http requests + // before forwarding them to client. + httpDirector func(*http.Request) + + // yamuxConfig is passed to new yamux.Session's + yamuxConfig *yamux.Config + + log logging.Logger +} + +// ServerConfig defines the configuration for the Server +type ServerConfig struct { + // StateChanges receives state transition details each time client + // connection state changes. The channel is expected to be sufficiently + // buffered to keep up with event pace. + // + // If nil, no information about state transitions are dispatched + // by the library. + StateChanges chan<- *ClientStateChange + + // Director is a function that modifies HTTP request into a new HTTP request + // before sending to client. If nil no modifications are done. + Director func(*http.Request) + + // Debug enables debug mode, enable only if you want to debug the server + Debug bool + + // Log defines the logger. If nil a default logging.Logger is used. + Log logging.Logger + + // YamuxConfig defines the config which passed to every new yamux.Session. If nil + // yamux.DefaultConfig() is used. + YamuxConfig *yamux.Config +} + +// NewServer creates a new Server. The defaults are used if config is nil. +func NewServer(cfg *ServerConfig) (*Server, error) { + yamuxConfig := yamux.DefaultConfig() + if cfg.YamuxConfig != nil { + if err := yamux.VerifyConfig(cfg.YamuxConfig); err != nil { + return nil, err + } + + yamuxConfig = cfg.YamuxConfig + } + + log := newLogger("tunnel-server", cfg.Debug) + if cfg.Log != nil { + log = cfg.Log + } + + connCh := make(chan net.Conn) + + opts := &vaddrOptions{ + connCh: connCh, + log: log, + } + + s := &Server{ + pending: make(map[string]chan net.Conn), + sessions: make(map[string]*yamux.Session), + onConnectCallbacks: newCallbacks("OnConnect"), + onDisconnectCallbacks: newCallbacks("OnDisconnect"), + virtualHosts: newVirtualHosts(), + virtualAddrs: newVirtualAddrs(opts), + controls: newControls(), + states: make(map[string]ClientState), + stateCh: cfg.StateChanges, + httpDirector: cfg.Director, + yamuxConfig: yamuxConfig, + connCh: connCh, + log: log, + } + + go s.serveTCP() + + return s, nil +} + +// ServeHTTP is a tunnel that creates an http/websocket tunnel between a +// public connection and the client connection. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // if the user didn't add the control and tunnel handler manually, we'll + // going to infer and call the respective path handlers. + switch path.Clean(r.URL.Path) + "/" { + case proto.ControlPath: + s.checkConnect(s.controlHandler).ServeHTTP(w, r) + return + } + + if err := s.handleHTTP(w, r); err != nil { + if !strings.Contains(err.Error(), "no virtual host available") { // this one is outputted too much, unnecessarily + s.log.Error("remote %s (%s): %s", r.RemoteAddr, r.RequestURI, err) + } + http.Error(w, err.Error(), http.StatusBadGateway) + } +} + +// handleHTTP handles a single HTTP request +func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) error { + s.log.Debug("HandleHTTP request:") + s.log.Debug("%v", r) + + if s.httpDirector != nil { + s.httpDirector(r) + } + + hostPort := strings.ToLower(r.Host) + if hostPort == "" { + return errors.New("request host is empty") + } + + // if someone hits foo.example.com:8080, this should be proxied to + // localhost:8080, so send the port to the client so it knows how to proxy + // correctly. If no port is available, it's up to client how to interpret it + host, port, err := parseHostPort(hostPort) + if err != nil { + // no need to return, just continue lazily, port will be 0, which in + // our case will be proxied to client's local servers port 80 + s.log.Debug("No port available for %q, sending port 80 to client", hostPort) + } + + // get the identifier associated with this host + identifier, ok := s.getIdentifier(hostPort) + if !ok { + // fallback to host + identifier, ok = s.getIdentifier(host) + if !ok { + return fmt.Errorf("no virtual host available for %q", hostPort) + } + } + + if isWebsocketConn(r) { + s.log.Debug("handling websocket connection") + + return s.handleWSConn(w, r, identifier, port) + } + + stream, err := s.dial(identifier, proto.HTTP, port) + if err != nil { + return err + } + defer func() { + s.log.Debug("Closing stream") + stream.Close() + }() + + if err := r.Write(stream); err != nil { + return err + } + + s.log.Debug("Session opened to client, writing request to client") + resp, err := http.ReadResponse(bufio.NewReader(stream), r) + if err != nil { + return fmt.Errorf("read from tunnel: %s", err.Error()) + } + + defer func() { + if resp.Body != nil { + if err := resp.Body.Close(); err != nil && err != io.ErrUnexpectedEOF { + s.log.Error("resp.Body Close error: %s", err.Error()) + } + } + }() + + s.log.Debug("Response received, writing back to public connection: %+v", resp) + + copyHeader(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + + if _, err := io.Copy(w, resp.Body); err != nil { + if err == io.ErrUnexpectedEOF { + s.log.Debug("Client closed the connection, couldn't copy response") + } else { + s.log.Error("copy err: %s", err) // do not return, because we might write multipe headers + } + } + + return nil +} + +func (s *Server) serveTCP() { + for conn := range s.connCh { + go s.serveTCPConn(conn) + } +} + +func (s *Server) serveTCPConn(conn net.Conn) { + err := s.handleTCPConn(conn) + if err != nil { + s.log.Warning("failed to serve %q: %s", conn.RemoteAddr(), err) + conn.Close() + } +} + +func (s *Server) handleWSConn(w http.ResponseWriter, r *http.Request, ident string, port int) error { + hj, ok := w.(http.Hijacker) + if !ok { + return errors.New("webserver doesn't support hijacking") + } + + conn, _, err := hj.Hijack() + if err != nil { + return fmt.Errorf("hijack not possible: %s", err) + } + + stream, err := s.dial(ident, proto.WS, port) + if err != nil { + return err + } + + if err := r.Write(stream); err != nil { + err = errors.New("unable to write upgrade request: " + err.Error()) + return nonil(err, stream.Close()) + } + + resp, err := http.ReadResponse(bufio.NewReader(stream), r) + if err != nil { + err = errors.New("unable to read upgrade response: " + err.Error()) + return nonil(err, stream.Close()) + } + + if err := resp.Write(conn); err != nil { + err = errors.New("unable to write upgrade response: " + err.Error()) + return nonil(err, stream.Close()) + } + + var wg sync.WaitGroup + wg.Add(2) + + go s.proxy(&wg, conn, stream) + go s.proxy(&wg, stream, conn) + + wg.Wait() + + return nonil(stream.Close(), conn.Close()) +} + +func (s *Server) handleTCPConn(conn net.Conn) error { + ident, ok := s.virtualAddrs.getIdent(conn) + if !ok { + return fmt.Errorf("no virtual address available for %s", conn.LocalAddr()) + } + + _, port, err := parseHostPort(conn.LocalAddr().String()) + if err != nil { + return err + } + + stream, err := s.dial(ident, proto.TCP, port) + if err != nil { + return err + } + + var wg sync.WaitGroup + wg.Add(2) + + go s.proxy(&wg, conn, stream) + go s.proxy(&wg, stream, conn) + + wg.Wait() + + return nonil(stream.Close(), conn.Close()) +} + +func (s *Server) proxy(wg *sync.WaitGroup, dst, src net.Conn) { + defer wg.Done() + + s.log.Debug("tunneling %s -> %s", src.RemoteAddr(), dst.RemoteAddr()) + n, err := io.Copy(dst, src) + s.log.Debug("tunneled %d bytes %s -> %s: %v", n, src.RemoteAddr(), dst.RemoteAddr(), err) +} + +func (s *Server) dial(identifier string, p proto.Type, port int) (net.Conn, error) { + control, ok := s.getControl(identifier) + if !ok { + return nil, errNoClientSession + } + + session, err := s.getSession(identifier) + if err != nil { + return nil, err + } + + msg := proto.ControlMessage{ + Action: proto.RequestClientSession, + Protocol: p, + LocalPort: port, + } + + s.log.Debug("Sending control msg %+v", msg) + + // ask client to open a session to us, so we can accept it + if err := control.send(msg); err != nil { + // we might have several issues here, either the stream is closed, or + // the session is going be shut down, the underlying connection might + // be broken. In all cases, it's not reliable anymore having a client + // session. + control.Close() + s.deleteControl(identifier) + return nil, errNoClientSession + } + + var stream net.Conn + acceptStream := func() error { + stream, err = session.Accept() + return err + } + + // if we don't receive anything from the client, we'll timeout + s.log.Debug("Waiting for session accept") + + select { + case err := <-async(acceptStream): + return stream, err + case <-time.After(defaultTimeout): + return nil, errors.New("timeout getting session") + } +} + +// controlHandler is used to capture incoming tunnel connect requests into raw +// tunnel TCP connections. +func (s *Server) controlHandler(w http.ResponseWriter, r *http.Request) (ctErr error) { + identifier := r.Header.Get(proto.ClientIdentifierHeader) + _, ok := s.getHost(identifier) + if !ok { + return fmt.Errorf("no host associated for identifier %s. please use server.AddHost()", identifier) + } + + ct, ok := s.getControl(identifier) + if ok { + ct.Close() + s.deleteControl(identifier) + s.deleteSession(identifier) + s.log.Warning("Control connection for %q already exists. This is a race condition and needs to be fixed on client implementation", identifier) + return fmt.Errorf("control conn for %s already exist. \n", identifier) + } + + s.log.Debug("Tunnel with identifier %s", identifier) + + hj, ok := w.(http.Hijacker) + if !ok { + return errors.New("webserver doesn't support hijacking") + } + + conn, _, err := hj.Hijack() + if err != nil { + return fmt.Errorf("hijack not possible: %s", err) + } + + if _, err := io.WriteString(conn, "HTTP/1.1 "+proto.Connected+"\n\n"); err != nil { + return fmt.Errorf("error writing response: %s", err) + } + + if err := conn.SetDeadline(time.Time{}); err != nil { + return fmt.Errorf("error setting connection deadline: %s", err) + } + + s.log.Debug("Creating control session") + session, err := yamux.Server(conn, s.yamuxConfig) + if err != nil { + return err + } + s.addSession(identifier, session) + + var stream net.Conn + + // close and delete the session/stream if something goes wrong + defer func() { + if ctErr != nil { + if stream != nil { + stream.Close() + } + s.deleteSession(identifier) + } + }() + + acceptStream := func() error { + stream, err = session.Accept() + return err + } + + // if we don't receive anything from the client, we'll timeout + select { + case err := <-async(acceptStream): + if err != nil { + return err + } + case <-time.After(time.Second * 10): + return errors.New("timeout getting session") + } + + s.log.Debug("Initiating handshake protocol") + buf := make([]byte, len(proto.HandshakeRequest)) + if _, err := stream.Read(buf); err != nil { + return err + } + + if string(buf) != proto.HandshakeRequest { + return fmt.Errorf("handshake aborted. got: %s", string(buf)) + } + + if _, err := stream.Write([]byte(proto.HandshakeResponse)); err != nil { + return err + } + + // setup control stream and start to listen to messages + ct = newControl(stream) + s.addControl(identifier, ct) + go s.listenControl(ct) + + s.log.Debug("Control connection is setup") + return nil +} + +// listenControl listens to messages coming from the client. +func (s *Server) listenControl(ct *control) { + s.onConnect(ct.identifier) + + for { + var msg map[string]interface{} + err := ct.dec.Decode(&msg) + if err != nil { + host, _ := s.getHost(ct.identifier) + s.log.Debug("Closing client connection: '%s', %s'", host, ct.identifier) + + // close client connection so it reconnects again + ct.Close() + + // don't forget to cleanup anything + s.deleteControl(ct.identifier) + s.deleteSession(ct.identifier) + + s.onDisconnect(ct.identifier, err) + + if err != io.EOF { + s.log.Error("decode err: %s", err) + } + return + } + + // right now we don't do anything with the messages, but because the + // underlying connection needs to establihsed, we know when we have + // disconnection(above), so we can cleanup the connection. + s.log.Debug("msg: %s", msg) + } +} + +// OnConnect invokes a callback for client with given identifier, +// when it establishes a control session. +// After a client is connected, the associated function +// is also removed and needs to be added again. +func (s *Server) OnConnect(identifier string, fn func() error) { + s.onConnectCallbacks.add(identifier, fn) +} + +// onConnect sends notifications to listeners (registered in onConnectCallbacks +// or stateChanges chanel readers) when client connects. +func (s *Server) onConnect(identifier string) { + if err := s.onConnectCallbacks.call(identifier); err != nil { + s.log.Error("OnConnect: error calling callback for %q: %s", identifier, err) + } + + s.changeState(identifier, ClientConnected, nil) +} + +// OnDisconnect calls the function when the client connected with the +// associated identifier disconnects from the server. +// After a client is disconnected, the associated function +// is also removed and needs to be added again. +func (s *Server) OnDisconnect(identifier string, fn func() error) { + s.onDisconnectCallbacks.add(identifier, fn) +} + +// onDisconnect sends notifications to listeners (registered in onDisconnectCallbacks +// or stateChanges chanel readers) when client disconnects. +func (s *Server) onDisconnect(identifier string, err error) { + if err := s.onDisconnectCallbacks.call(identifier); err != nil { + s.log.Error("OnDisconnect: error calling callback for %q: %s", identifier, err) + } + + s.changeState(identifier, ClientClosed, err) +} + +func (s *Server) changeState(identifier string, state ClientState, err error) (prev ClientState) { + s.statesMu.Lock() + defer s.statesMu.Unlock() + + prev = s.states[identifier] + s.states[identifier] = state + + if s.stateCh != nil { + change := &ClientStateChange{ + Identifier: identifier, + Previous: prev, + Current: state, + Error: err, + } + + select { + case s.stateCh <- change: + default: + s.log.Warning("Dropping state change due to slow reader: %s", change) + } + } + + return prev +} + +// AddHost adds the given virtual host and maps it to the identifier. +func (s *Server) AddHost(host, identifier string) { + s.virtualHosts.AddHost(host, identifier) +} + +// DeleteHost deletes the given virtual host. Once removed any request to this +// host is denied. +func (s *Server) DeleteHost(host string) { + s.virtualHosts.DeleteHost(host) +} + +// AddAddr starts accepting connections on listener l, routing every connection +// to a tunnel client given by the identifier. +// +// When ip parameter is nil, all connections accepted from the listener are +// routed to the tunnel client specified by the identifier (port-based routing). +// +// When ip parameter is non-nil, only those connections are routed whose local +// address matches the specified ip (ip-based routing). +// +// If l listens on multiple interfaces it's desirable to call AddAddr multiple +// times with the same l value but different ip one. +func (s *Server) AddAddr(l net.Listener, ip net.IP, identifier string) { + s.virtualAddrs.Add(l, ip, identifier) +} + +// DeleteAddr stops listening for connections on the given listener. +// +// Upon return no more connections will be tunneled, but as the method does not +// close the listener, so any ongoing connection won't get interrupted. +func (s *Server) DeleteAddr(l net.Listener, ip net.IP) { + s.virtualAddrs.Delete(l, ip) +} + +func (s *Server) getIdentifier(host string) (string, bool) { + identifier, ok := s.virtualHosts.GetIdentifier(host) + return identifier, ok +} + +func (s *Server) getHost(identifier string) (string, bool) { + host, ok := s.virtualHosts.GetHost(identifier) + return host, ok +} + +func (s *Server) addControl(identifier string, conn *control) { + s.controls.addControl(identifier, conn) +} + +func (s *Server) getControl(identifier string) (*control, bool) { + return s.controls.getControl(identifier) +} + +func (s *Server) deleteControl(identifier string) { + s.controls.deleteControl(identifier) +} + +func (s *Server) getSession(identifier string) (*yamux.Session, error) { + s.sessionsMu.Lock() + session, ok := s.sessions[identifier] + s.sessionsMu.Unlock() + + if !ok { + return nil, fmt.Errorf("no session available for identifier: '%s'", identifier) + } + + return session, nil +} + +func (s *Server) addSession(identifier string, session *yamux.Session) { + s.sessionsMu.Lock() + s.sessions[identifier] = session + s.sessionsMu.Unlock() +} + +func (s *Server) deleteSession(identifier string) { + s.sessionsMu.Lock() + defer s.sessionsMu.Unlock() + + session, ok := s.sessions[identifier] + + if !ok { + return // nothing to delete + } + + if session != nil { + session.GoAway() // don't accept any new connection + session.Close() + } + + delete(s.sessions, identifier) +} + +func copyHeader(dst, src http.Header) { + for k, v := range src { + vv := make([]string, len(v)) + copy(vv, v) + dst[k] = vv + } +} + +// checkConnect checks whether the incoming request is HTTP CONNECT method. +func (s *Server) checkConnect(fn func(w http.ResponseWriter, r *http.Request) error) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "CONNECT" { + http.Error(w, "405 must CONNECT\n", http.StatusMethodNotAllowed) + return + } + + if err := fn(w, r); err != nil { + s.log.Error("Handler err: %v", err.Error()) + + if identifier := r.Header.Get(proto.ClientIdentifierHeader); identifier != "" { + s.onDisconnect(identifier, err) + } + + http.Error(w, err.Error(), 502) + } + }) +} + +func parseHostPort(addr string) (string, int, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return "", 0, err + } + + n, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return "", 0, err + } + + return host, int(n), nil +} + +func isWebsocketConn(r *http.Request) bool { + return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") && + headerContains(r.Header["Upgrade"], "websocket") +} + +// headerContains is a copy of tokenListContainsValue from gorilla/websocket/util.go +func headerContains(header []string, value string) bool { + for _, h := range header { + for _, v := range strings.Split(h, ",") { + if strings.EqualFold(strings.TrimSpace(v), value) { + return true + } + } + } + + return false +} + +func nonil(err ...error) error { + for _, e := range err { + if e != nil { + return e + } + } + + return nil +} + +func newLogger(name string, debug bool) logging.Logger { + log := logging.NewLogger(name) + logHandler := logging.NewWriterHandler(os.Stderr) + logHandler.Colorize = true + log.SetHandler(logHandler) + + if debug { + log.SetLevel(logging.DEBUG) + logHandler.SetLevel(logging.DEBUG) + } + + return log +} diff --git a/lib/tunnel/spec.md b/lib/tunnel/spec.md new file mode 100644 index 0000000..fdd0522 --- /dev/null +++ b/lib/tunnel/spec.md @@ -0,0 +1,100 @@ +# Specification + +# Naming conventions + +* `server` is listening to public connection and is responsible of routing + public HTTP requests to clients. +* `client` is a long running process, connected to a server and running on a local machine. +* `virtualHost` is a virtual domain that maps a domain to a single client. i.e: + `arslan.koding.io` is a virtualhost which is mapped to my `client` running on + my local machine. +* `identifier` is a secret token, which is not meant to be shared with others. + An identifier is responsible of mapping a virtualhost to a client. +* `session` is a single TCP connection which uses the library `yamux`. A + session can be created either via `yamux.Server()` or `yamux.Client` +* `stream` is a `net.Conn` compatible `virtual` connection that is multiplexed + over the `session`. A session can have hundreds of thousands streams +* `control connection` is a single `stream` which is used to communicate and + handle messaging between server and client. It uses a custom protocol which + is JSON encoded. +* `tunnel connection` is a single `stream` which is used to proxy public HTTP + requests from the `server` to the `client` and vice versa. A single `tunnel` + connection is created for every single HTTP requests. +* `public connection` is a connection from a remote machine to the `server` +* `ControlHandler` is a http.Handler which listens to requests coming to + `/_controlPath_/`. It's used to setup the initial `session` connection from + `client` to `server`. And creates the `control connection` from this session. + server and client, and also for all additional new tunnel. It literally + captures the incoming HTTP request and hijacks it and converts it into RAW TCP, + which then is used as the foundation for all yamux `sessions.` + + +# Server +1. Server is created with `NewServer()` which returns `*Server`, a `http.Handler` + compatible type. Plug into any HTTP server you want. The root path `"/"` is + recommended to listen and proxy any tunnels. It also listens to any request + coming to `ControlHandler` +2. Tunneling is based on virtual hosts. A virtual hosts is identified with an + unique identifier. This identifier is the only piece that both client and + server needs to known ahead. Think of it as a secret token. +3. To add a virtual host, call `server.AddHost(virtualHost, identifier)`. This + step needs to be done from the server itself. This can be could manually or + via custom auth based HTTP handlers, such as "/addhost", which adds + virtualhosts and returns the `identifier` to the requester (in our case `client`) +4. A DNS record and it's subdomains needs to point to a `server`, so it can + handle virtual hosts, i.e: `*.example.com` is routed to a server, which can + handle `foo.example.com`, `bar.example.com`, etc.. + + +# Client + +1. Client is created with `NewClient(serverAddr, localAddr)` which returns a + `*Client`. Here `serverAddr` is the TCP address to the server. `localAddr` + is the server in which all public requests are forwarded to. It's optional if + you want it to be done dynamically +2. Once a client is created, it starts with `client.Start(identifier)`. Here + `identifier` is needed upfront. This method creates the initial TCP + connection to the server. It sends the identifier back to the server. This + TCP connection is used as the foundation for `yamux.Client()`. Once a yamux + session is established, we are able to use this single connection to have + multiple streams, which are multiplexed over this one connection. A `control + connection` is created and client starts to listen it. `client.Start` is + blocking. + +# Control Handshake + +1. Client sends a `handshakeRequest` over the `control connection` stream +2. The server sends back a `handshakeResponse` to the client over the `control connection` stream +3. Once the client receives the `handshakeResponse` from the server, it starts + to listen from the `control connection` stream. +4. A `control connection` is json.Encoder/Decoder both for server and client + + +# Tunnel creation +1. When the server receives a public connection, it checks the HTTP host + headers and retrieves the corresponding identifier from the given host. +2. The server retrieves the `control connection` which was associated with this + `identifier` and sends a `ControlMsg` message with the action + `RequestClientSession`. This message is in the form of: + + type ControlMsg struct { + Action Action `json:"action"` + Protocol TransportProtocol `json:"transportProtocol"` + LocalPort string `json:"localPort"` + } + + Here the `LocalPort` is read from the HTTP Host header. If absent a zero + port is sent and client maps it to the local server running at port 80, unless + the `localAddr` is specified in `client.Start()` method. `Protocol` is + reserved for future features. +3. The server immediately starts to listen(accept) to a new `stream`. This is + blocking and it waits there. +4. When the client receives the `RequestClientSession` message, it opens a new + `virtual` TCP connection, a `stream` to the server. +5. The server which was waiting for a new stream in step 3, establish the stream. +6. The server copies the request over the stream to the client. +7. The client copies the request coming from the server to the local server and + copies back the result to the server +8. The server reads the response coming from the client and returns back it to + the public connection requester + diff --git a/lib/tunnel/tcpproxy.go b/lib/tunnel/tcpproxy.go new file mode 100644 index 0000000..5343c24 --- /dev/null +++ b/lib/tunnel/tcpproxy.go @@ -0,0 +1,78 @@ +package tunnel + +import ( + "fmt" + "net" + + "github.com/koding/logging" + "git.xeserv.us/xena/route/lib/tunnel/proto" +) + +var ( + tpcLog = logging.NewLogger("tcp") +) + +// TCPProxy forwards TCP streams. +// +// If port-based routing is used, LocalAddr or FetchLocalAddr field is required +// for tunneling to function properly. +// Otherwise you'll be forwarding traffic to random ports and this is usually not desired. +// +// If IP-based routing is used then tunnel server connection request is +// proxied to 127.0.0.1:incomingPort where incomingPort is control message LocalPort. +// Usually this is tunnel server's public exposed Port. +// This behaviour can be changed by setting LocalAddr or FetchLocalAddr. +// FetchLocalAddr takes precedence over LocalAddr. +type TCPProxy struct { + // LocalAddr defines the TCP address of the local server. + // This is optional if you want to specify a single TCP address. + LocalAddr string + // FetchLocalAddr is used for looking up TCP address of the server. + // This is optional if you want to specify a dynamic TCP address based on incommig port. + FetchLocalAddr func(port int) (string, error) + // Log is a custom logger that can be used for the proxy. + // If not set a "tcp" logger is used. + Log logging.Logger +} + +// Proxy is a ProxyFunc. +func (p *TCPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) { + if msg.Protocol != proto.TCP { + panic("Proxy mismatch") + } + + var log = p.log() + + var port = msg.LocalPort + if port == 0 { + log.Warning("TCP proxy to port 0") + } + + var localAddr = fmt.Sprintf("127.0.0.1:%d", port) + if p.LocalAddr != "" { + localAddr = p.LocalAddr + } else if p.FetchLocalAddr != nil { + l, err := p.FetchLocalAddr(msg.LocalPort) + if err != nil { + log.Warning("Failed to get custom local address: %s", err) + return + } + localAddr = l + } + + log.Debug("Dialing local server: %q", localAddr) + local, err := net.DialTimeout("tcp", localAddr, defaultTimeout) + if err != nil { + log.Error("Dialing local server %q failed: %s", localAddr, err) + return + } + + Join(local, remote, log) +} + +func (p *TCPProxy) log() logging.Logger { + if p.Log != nil { + return p.Log + } + return tpcLog +} diff --git a/lib/tunnel/tunnel_test.go b/lib/tunnel/tunnel_test.go new file mode 100644 index 0000000..c053065 --- /dev/null +++ b/lib/tunnel/tunnel_test.go @@ -0,0 +1,412 @@ +package tunnel_test + +import ( + "fmt" + "strconv" + "sync" + "testing" + "time" + + "git.xeserv.us/xena/route/lib/tunnel" + "git.xeserv.us/xena/route/lib/tunnel/tunneltest" + + "github.com/cenkalti/backoff" +) + +func TestMultipleRequest(t *testing.T) { + tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + // make a request to tunnelserver, this should be tunneled to local server + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + msg := "hello" + strconv.Itoa(i) + res, err := echoHTTP(tt, msg) + if err != nil { + t.Fatalf("echoHTTP error: %s", err) + } + + if res != msg { + t.Errorf("got %q, want %q", res, msg) + } + }(i) + } + + wg.Wait() +} + +func TestMultipleLatencyRequest(t *testing.T) { + tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + // make a request to tunnelserver, this should be tunneled to local server + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + msg := "hello" + strconv.Itoa(i) + res, err := echoHTTP(tt, msg) + if err != nil { + t.Fatalf("echoHTTP error: %s", err) + } + + if res != msg { + t.Errorf("got %q, want %q", res, msg) + } + }(i) + } + + wg.Wait() +} + +func TestReconnectClient(t *testing.T) { + tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + msg := "hello" + res, err := echoHTTP(tt, msg) + if err != nil { + t.Fatalf("echoHTTP error: %s", err) + } + + if res != msg { + t.Errorf("got %q, want %q", res, msg) + } + + client := tt.Clients["http"] + + // close client, and start it again + client.Close() + + go client.Start() + <-client.StartNotify() + + msg = "helloagain" + res, err = echoHTTP(tt, msg) + if err != nil { + t.Fatalf("echoHTTP error: %s", err) + } + + if res != msg { + t.Errorf("got %q, want %q", res, msg) + } +} + +func TestNoClient(t *testing.T) { + const expectedErr = "no client session established" + + rec := tunneltest.NewStateRecorder() + + tt, err := tunneltest.Serve(singleRecHTTP(handlerEchoHTTP, rec.C())) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + if err := rec.WaitTransitions( + tunnel.ClientStarted, + tunnel.ClientConnecting, + tunnel.ClientConnected, + ); err != nil { + t.Fatal(err) + } + + if err := tt.ServerStateRecorder.WaitTransition( + tunnel.ClientUnknown, + tunnel.ClientConnected, + ); err != nil { + t.Fatal(err) + } + + // close client, this is the main point of the test + if err := tt.Clients["http"].Close(); err != nil { + t.Fatal(err) + } + + if err := rec.WaitTransitions( + tunnel.ClientConnected, + tunnel.ClientDisconnected, + tunnel.ClientClosed, + ); err != nil { + t.Fatal(err) + } + + if err := tt.ServerStateRecorder.WaitTransition( + tunnel.ClientConnected, + tunnel.ClientClosed, + ); err != nil { + t.Fatal(err) + } + + msg := "hello" + res, err := echoHTTP(tt, msg) + if err != nil { + t.Fatalf("echoHTTP error: %s", err) + } + + if res != expectedErr { + t.Errorf("got %q, want %q", res, msg) + } +} + +func TestNoHost(t *testing.T) { + tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + noBackoff := backoff.NewConstantBackOff(time.Duration(-1)) + + unknown, err := tunnel.NewClient(&tunnel.ClientConfig{ + Identifier: "unknown", + ServerAddr: tt.ServerAddr().String(), + Backoff: noBackoff, + Debug: testing.Verbose(), + }) + if err != nil { + t.Fatalf("client error: %s", err) + } + unknown.Start() + defer unknown.Close() + + if err := tt.ServerStateRecorder.WaitTransition( + tunnel.ClientUnknown, + tunnel.ClientClosed, + ); err != nil { + t.Fatal(err) + } + + unknown.Start() + if err := tt.ServerStateRecorder.WaitTransition( + tunnel.ClientClosed, + tunnel.ClientClosed, + ); err != nil { + t.Fatal(err) + } +} + +func TestNoLocalServer(t *testing.T) { + const expectedErr = "no local server" + + tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + // close local listener, this is the main point of the test + tt.Listeners["http"][0].Close() + + msg := "hello" + res, err := echoHTTP(tt, msg) + if err != nil { + t.Fatalf("echoHTTP error: %s", err) + } + + if res != expectedErr { + t.Errorf("got %q, want %q", res, msg) + } +} + +func TestSingleRequest(t *testing.T) { + tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + msg := "hello" + res, err := echoHTTP(tt, msg) + if err != nil { + t.Fatalf("echoHTTP error: %s", err) + } + + if res != msg { + t.Errorf("got %q, want %q", res, msg) + } +} + +func TestSingleLatencyRequest(t *testing.T) { + tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + msg := "hello" + res, err := echoHTTP(tt, msg) + if err != nil { + t.Fatalf("echoHTTP error: %s", err) + } + + if res != msg { + t.Errorf("got %q, want %q", res, msg) + } +} + +func TestSingleTCP(t *testing.T) { + tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + msg := "hello" + res, err := echoTCP(tt, msg) + if err != nil { + t.Fatalf("echoTCP error: %s", err) + } + + if msg != res { + t.Errorf("got %q, want %q", res, msg) + } +} + +func TestMultipleTCP(t *testing.T) { + tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + msg := "hello" + strconv.Itoa(i) + res, err := echoTCP(tt, msg) + if err != nil { + t.Errorf("echoTCP: %s", err) + } + + if res != msg { + t.Errorf("got %q, want %q", res, msg) + } + }(i) + } + + wg.Wait() +} + +func TestMultipleLatencyTCP(t *testing.T) { + tt, err := tunneltest.Serve(singleTCP(handlerLatencyEchoTCP)) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + msg := "hello" + strconv.Itoa(i) + res, err := echoTCP(tt, msg) + if err != nil { + t.Errorf("echoTCP: %s", err) + } + + if res != msg { + t.Errorf("got %q, want %q", res, msg) + } + }(i) + } + + wg.Wait() +} + +func TestMultipleStreamTCP(t *testing.T) { + tunnels := map[string]*tunneltest.Tunnel{ + "http": { + Type: tunneltest.TypeHTTP, + LocalAddr: "127.0.0.1:0", + Handler: handlerEchoHTTP, + }, + "tcp": { + Type: tunneltest.TypeTCP, + ClientIdent: "http", + LocalAddr: "127.0.0.1:0", + RemoteAddr: "127.0.0.1:0", + Handler: handlerEchoTCP, + }, + "tcp_all": { + Type: tunneltest.TypeTCP, + ClientIdent: "http", + LocalAddr: "127.0.0.1:0", + RemoteAddr: "0.0.0.0:0", + Handler: handlerEchoTCP, + }, + } + + addrs, err := tunneltest.UsableAddrs() + if err != nil { + t.Fatal(err) + } + + clients := []string{"tcp"} + for i, addr := range addrs { + if addr.IP.IsLoopback() { + continue + } + + client := fmt.Sprintf("tcp_%d", i) + + tunnels[client] = &tunneltest.Tunnel{ + Type: tunneltest.TypeTCP, + ClientIdent: "http", + LocalAddr: "127.0.0.1:0", + RemoteAddrIdent: "tcp_all", + IP: addr.IP, + Handler: handlerEchoTCP, + } + + clients = append(clients, client) + } + + tt, err := tunneltest.Serve(tunnels) + if err != nil { + t.Fatal(err) + } + defer tt.Close() + + var wg sync.WaitGroup + for i := 0; i < 100/len(clients); i++ { + wg.Add(len(clients)) + + for j, ident := range clients { + go func(ident string, i, j int) { + defer wg.Done() + msg := fmt.Sprintf("hello_%d_client_%d", j, i) + res, err := echoTCPIdent(tt, msg, ident) + if err != nil { + t.Errorf("echoTCP: %s", err) + } + + if res != msg { + t.Errorf("got %q, want %q", res, msg) + } + }(ident, i, j) + } + } + + wg.Wait() +} diff --git a/lib/tunnel/tunneltest/state_recorder.go b/lib/tunnel/tunneltest/state_recorder.go new file mode 100644 index 0000000..d248823 --- /dev/null +++ b/lib/tunnel/tunneltest/state_recorder.go @@ -0,0 +1,118 @@ +package tunneltest + +import ( + "bytes" + "fmt" + "sync" + "time" + + "git.xeserv.us/xena/route/lib/tunnel" +) + +var ( + recWaitTimeout = 5 * time.Second + recBuffer = 32 +) + +// States is a sequence of client state changes. +type States []*tunnel.ClientStateChange + +func (s States) String() string { + if len(s) == 0 { + return "" + } + + var buf bytes.Buffer + + fmt.Fprintf(&buf, "[%s", s[0].String()) + + for _, s := range s[1:] { + fmt.Fprintf(&buf, ",%s", s.String()) + } + + buf.WriteRune(']') + + return buf.String() +} + +// StateRecorder saves state changes pushed to StateRecorder.C(). +type StateRecorder struct { + mu sync.Mutex + ch chan *tunnel.ClientStateChange + recorded []*tunnel.ClientStateChange + offset int +} + +func NewStateRecorder() *StateRecorder { + rec := &StateRecorder{ + ch: make(chan *tunnel.ClientStateChange, recBuffer), + } + + go rec.record() + + return rec +} + +func (rec *StateRecorder) record() { + for state := range rec.ch { + rec.mu.Lock() + rec.recorded = append(rec.recorded, state) + rec.mu.Unlock() + } +} + +func (rec *StateRecorder) C() chan<- *tunnel.ClientStateChange { + return rec.ch +} + +func (rec *StateRecorder) WaitTransitions(states ...tunnel.ClientState) error { + from := states[0] + for _, to := range states[1:] { + if err := rec.WaitTransition(from, to); err != nil { + return err + } + + from = to + } + + return nil +} + +func (rec *StateRecorder) WaitTransition(from, to tunnel.ClientState) error { + timeout := time.After(recWaitTimeout) + + var lastStates []*tunnel.ClientStateChange + for { + select { + case <-timeout: + return fmt.Errorf("timed out waiting for %s->%s transition: %v", from, to, States(lastStates)) + default: + time.Sleep(50 * time.Millisecond) + + lastStates = rec.States()[rec.offset:] + + for i, state := range lastStates { + if from != 0 && state.Previous != from { + continue + } + + if to != 0 && state.Current != to { + continue + } + + rec.offset += i + + return nil + } + } + } +} + +func (rec *StateRecorder) States() []*tunnel.ClientStateChange { + rec.mu.Lock() + defer rec.mu.Unlock() + + states := make([]*tunnel.ClientStateChange, len(rec.recorded)) + copy(states, rec.recorded) + return states +} diff --git a/lib/tunnel/tunneltest/tunneltest.go b/lib/tunnel/tunneltest/tunneltest.go new file mode 100644 index 0000000..b31be78 --- /dev/null +++ b/lib/tunnel/tunneltest/tunneltest.go @@ -0,0 +1,561 @@ +package tunneltest + +import ( + "errors" + "fmt" + "log" + "net" + "net/http" + "net/url" + "os" + "sort" + "strconv" + "sync" + "testing" + "time" + + "git.xeserv.us/xena/route/lib/tunnel" +) + +var debugNet = os.Getenv("DEBUGNET") == "1" + +type dbgListener struct { + net.Listener +} + +func (l dbgListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + return dbgConn{conn}, nil +} + +type dbgConn struct { + net.Conn +} + +func (c dbgConn) Read(p []byte) (int, error) { + n, err := c.Conn.Read(p) + os.Stderr.Write(p) + return n, err +} + +func (c dbgConn) Write(p []byte) (int, error) { + n, err := c.Conn.Write(p) + os.Stderr.Write(p) + return n, err +} + +func logf(format string, args ...interface{}) { + if testing.Verbose() { + log.Printf("[tunneltest] "+format, args...) + } +} + +func nonil(err ...error) error { + for _, e := range err { + if e != nil { + return e + } + } + return nil +} + +func parseHostPort(addr string) (string, int, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return "", 0, err + } + + n, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return "", 0, err + } + + return host, int(n), nil +} + +// UsableAddrs returns all tcp addresses that we can bind a listener to. +func UsableAddrs() ([]*net.TCPAddr, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + + var usable []*net.TCPAddr + for _, addr := range addrs { + if ipNet, ok := addr.(*net.IPNet); ok { + if !ipNet.IP.IsLinkLocalUnicast() { + usable = append(usable, &net.TCPAddr{IP: ipNet.IP}) + } + } + } + + if len(usable) == 0 { + return nil, errors.New("no usable addresses found") + } + + return usable, nil +} + +const ( + TypeHTTP = iota + TypeTCP +) + +// Tunnel represents a single HTTP or TCP tunnel that can be served +// by TunnelTest. +type Tunnel struct { + // Type specifies a tunnel type - either TypeHTTP (default) or TypeTCP. + Type int + + // Handler is a handler to use for serving tunneled connections on + // local server. The value of this field is required to be of type: + // + // - http.Handler or http.HandlerFunc for HTTP tunnels + // - func(net.Conn) for TCP tunnels + // + // Required field. + Handler interface{} + + // LocalAddr is a network address of local server that handles + // connections/requests with Handler. + // + // Optional field, takes value of "127.0.0.1:0" when empty. + LocalAddr string + + // ClientIdent is an identifier of a client that have already + // registered a HTTP tunnel and have established control connection. + // + // If the Type is TypeTCP, instead of creating new client + // for this TCP tunnel, we add it to an existing client + // specified by the field. + // + // Optional field for TCP tunnels. + // Ignored field for HTTP tunnels. + ClientIdent string + + // RemoteAddr is a network address of remote server, which accepts + // connections on a tunnel server side. + // + // Required field for TCP tunnels. + // Ignored field for HTTP tunnels. + RemoteAddr string + + // RemoteAddrIdent an identifier of an already existing listener, + // that listens on multiple interfaces; if the RemoteAddrIdent is valid + // identifier the IP field is required to be non-nil and RemoteAddr + // is ignored. + // + // Optional field for TCP tunnels. + // Ignored field for HTTP tunnels. + RemoteAddrIdent string + + // IP specifies an IP address value for IP-based routing for TCP tunnels. + // For more details see inline documentation for (*tunnel.Server).AddAddr. + // + // Optional field for TCP tunnels. + // Ignored field for HTTP tunnels. + IP net.IP + + // StateChanges listens on state transitions. + // + // If ClientIdent field is empty, the StateChanges will receive + // state transition events for the newly created client. + // Otherwise setting this field is a nop. + StateChanges chan<- *tunnel.ClientStateChange +} + +type TunnelTest struct { + Server *tunnel.Server + ServerStateRecorder *StateRecorder + Clients map[string]*tunnel.Client + Listeners map[string][2]net.Listener // [0] is local listener, [1] is remote one (for TCP tunnels) + Addrs []*net.TCPAddr + Tunnels map[string]*Tunnel + DebugNet bool // for debugging network communication + + mu sync.Mutex // protects Listeners +} + +func NewTunnelTest() (*TunnelTest, error) { + rec := NewStateRecorder() + + cfg := &tunnel.ServerConfig{ + StateChanges: rec.C(), + Debug: testing.Verbose(), + } + s, err := tunnel.NewServer(cfg) + if err != nil { + return nil, err + } + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + + if debugNet { + l = dbgListener{l} + } + + addrs, err := UsableAddrs() + if err != nil { + return nil, err + } + + go (&http.Server{Handler: s}).Serve(l) + + return &TunnelTest{ + Server: s, + ServerStateRecorder: rec, + Clients: make(map[string]*tunnel.Client), + Listeners: map[string][2]net.Listener{"": {l, nil}}, + Addrs: addrs, + Tunnels: make(map[string]*Tunnel), + DebugNet: debugNet, + }, nil +} + +// Serve creates new TunnelTest that serves the given tunnels. +// +// If tunnels is nil, DefaultTunnels() are used instead. +func Serve(tunnels map[string]*Tunnel) (*TunnelTest, error) { + tt, err := NewTunnelTest() + if err != nil { + return nil, err + } + + if err = tt.Serve(tunnels); err != nil { + return nil, err + } + + return tt, nil +} + +func (tt *TunnelTest) serveSingle(ident string, t *Tunnel) (bool, error) { + // Verify tunnel dependencies for TCP tunnels. + if t.Type == TypeTCP { + // If tunnel specified by t.Client was not already started, + // skip and move on. + if _, ok := tt.Clients[t.ClientIdent]; t.ClientIdent != "" && !ok { + return false, nil + } + + // Verify the TCP tunnel whose remote endpoint listens on multiple + // interfaces is already served. + if t.RemoteAddrIdent != "" { + if _, ok := tt.Listeners[t.RemoteAddrIdent]; !ok { + return false, nil + } + + if tt.Tunnels[t.RemoteAddrIdent].Type != TypeTCP { + return false, fmt.Errorf("expected tunnel %q to be of TCP type", t.RemoteAddrIdent) + } + } + } + + l, err := net.Listen("tcp", t.LocalAddr) + if err != nil { + return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.LocalAddr, ident, err) + } + + if tt.DebugNet { + l = dbgListener{l} + } + + localAddr := l.Addr().String() + httpProxy := &tunnel.HTTPProxy{LocalAddr: localAddr} + tcpProxy := &tunnel.TCPProxy{FetchLocalAddr: tt.fetchLocalAddr} + + cfg := &tunnel.ClientConfig{ + Identifier: ident, + ServerAddr: tt.ServerAddr().String(), + Proxy: tunnel.Proxy(tunnel.ProxyFuncs{ + HTTP: httpProxy.Proxy, + TCP: tcpProxy.Proxy, + }), + StateChanges: t.StateChanges, + Debug: testing.Verbose(), + } + + // Register tunnel: + // + // - start tunnel.Client (tt.Clients[ident]) or reuse existing one (tt.Clients[t.ExistingClient]) + // - listen on local address and start local server (tt.Listeners[ident][0]) + // - register tunnel on tunnel.Server + // + switch t.Type { + case TypeHTTP: + // TODO(rjeczalik): refactor to separate method + + h, ok := t.Handler.(http.Handler) + if !ok { + h, ok = t.Handler.(http.HandlerFunc) + if !ok { + fn, ok := t.Handler.(func(http.ResponseWriter, *http.Request)) + if !ok { + return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler) + } + + h = http.HandlerFunc(fn) + } + + } + + logf("serving on local %s for HTTP tunnel %q", l.Addr(), ident) + + go (&http.Server{Handler: h}).Serve(l) + + tt.Server.AddHost(localAddr, ident) + + tt.mu.Lock() + tt.Listeners[ident] = [2]net.Listener{l, nil} + tt.mu.Unlock() + + if err := tt.addClient(ident, cfg); err != nil { + return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err) + } + + logf("registered HTTP tunnel: host=%s, ident=%s", localAddr, ident) + + case TypeTCP: + // TODO(rjeczalik): refactor to separate method + + h, ok := t.Handler.(func(net.Conn)) + if !ok { + return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler) + } + + logf("serving on local %s for TCP tunnel %q", l.Addr(), ident) + + go func() { + for { + conn, err := l.Accept() + if err != nil { + log.Printf("failed accepting conn for %q tunnel: %s", ident, err) + return + } + + go h(conn) + } + }() + + var remote net.Listener + + if t.RemoteAddrIdent != "" { + tt.mu.Lock() + remote = tt.Listeners[t.RemoteAddrIdent][1] + tt.mu.Unlock() + } else { + remote, err = net.Listen("tcp", t.RemoteAddr) + if err != nil { + return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.RemoteAddr, ident, err) + } + } + + // addrIdent holds identifier of client which is going to have registered + // tunnel via (*tunnel.Server).AddAddr + addrIdent := ident + if t.ClientIdent != "" { + tt.Clients[ident] = tt.Clients[t.ClientIdent] + addrIdent = t.ClientIdent + } + + tt.Server.AddAddr(remote, t.IP, addrIdent) + + tt.mu.Lock() + tt.Listeners[ident] = [2]net.Listener{l, remote} + tt.mu.Unlock() + + if _, ok := tt.Clients[ident]; !ok { + if err := tt.addClient(ident, cfg); err != nil { + return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err) + } + } + + logf("registered TCP tunnel: listener=%s, ip=%v, ident=%s", remote.Addr(), t.IP, addrIdent) + + default: + return false, fmt.Errorf("unknown %q tunnel type: %d", ident, t.Type) + } + + return true, nil +} + +func (tt *TunnelTest) addClient(ident string, cfg *tunnel.ClientConfig) error { + if _, ok := tt.Clients[ident]; ok { + return fmt.Errorf("tunnel %q is already being served", ident) + } + + c, err := tunnel.NewClient(cfg) + if err != nil { + return err + } + + done := make(chan struct{}) + + tt.Server.OnConnect(ident, func() error { + close(done) + return nil + }) + + go c.Start() + <-c.StartNotify() + + select { + case <-time.After(10 * time.Second): + return errors.New("timed out after 10s waiting on client to establish control conn") + case <-done: + } + + tt.Clients[ident] = c + return nil +} + +func (tt *TunnelTest) Serve(tunnels map[string]*Tunnel) error { + if len(tunnels) == 0 { + return errors.New("no tunnels to serve") + } + + // Since one tunnels depends on others do 3 passes to start them + // all, each started tunnel is removed from the tunnels map. + // After 3 passes all of them must be started, otherwise the + // configuration is bad: + // + // - first pass starts HTTP tunnels as new client tunnels + // - second pass starts TCP tunnels that rely on on already existing client tunnels (t.ClientIdent) + // - third pass starts TCP tunnels that rely on on already existing TCP tunnels (t.RemoteAddrIdent) + // + for i := 0; i < 3; i++ { + if err := tt.popServedDeps(tunnels); err != nil { + return err + } + } + + if len(tunnels) != 0 { + unresolved := make([]string, len(tunnels)) + for ident := range tunnels { + unresolved = append(unresolved, ident) + } + sort.Strings(unresolved) + + return fmt.Errorf("unable to start tunnels due to unresolved dependencies: %v", unresolved) + } + + return nil +} + +func (tt *TunnelTest) popServedDeps(tunnels map[string]*Tunnel) error { + for ident, t := range tunnels { + ok, err := tt.serveSingle(ident, t) + if err != nil { + return err + } + + if ok { + // Remove already started tunnels so they won't get started again. + delete(tunnels, ident) + tt.Tunnels[ident] = t + } + } + + return nil +} + +func (tt *TunnelTest) fetchLocalAddr(port int) (string, error) { + tt.mu.Lock() + defer tt.mu.Unlock() + + for _, l := range tt.Listeners { + if l[1] == nil { + // this listener does not belong to a TCP tunnel + continue + } + + _, remotePort, err := parseHostPort(l[1].Addr().String()) + if err != nil { + return "", err + } + + if port == remotePort { + return l[0].Addr().String(), nil + } + } + + return "", fmt.Errorf("no route for %d port", port) +} + +func (tt *TunnelTest) ServerAddr() net.Addr { + return tt.Listeners[""][0].Addr() +} + +// Addr gives server endpoint of the TCP tunnel for the given ident. +// +// If the tunnel does not exist or is a HTTP one, TunnelAddr return nil. +func (tt *TunnelTest) Addr(ident string) net.Addr { + l, ok := tt.Listeners[ident] + if !ok { + return nil + } + + return l[1].Addr() +} + +// Request creates a HTTP request to a server endpoint of the HTTP tunnel +// for the given ident. +// +// If the tunnel does not exist, Request returns nil. +func (tt *TunnelTest) Request(ident string, query url.Values) *http.Request { + l, ok := tt.Listeners[ident] + if !ok { + return nil + } + + var raw string + if query != nil { + raw = query.Encode() + } + + return &http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: tt.ServerAddr().String(), + Path: "/", + RawQuery: raw, + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Host: l[0].Addr().String(), + } +} + +func (tt *TunnelTest) Close() (err error) { + // Close tunnel.Clients. + clients := make(map[*tunnel.Client]struct{}) + for _, c := range tt.Clients { + clients[c] = struct{}{} + } + for c := range clients { + err = nonil(err, c.Close()) + } + + // Stop all TCP/HTTP servers. + listeners := make(map[net.Listener]struct{}) + for _, l := range tt.Listeners { + for _, l := range l { + if l != nil { + listeners[l] = struct{}{} + } + } + } + for l := range listeners { + err = nonil(err, l.Close()) + } + + return err +} diff --git a/lib/tunnel/util.go b/lib/tunnel/util.go new file mode 100644 index 0000000..154bea2 --- /dev/null +++ b/lib/tunnel/util.go @@ -0,0 +1,121 @@ +package tunnel + +import ( + "crypto/tls" + "fmt" + "net" + "sync" + "time" + + "git.xeserv.us/xena/route/lib/tunnel/proto" + + "github.com/cenkalti/backoff" +) + +// async is a helper function to convert a blocking function to a function +// returning an error. Useful for plugging function closures into select and co +func async(fn func() error) <-chan error { + errChan := make(chan error, 0) + go func() { + select { + case errChan <- fn(): + default: + } + + close(errChan) + }() + + return errChan +} + +type expBackoff struct { + mu sync.Mutex + bk *backoff.ExponentialBackOff +} + +func newForeverBackoff() *expBackoff { + eb := &expBackoff{ + bk: backoff.NewExponentialBackOff(), + } + eb.bk.MaxElapsedTime = 0 // never stops + return eb +} + +func (eb *expBackoff) NextBackOff() time.Duration { + eb.mu.Lock() + defer eb.mu.Unlock() + + return eb.bk.NextBackOff() +} + +func (eb *expBackoff) Reset() { + eb.mu.Lock() + eb.bk.Reset() + eb.mu.Unlock() +} + +type callbacks struct { + mu sync.Mutex + name string + funcs map[string]func() error +} + +func newCallbacks(name string) *callbacks { + return &callbacks{ + name: name, + funcs: make(map[string]func() error), + } +} + +func (c *callbacks) add(identifier string, fn func() error) { + c.mu.Lock() + c.funcs[identifier] = fn + c.mu.Unlock() +} + +func (c *callbacks) pop(identifier string) (func() error, error) { + c.mu.Lock() + defer c.mu.Unlock() + + fn, ok := c.funcs[identifier] + if !ok { + return nil, nil // nop + } + + delete(c.funcs, identifier) + + if fn == nil { + return nil, fmt.Errorf("nil callback set for %q client", identifier) + } + + return fn, nil +} + +func (c *callbacks) call(identifier string) error { + fn, err := c.pop(identifier) + if err != nil { + return err + } + + if fn == nil { + return nil // nop + } + + return fn() +} + +// Returns server control url as a string. Reads scheme and remote address from connection. +func controlURL(conn net.Conn) string { + return fmt.Sprint(scheme(conn), "://", conn.RemoteAddr(), proto.ControlPath) +} + +func scheme(conn net.Conn) (scheme string) { + switch conn.(type) { + case *tls.Conn: + scheme = "https" + default: + scheme = "http" + } + + return +} diff --git a/lib/tunnel/virtualaddr.go b/lib/tunnel/virtualaddr.go new file mode 100644 index 0000000..6c77b66 --- /dev/null +++ b/lib/tunnel/virtualaddr.go @@ -0,0 +1,179 @@ +package tunnel + +import ( + "net" + "strconv" + "sync" + "sync/atomic" + + "github.com/koding/logging" +) + +type listener struct { + net.Listener + *vaddrOptions + + done int32 + + // ips keeps track of registered clients for ip-based routing; + // when last client is deleted from the ip routing map, we stop + // listening on connections + ips map[string]struct{} +} + +type vaddrOptions struct { + connCh chan<- net.Conn + log logging.Logger +} + +type vaddrStorage struct { + *vaddrOptions + + listeners map[net.Listener]*listener + ports map[int]string // port-based routing: maps port number to identifier + ips map[string]string // ip-based routing: maps ip address to identifier + + mu sync.RWMutex +} + +func newVirtualAddrs(opts *vaddrOptions) *vaddrStorage { + return &vaddrStorage{ + vaddrOptions: opts, + listeners: make(map[net.Listener]*listener), + ports: make(map[int]string), + ips: make(map[string]string), + } +} + +func (l *listener) serve() { + for { + conn, err := l.Accept() + if err != nil { + l.log.Error("failue listening on %q: %s", l.Addr(), err) + return + } + + if atomic.LoadInt32(&l.done) != 0 { + l.log.Debug("stopped serving %q", l.Addr()) + conn.Close() + return + } + + l.connCh <- conn + } +} + +func (l *listener) localAddr() string { + if addr, ok := l.Addr().(*net.TCPAddr); ok { + if addr.IP.Equal(net.IPv4zero) { + return net.JoinHostPort("127.0.0.1", strconv.Itoa(addr.Port)) + } + } + return l.Addr().String() +} + +func (l *listener) stop() { + if atomic.CompareAndSwapInt32(&l.done, 0, 1) { + // stop is called when no more connections should be accepted by + // the user-provided listener; as we can't simple close the listener + // to not break the guarantee given by the (*Server).DeleteAddr + // method, we make a dummy connection to break out of serve loop. + // It is safe to make a dummy connection, as either the following + // dial will time out when the listener is busy accepting connections, + // or will get closed immadiately after idle listeners accepts connection + // and returns from the serve loop. + conn, err := net.DialTimeout("tcp", l.localAddr(), defaultTimeout) + if err == nil { + conn.Close() + } + } +} + +func (vaddr *vaddrStorage) Add(l net.Listener, ip net.IP, ident string) { + vaddr.mu.Lock() + defer vaddr.mu.Unlock() + + lis, ok := vaddr.listeners[l] + if !ok { + lis = vaddr.newListener(l) + vaddr.listeners[l] = lis + go lis.serve() + } + + if ip != nil { + lis.ips[ip.String()] = struct{}{} + vaddr.ips[ip.String()] = ident + } else { + vaddr.ports[mustPort(l)] = ident + } +} + +func (vaddr *vaddrStorage) Delete(l net.Listener, ip net.IP) { + vaddr.mu.Lock() + defer vaddr.mu.Unlock() + + lis, ok := vaddr.listeners[l] + if !ok { + return + } + + var stop bool + + if ip != nil { + delete(lis.ips, ip.String()) + delete(vaddr.ips, ip.String()) + + stop = len(lis.ips) == 0 + } else { + delete(vaddr.ports, mustPort(l)) + + stop = true + } + + // Only stop listening for connections when listener has clients + // registered to tunnel the connections to. + if stop { + lis.stop() + delete(vaddr.listeners, l) + } +} + +func (vaddr *vaddrStorage) newListener(l net.Listener) *listener { + return &listener{ + Listener: l, + vaddrOptions: vaddr.vaddrOptions, + ips: make(map[string]struct{}), + } +} + +func (vaddr *vaddrStorage) getIdent(conn net.Conn) (string, bool) { + vaddr.mu.Lock() + defer vaddr.mu.Unlock() + + ip, port, err := parseHostPort(conn.LocalAddr().String()) + if err != nil { + vaddr.log.Debug("failed to get identifier for connection %q: %s", conn.LocalAddr(), err) + return "", false + } + + // First lookup if there's a ip-based route, then try port-base one. + + if ident, ok := vaddr.ips[ip]; ok { + return ident, true + } + + ident, ok := vaddr.ports[port] + return ident, ok +} + +func mustPort(l net.Listener) int { + _, port, err := parseHostPort(l.Addr().String()) + if err != nil { + // This can happened when user passed custom type that + // implements net.Listener, which returns ill-formed + // net.Addr value. + panic("ill-formed net.Addr: " + err.Error()) + } + + return port +} diff --git a/lib/tunnel/virtualhost.go b/lib/tunnel/virtualhost.go new file mode 100644 index 0000000..e0af5ce --- /dev/null +++ b/lib/tunnel/virtualhost.go @@ -0,0 +1,77 @@ +package tunnel + +import ( + "sync" +) + +type vhostStorage interface { + // AddHost adds the given host and identifier to the storage + AddHost(host, identifier string) + + // DeleteHost deletes the given host + DeleteHost(host string) + + // GetHost returns the host name for the given identifier + GetHost(identifier string) (string, bool) + + // GetIdentifier returns the identifier for the given host + GetIdentifier(host string) (string, bool) +} + +type virtualHost struct { + identifier string +} + +// virtualHosts is used for mapping host to users example: host +// "fs-1-fatih.kd.io" belongs to user "arslan" +type virtualHosts struct { + mapping map[string]*virtualHost + sync.Mutex +} + +// newVirtualHosts provides an in memory virtual host storage for mapping +// virtual hosts to identifiers. +func newVirtualHosts() *virtualHosts { + return &virtualHosts{ + mapping: make(map[string]*virtualHost), + } +} + +func (v *virtualHosts) AddHost(host, identifier string) { + v.Lock() + v.mapping[host] = &virtualHost{identifier: identifier} + v.Unlock() +} + +func (v *virtualHosts) DeleteHost(host string) { + v.Lock() + delete(v.mapping, host) + v.Unlock() +} + +// GetIdentifier returns the identifier associated with the given host +func (v *virtualHosts) GetIdentifier(host string) (string, bool) { + v.Lock() + ht, ok := v.mapping[host] + v.Unlock() + + if !ok { + return "", false + } + + return ht.identifier, true +} + +// GetHost returns the host associated with the given identifier +func (v *virtualHosts) GetHost(identifier string) (string, bool) { + v.Lock() + defer v.Unlock() + + for hostname, hst := range v.mapping { + if hst.identifier == identifier { + return hostname, true + } + } + + return "", false +} diff --git a/lib/tunnel/websocket_test.go b/lib/tunnel/websocket_test.go new file mode 100644 index 0000000..c730633 --- /dev/null +++ b/lib/tunnel/websocket_test.go @@ -0,0 +1,69 @@ +package tunnel_test + +import ( + "fmt" + "net/http" + "reflect" + "testing" + + "git.xeserv.us/xena/route/lib/tunnel/tunneltest" +) + +func testWebsocket(name string, n int, t *testing.T, tt *tunneltest.TunnelTest) { + conn, err := websocketDial(tt, "http") + if err != nil { + t.Fatalf("Dial()=%s", err) + } + defer conn.Close() + + for i := 0; i < n; i++ { + want := &EchoMessage{ + Value: fmt.Sprintf("message #%d", i), + Close: i == (n - 1), + } + + err := conn.WriteJSON(want) + if err != nil { + t.Errorf("(test %s) %d: failed sending %q: %s", name, i, want, err) + continue + } + + got := &EchoMessage{} + + err = conn.ReadJSON(got) + if err != nil { + t.Errorf("(test %s) %d: failed reading: %s", name, i, err) + continue + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("(test %s) %d: got %+v, want %+v", name, i, got, want) + } + } +} + +func testHandler(t *testing.T, fn func(w http.ResponseWriter, r *http.Request) error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := fn(w, r); err != nil { + t.Errorf("handler func error: %s", err) + } + } +} + +func TestWebsocket(t *testing.T) { + tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(nil)))) + if err != nil { + t.Fatal(err) + } + + testWebsocket("handlerEchoWS", 100, t, tt) +} + +func TestLatencyWebsocket(t *testing.T) { + tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(sleep)))) + if err != nil { + t.Fatal(err) + } + + testWebsocket("handlerLatencyEchoWS", 20, t, tt) +}