import github.com/koding/tunnel
This commit is contained in:
parent
551017e893
commit
86be40fea0
|
@ -0,0 +1,19 @@
|
|||
language: go
|
||||
|
||||
sudo: false
|
||||
|
||||
addons:
|
||||
apt:
|
||||
packages:
|
||||
- moreutils
|
||||
|
||||
go:
|
||||
- 1.4.3
|
||||
- 1.6.3
|
||||
- 1.7
|
||||
|
||||
script:
|
||||
- export GOMAXPROCS=$(nproc)
|
||||
- gofmt -s -l . | ifne false
|
||||
- go build ./...
|
||||
- go test -race ./...
|
|
@ -0,0 +1,28 @@
|
|||
Copyright (c) 2015 The Koding Authors.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of Koding Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
# Tunnel [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/koding/tunnel) [![Go Report Card](https://goreportcard.com/badge/github.com/koding/tunnel)](https://goreportcard.com/report/github.com/koding/tunnel) [![Build Status](http://img.shields.io/travis/koding/tunnel.svg?style=flat-square)](https://travis-ci.org/koding/tunnel)
|
||||
|
||||
Tunnel is a server/client package that enables to proxy public connections to
|
||||
your local machine over a tunnel connection from the local machine to the
|
||||
public server. What this means is, you can share your localhost even if it
|
||||
doesn't have a Public IP or if it's not reachable from outside.
|
||||
|
||||
It uses the excellent [yamux](https://github.com/hashicorp/yamux) package to
|
||||
multiplex connections between server and client.
|
||||
|
||||
The project is under active development, please vendor it if you want to use it.
|
||||
|
||||
# Usage
|
||||
|
||||
The tunnel package consists of two parts. The `server` and the `client`.
|
||||
|
||||
Server is the public facing part. It's type that satisfies the `http.Handler`.
|
||||
So it's easily pluggable into existing servers.
|
||||
|
||||
|
||||
Let assume that you setup your DNS service so all `*.example.com` domains route
|
||||
to your server at the public IP `203.0.113.0`. Let us first create the server
|
||||
part:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/koding/tunnel"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := &tunnel.ServerConfig{}
|
||||
server, _ := tunnel.NewServer(cfg)
|
||||
server.AddHost("sub.example.com", "1234")
|
||||
http.ListenAndServe(":80", server)
|
||||
}
|
||||
```
|
||||
|
||||
Once you create the `server`, you just plug it into your server. The only
|
||||
detail here is to map a virtualhost to a secret token. The secret token is the
|
||||
only part that needs to be known for the client side.
|
||||
|
||||
Let us now create the client side part:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import "github.com/koding/tunnel"
|
||||
|
||||
func main() {
|
||||
cfg := &tunnel.ClientConfig{
|
||||
Identifier: "1234",
|
||||
ServerAddr: "203.0.113.0:80",
|
||||
}
|
||||
|
||||
client, err := tunnel.NewClient(cfg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
client.Start()
|
||||
}
|
||||
```
|
||||
|
||||
The `Start()` method is by default blocking. As you see you, we just passed the
|
||||
server address and the secret token.
|
||||
|
||||
Now whenever someone hit `sub.example.com`, the request will be proxied to the
|
||||
machine where client is running and hit the local server running `127.0.0.1:80`
|
||||
(assuming there is one). If someone hits `sub.example.com:3000` (assume your
|
||||
server is running at this port), it'll be routed to `127.0.0.1:3000`
|
||||
|
||||
That's it.
|
||||
|
||||
There are many options that can be changed, such as a static local address for
|
||||
your client. Have alook at the
|
||||
[documentation](http://godoc.org/github.com/koding/tunnel)
|
||||
|
||||
|
||||
# Protocol
|
||||
|
||||
The server/client protocol is written in the [spec.md](spec.md) file. Please
|
||||
have a look for more detail.
|
||||
|
||||
|
||||
## License
|
||||
|
||||
The BSD 3-Clause License - see LICENSE for more details
|
|
@ -0,0 +1,565 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/koding/logging"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
//go:generate stringer -type ClientState
|
||||
|
||||
// ErrRedialAborted is emitted on ClientClosed event, when backoff policy
|
||||
// used by a client decided no more reconnection attempts must be made.
|
||||
var ErrRedialAborted = errors.New("unable to restore the connection, aborting")
|
||||
|
||||
// ClientState represents client connection state to tunnel server.
|
||||
type ClientState uint32
|
||||
|
||||
// ClientState enumeration.
|
||||
const (
|
||||
ClientUnknown ClientState = iota
|
||||
ClientStarted
|
||||
ClientConnecting
|
||||
ClientConnected
|
||||
ClientDisconnected
|
||||
ClientClosed // keep it always last
|
||||
)
|
||||
|
||||
// ClientStateChange represents single client state transition.
|
||||
type ClientStateChange struct {
|
||||
Identifier string
|
||||
Previous ClientState
|
||||
Current ClientState
|
||||
Error error
|
||||
}
|
||||
|
||||
// Strings implements the fmt.Stringer interface.
|
||||
func (cs *ClientStateChange) String() string {
|
||||
if cs.Error != nil {
|
||||
return fmt.Sprintf("[%s] %s->%s (%s)", cs.Identifier, cs.Previous, cs.Current, cs.Error)
|
||||
}
|
||||
return fmt.Sprintf("[%s] %s->%s", cs.Identifier, cs.Previous, cs.Current)
|
||||
}
|
||||
|
||||
// Backoff defines behavior of staggering reconnection retries.
|
||||
type Backoff interface {
|
||||
// Next returns the duration to sleep before retrying reconnections.
|
||||
// If the returned value is negative, the retry is aborted.
|
||||
NextBackOff() time.Duration
|
||||
|
||||
// Reset is used to signal a reconnection was successful and next
|
||||
// call to Next should return desired time duration for 1st reconnection
|
||||
// attempt.
|
||||
Reset()
|
||||
}
|
||||
|
||||
// Client is responsible for creating a control connection to a tunnel server,
|
||||
// creating new tunnels and proxy them to tunnel server.
|
||||
type Client struct {
|
||||
// underlying yamux session
|
||||
session *yamux.Session
|
||||
|
||||
// config holds the ClientConfig
|
||||
config *ClientConfig
|
||||
|
||||
// yamuxConfig is passed to new yamux.Session's
|
||||
yamuxConfig *yamux.Config
|
||||
|
||||
// proxy handles local server communication.
|
||||
proxy ProxyFunc
|
||||
|
||||
// startNotify is a chanel user can get to be notified when client is
|
||||
// connected to the server. The preferred way of doing this however,
|
||||
// would be using StateChanges in ClientConfig where user can provide
|
||||
// his own channel.
|
||||
startNotify chan bool
|
||||
// closed is a flag set when client calls Close() and quits.
|
||||
closed bool
|
||||
// closedMu guards both closed flag and startNotify channel. Since library
|
||||
// owns the channel it's cleared when trying to reconnect.
|
||||
closedMu sync.RWMutex
|
||||
|
||||
reqWg sync.WaitGroup
|
||||
ctrlWg sync.WaitGroup
|
||||
|
||||
state ClientState
|
||||
|
||||
// redialBackoff is used to reconnect in exponential backoff intervals
|
||||
redialBackoff Backoff
|
||||
|
||||
log logging.Logger
|
||||
}
|
||||
|
||||
// ClientConfig defines the configuration for the Client
|
||||
type ClientConfig struct {
|
||||
// Identifier is the secret token that needs to be passed to the server.
|
||||
// Required if FetchIdentifier is not set.
|
||||
Identifier string
|
||||
|
||||
// FetchIdentifier can be used to fetch identifier. Required if Identifier
|
||||
// is not set.
|
||||
FetchIdentifier func() (string, error)
|
||||
|
||||
// ServerAddr defines the TCP address of the tunnel server to be connected.
|
||||
// Required if FetchServerAddr is not set.
|
||||
ServerAddr string
|
||||
|
||||
// FetchServerAddr can be used to fetch tunnel server address.
|
||||
// Required if ServerAddress is not set.
|
||||
FetchServerAddr func() (string, error)
|
||||
|
||||
// Dial provides custom transport layer for client server communication.
|
||||
//
|
||||
// If nil, default implementation is to return net.Dial("tcp", address).
|
||||
//
|
||||
// It can be used for connection monitoring, setting different timeouts or
|
||||
// securing the connection.
|
||||
Dial func(network, address string) (net.Conn, error)
|
||||
|
||||
// Proxy defines custom proxing logic. This is optional extension point
|
||||
// where you can provide your local server selection or communication rules.
|
||||
Proxy ProxyFunc
|
||||
|
||||
// StateChanges receives state transition details each time client
|
||||
// connection state changes. The channel is expected to be sufficiently
|
||||
// buffered to keep up with event pace.
|
||||
//
|
||||
// If nil, no information about state transitions are dispatched
|
||||
// by the library.
|
||||
StateChanges chan<- *ClientStateChange
|
||||
|
||||
// Backoff is used to control behavior of staggering reconnection loop.
|
||||
//
|
||||
// If nil, default backoff policy is used which makes a client to never
|
||||
// give up on reconnection.
|
||||
//
|
||||
// If custom backoff is used, client will emit ErrRedialAborted set
|
||||
// with ClientClosed event when no more reconnection atttemps should
|
||||
// be made.
|
||||
Backoff Backoff
|
||||
|
||||
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
|
||||
// yamux.DefaultConfig() is used.
|
||||
YamuxConfig *yamux.Config
|
||||
|
||||
// Log defines the logger. If nil a default logging.Logger is used.
|
||||
Log logging.Logger
|
||||
|
||||
// Debug enables debug mode, enable only if you want to debug the server.
|
||||
Debug bool
|
||||
|
||||
// DEPRECATED:
|
||||
|
||||
// LocalAddr is DEPRECATED please use ProxyHTTP.LocalAddr, see ProxyOverwrite for more details.
|
||||
LocalAddr string
|
||||
|
||||
// FetchLocalAddr is DEPRECATED please use ProxyTCP.FetchLocalAddr, see ProxyOverwrite for more details.
|
||||
FetchLocalAddr func(port int) (string, error)
|
||||
}
|
||||
|
||||
// verify is used to verify the ClientConfig
|
||||
func (c *ClientConfig) verify() error {
|
||||
if c.ServerAddr == "" && c.FetchServerAddr == nil {
|
||||
return errors.New("neither ServerAddr nor FetchServerAddr is set")
|
||||
}
|
||||
|
||||
if c.Identifier == "" && c.FetchIdentifier == nil {
|
||||
return errors.New("neither Identifier nor FetchIdentifier is set")
|
||||
}
|
||||
|
||||
if c.YamuxConfig != nil {
|
||||
if err := yamux.VerifyConfig(c.YamuxConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if c.Proxy != nil && (c.LocalAddr != "" || c.FetchLocalAddr != nil) {
|
||||
return errors.New("both Proxy and LocalAddr or FetchLocalAddr are set")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewClient creates a new tunnel that is established between the serverAddr
|
||||
// and localAddr. It exits if it can't create a new control connection to the
|
||||
// server. If localAddr is empty client will always try to proxy to a local
|
||||
// port.
|
||||
func NewClient(cfg *ClientConfig) (*Client, error) {
|
||||
if err := cfg.verify(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
yamuxConfig := yamux.DefaultConfig()
|
||||
if cfg.YamuxConfig != nil {
|
||||
yamuxConfig = cfg.YamuxConfig
|
||||
}
|
||||
|
||||
var proxy = DefaultProxy
|
||||
if cfg.Proxy != nil {
|
||||
proxy = cfg.Proxy
|
||||
}
|
||||
// DEPRECATED API SUPPORT
|
||||
if cfg.LocalAddr != "" || cfg.FetchLocalAddr != nil {
|
||||
var f ProxyFuncs
|
||||
if cfg.LocalAddr != "" {
|
||||
f.HTTP = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
|
||||
f.WS = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
|
||||
}
|
||||
if cfg.FetchLocalAddr != nil {
|
||||
f.TCP = (&TCPProxy{FetchLocalAddr: cfg.FetchLocalAddr}).Proxy
|
||||
}
|
||||
proxy = Proxy(f)
|
||||
}
|
||||
|
||||
var bo Backoff = newForeverBackoff()
|
||||
if cfg.Backoff != nil {
|
||||
bo = cfg.Backoff
|
||||
}
|
||||
|
||||
log := newLogger("tunnel-client", cfg.Debug)
|
||||
if cfg.Log != nil {
|
||||
log = cfg.Log
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
config: cfg,
|
||||
yamuxConfig: yamuxConfig,
|
||||
proxy: proxy,
|
||||
startNotify: make(chan bool, 1),
|
||||
redialBackoff: bo,
|
||||
log: log,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Start starts the client and connects to the server with the identifier.
|
||||
// client.FetchIdentifier() will be used if it's not nil. It's supports
|
||||
// reconnecting with exponential backoff intervals when the connection to the
|
||||
// server disconnects. Call client.Close() to shutdown the client completely. A
|
||||
// successful connection will cause StartNotify() to receive a value.
|
||||
func (c *Client) Start() {
|
||||
fetchIdent := func() (string, error) {
|
||||
if c.config.FetchIdentifier != nil {
|
||||
return c.config.FetchIdentifier()
|
||||
}
|
||||
|
||||
return c.config.Identifier, nil
|
||||
}
|
||||
|
||||
fetchServerAddr := func() (string, error) {
|
||||
if c.config.FetchServerAddr != nil {
|
||||
return c.config.FetchServerAddr()
|
||||
}
|
||||
|
||||
return c.config.ServerAddr, nil
|
||||
}
|
||||
|
||||
c.changeState(ClientStarted, nil)
|
||||
|
||||
c.redialBackoff.Reset()
|
||||
var lastErr error
|
||||
for {
|
||||
prev := c.changeState(ClientConnecting, lastErr)
|
||||
|
||||
if c.isRetry(prev) {
|
||||
dur := c.redialBackoff.NextBackOff()
|
||||
if dur < 0 {
|
||||
c.setClosed(true)
|
||||
c.changeState(ClientClosed, ErrRedialAborted)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(dur)
|
||||
|
||||
// exit if closed
|
||||
if c.isClosed() {
|
||||
c.changeState(ClientClosed, lastErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := fetchIdent()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
c.log.Critical("client fetch identifier error: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr, err := fetchServerAddr()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
c.log.Critical("client fetch server address error: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.setClosed(false)
|
||||
|
||||
if err := c.connect(identifier, serverAddr); err != nil {
|
||||
lastErr = err
|
||||
c.log.Debug("client connect error: %s", err)
|
||||
}
|
||||
|
||||
// exit if closed
|
||||
if c.isClosed() {
|
||||
c.changeState(ClientClosed, lastErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the client and shutdowns the connection to the tunnel server
|
||||
func (c *Client) Close() error {
|
||||
defer c.setClosed(true)
|
||||
|
||||
if c.session == nil {
|
||||
return errors.New("session is not initialized")
|
||||
}
|
||||
|
||||
// wait until all connections are finished
|
||||
waitCh := make(chan struct{})
|
||||
go func() {
|
||||
if err := c.session.GoAway(); err != nil {
|
||||
c.log.Debug("Session go away failed: %s", err)
|
||||
}
|
||||
|
||||
c.reqWg.Wait()
|
||||
close(waitCh)
|
||||
}()
|
||||
select {
|
||||
case <-waitCh:
|
||||
// ok
|
||||
case <-time.After(time.Second * 10):
|
||||
c.log.Info("Timeout waiting for connections to finish")
|
||||
}
|
||||
|
||||
if err := c.session.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isClosed securely checks if client is marked as closed.
|
||||
func (c *Client) isClosed() bool {
|
||||
c.closedMu.RLock()
|
||||
defer c.closedMu.RUnlock()
|
||||
return c.closed
|
||||
}
|
||||
|
||||
// setClosed securely marks client as closed (or not closed). If not closed
|
||||
// also empty the value inside the startNotify channel by retrieving it (if any),
|
||||
// so it doesn't block during connect, when the client was closed and started again,
|
||||
// and startNotify was never listened to.
|
||||
func (c *Client) setClosed(closed bool) {
|
||||
c.closedMu.Lock()
|
||||
defer c.closedMu.Unlock()
|
||||
c.closed = closed
|
||||
|
||||
if !closed {
|
||||
// clear channel
|
||||
select {
|
||||
case <-c.startNotify:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startNotifyIfNeeded sends ok to startNotify channel if it's listened to.
|
||||
// This function is called by connect when connection was successful.
|
||||
func (c *Client) startNotifyIfNeeded() {
|
||||
c.closedMu.RLock()
|
||||
if !c.closed {
|
||||
c.log.Debug("sending ok to startNotify chan")
|
||||
select {
|
||||
case c.startNotify <- true:
|
||||
default:
|
||||
// reaching here means the client never read the signal via
|
||||
// StartNotify(). This is OK, we shouldn't except it the consumer
|
||||
// to read from this channel. It's optional, so we just drop the
|
||||
// signal.
|
||||
c.log.Debug("startNotify message was dropped")
|
||||
}
|
||||
}
|
||||
c.closedMu.RUnlock()
|
||||
}
|
||||
|
||||
// StartNotify returns a channel that receives a single value when the client
|
||||
// established a successful connection to the server.
|
||||
func (c *Client) StartNotify() <-chan bool {
|
||||
return c.startNotify
|
||||
}
|
||||
|
||||
func (c *Client) changeState(state ClientState, err error) (prev ClientState) {
|
||||
prev = ClientState(atomic.LoadUint32((*uint32)(&c.state)))
|
||||
|
||||
if c.config.StateChanges != nil {
|
||||
change := &ClientStateChange{
|
||||
Identifier: c.config.Identifier,
|
||||
Previous: ClientState(prev),
|
||||
Current: state,
|
||||
Error: err,
|
||||
}
|
||||
|
||||
select {
|
||||
case c.config.StateChanges <- change:
|
||||
default:
|
||||
c.log.Warning("Dropping state change due to slow reader: %s", change)
|
||||
}
|
||||
}
|
||||
|
||||
atomic.CompareAndSwapUint32((*uint32)(&c.state), uint32(prev), uint32(state))
|
||||
|
||||
return prev
|
||||
}
|
||||
|
||||
func (c *Client) isRetry(state ClientState) bool {
|
||||
return state != ClientStarted && state != ClientClosed
|
||||
}
|
||||
|
||||
func (c *Client) connect(identifier, serverAddr string) error {
|
||||
c.log.Debug("Trying to connect to %q with identifier %q", serverAddr, identifier)
|
||||
|
||||
conn, err := c.dial(serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
remoteURL := controlURL(conn)
|
||||
c.log.Debug("CONNECT to %q", remoteURL)
|
||||
req, err := http.NewRequest("CONNECT", remoteURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating request to %s: %s", remoteURL, err)
|
||||
}
|
||||
|
||||
req.Header.Set(proto.ClientIdentifierHeader, identifier)
|
||||
|
||||
c.log.Debug("Writing request to TCP: %+v", req)
|
||||
|
||||
if err := req.Write(conn); err != nil {
|
||||
return fmt.Errorf("writing CONNECT request to %s failed: %s", req.URL, err)
|
||||
}
|
||||
|
||||
c.log.Debug("Reading response from TCP")
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading CONNECT response from %s failed: %s", req.URL, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK || resp.Status != proto.Connected {
|
||||
out, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tunnel server error: status=%d, error=%s", resp.StatusCode, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("tunnel server error: status=%d, body=%s", resp.StatusCode, string(out))
|
||||
}
|
||||
|
||||
c.ctrlWg.Wait() // wait until previous listenControl observes disconnection
|
||||
|
||||
c.session, err = yamux.Client(conn, c.yamuxConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("session initialization failed: %s", err)
|
||||
}
|
||||
|
||||
var stream net.Conn
|
||||
openStream := func() error {
|
||||
// this is blocking until client opens a session to us
|
||||
stream, err = c.session.Open()
|
||||
return err
|
||||
}
|
||||
|
||||
// if we don't receive anything from the server, we'll timeout
|
||||
select {
|
||||
case err := <-async(openStream):
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting for session to open failed: %s", err)
|
||||
}
|
||||
case <-time.After(time.Second * 10):
|
||||
if stream != nil {
|
||||
stream.Close()
|
||||
}
|
||||
return errors.New("timeout opening session")
|
||||
}
|
||||
|
||||
if _, err := stream.Write([]byte(proto.HandshakeRequest)); err != nil {
|
||||
return fmt.Errorf("writing handshake request failed: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, len(proto.HandshakeResponse))
|
||||
if _, err := stream.Read(buf); err != nil {
|
||||
return fmt.Errorf("reading handshake response failed: %s", err)
|
||||
}
|
||||
|
||||
if string(buf) != proto.HandshakeResponse {
|
||||
return fmt.Errorf("invalid handshake response, received: %s", string(buf))
|
||||
}
|
||||
|
||||
ct := newControl(stream)
|
||||
c.log.Debug("client has started successfully")
|
||||
c.redialBackoff.Reset() // we successfully connected, so we can reset the backoff
|
||||
|
||||
c.startNotifyIfNeeded()
|
||||
|
||||
return c.listenControl(ct)
|
||||
}
|
||||
|
||||
func (c *Client) dial(serverAddr string) (net.Conn, error) {
|
||||
if c.config.Dial != nil {
|
||||
return c.config.Dial("tcp", serverAddr)
|
||||
}
|
||||
|
||||
return net.Dial("tcp", serverAddr)
|
||||
}
|
||||
|
||||
func (c *Client) listenControl(ct *control) error {
|
||||
c.ctrlWg.Add(1)
|
||||
defer c.ctrlWg.Done()
|
||||
|
||||
c.changeState(ClientConnected, nil)
|
||||
|
||||
for {
|
||||
var msg proto.ControlMessage
|
||||
if err := ct.dec.Decode(&msg); err != nil {
|
||||
c.reqWg.Wait() // wait until all requests are finished
|
||||
c.session.GoAway()
|
||||
c.session.Close()
|
||||
c.changeState(ClientDisconnected, err)
|
||||
|
||||
return fmt.Errorf("failure decoding control message: %s", err)
|
||||
}
|
||||
|
||||
c.log.Debug("Received control msg %+v", msg)
|
||||
c.log.Debug("Opening a new stream from server session")
|
||||
|
||||
remote, err := c.session.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isHTTP := msg.Protocol == proto.HTTP
|
||||
if isHTTP {
|
||||
c.reqWg.Add(1)
|
||||
}
|
||||
go func() {
|
||||
c.proxy(remote, &msg)
|
||||
if isHTTP {
|
||||
c.reqWg.Done()
|
||||
}
|
||||
remote.Close()
|
||||
}()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// Code generated by "stringer -type ClientState"; DO NOT EDIT
|
||||
|
||||
package tunnel
|
||||
|
||||
import "fmt"
|
||||
|
||||
const _ClientState_name = "ClientUnknownClientStartedClientConnectingClientConnectedClientDisconnectedClientClosed"
|
||||
|
||||
var _ClientState_index = [...]uint8{0, 13, 26, 42, 57, 75, 87}
|
||||
|
||||
func (i ClientState) String() string {
|
||||
if i >= ClientState(len(_ClientState_index)-1) {
|
||||
return fmt.Sprintf("ClientState(%d)", i)
|
||||
}
|
||||
return _ClientState_name[_ClientState_index[i]:_ClientState_index[i+1]]
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var errControlClosed = errors.New("control connection is closed")
|
||||
|
||||
type control struct {
|
||||
// enc and dec are responsible for encoding and decoding json values forth
|
||||
// and back
|
||||
enc *json.Encoder
|
||||
dec *json.Decoder
|
||||
|
||||
// underlying connection responsible for encoder and decoder
|
||||
nc net.Conn
|
||||
|
||||
// identifier associated with this control
|
||||
identifier string
|
||||
|
||||
mu sync.Mutex // guards the following
|
||||
closed bool // if Close() and quits
|
||||
}
|
||||
|
||||
func newControl(nc net.Conn) *control {
|
||||
c := &control{
|
||||
enc: json.NewEncoder(nc),
|
||||
dec: json.NewDecoder(nc),
|
||||
nc: nc,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *control) send(v interface{}) error {
|
||||
if c.enc == nil {
|
||||
return errors.New("encoder is not initialized")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return errControlClosed
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return c.enc.Encode(v)
|
||||
}
|
||||
|
||||
func (c *control) recv(v interface{}) error {
|
||||
if c.dec == nil {
|
||||
return errors.New("decoder is not initialized")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return errControlClosed
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return c.dec.Decode(v)
|
||||
}
|
||||
|
||||
func (c *control) Close() error {
|
||||
if c.nc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.closed = true
|
||||
c.mu.Unlock()
|
||||
|
||||
return c.nc.Close()
|
||||
}
|
||||
|
||||
type controls struct {
|
||||
sync.Mutex
|
||||
controls map[string]*control
|
||||
}
|
||||
|
||||
func newControls() *controls {
|
||||
return &controls{
|
||||
controls: make(map[string]*control),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controls) getControl(identifier string) (*control, bool) {
|
||||
c.Lock()
|
||||
control, ok := c.controls[identifier]
|
||||
c.Unlock()
|
||||
return control, ok
|
||||
}
|
||||
|
||||
func (c *controls) addControl(identifier string, control *control) {
|
||||
control.identifier = identifier
|
||||
|
||||
c.Lock()
|
||||
c.controls[identifier] = control
|
||||
c.Unlock()
|
||||
}
|
||||
|
||||
func (c *controls) deleteControl(identifier string) {
|
||||
c.Lock()
|
||||
delete(c.controls, identifier)
|
||||
c.Unlock()
|
||||
}
|
|
@ -0,0 +1,263 @@
|
|||
package tunnel_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.xeserv.us/xena/route/lib/tunnel"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano() + int64(os.Getpid()))
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
type EchoMessage struct {
|
||||
Value string `json:"value,omitempty"`
|
||||
Close bool `json:"close,omitempty"`
|
||||
}
|
||||
|
||||
var timeout = 10 * time.Second
|
||||
|
||||
var dialer = &websocket.Dialer{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
HandshakeTimeout: timeout,
|
||||
NetDial: func(_, addr string) (net.Conn, error) {
|
||||
return net.DialTimeout("tcp4", addr, timeout)
|
||||
},
|
||||
}
|
||||
|
||||
func echoHTTP(tt *tunneltest.TunnelTest, echo string) (string, error) {
|
||||
req := tt.Request("http", url.Values{"echo": []string{echo}})
|
||||
if req == nil {
|
||||
return "", fmt.Errorf(`tunnel "http" does not exist`)
|
||||
}
|
||||
|
||||
req.Close = rand.Int()%2 == 0
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
p, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(bytes.TrimSpace(p)), nil
|
||||
}
|
||||
|
||||
func echoTCP(tt *tunneltest.TunnelTest, echo string) (string, error) {
|
||||
return echoTCPIdent(tt, echo, "tcp")
|
||||
}
|
||||
|
||||
func echoTCPIdent(tt *tunneltest.TunnelTest, echo, ident string) (string, error) {
|
||||
addr := tt.Addr(ident)
|
||||
if addr == nil {
|
||||
return "", fmt.Errorf("tunnel %q does not exist", ident)
|
||||
}
|
||||
s := addr.String()
|
||||
ip := tt.Tunnels[ident].IP
|
||||
if ip != nil {
|
||||
_, port, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s = net.JoinHostPort(ip.String(), port)
|
||||
}
|
||||
|
||||
c, err := dialTCP(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
c.out <- echo
|
||||
|
||||
select {
|
||||
case reply := <-c.in:
|
||||
return reply, nil
|
||||
case <-time.After(tcpTimeout):
|
||||
return "", fmt.Errorf("timed out waiting for reply from %s (%s) after %v", s, addr, tcpTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func websocketDial(tt *tunneltest.TunnelTest, ident string) (*websocket.Conn, error) {
|
||||
req := tt.Request(ident, nil)
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("no client found for ident %q", ident)
|
||||
}
|
||||
|
||||
h := http.Header{"Host": {req.Host}}
|
||||
wsurl := fmt.Sprintf("ws://%s", tt.ServerAddr())
|
||||
|
||||
conn, _, err := dialer.Dial(wsurl, h)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func sleep() {
|
||||
time.Sleep(time.Duration(rand.Intn(2000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
func handlerEchoWS(sleepFn func()) func(w http.ResponseWriter, r *http.Request) error {
|
||||
return func(w http.ResponseWriter, r *http.Request) (e error) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if e == nil {
|
||||
e = err
|
||||
}
|
||||
}()
|
||||
|
||||
if sleepFn != nil {
|
||||
sleepFn()
|
||||
}
|
||||
|
||||
for {
|
||||
var msg EchoMessage
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ReadJSON error: %s", err)
|
||||
}
|
||||
|
||||
if sleepFn != nil {
|
||||
sleepFn()
|
||||
}
|
||||
|
||||
err = conn.WriteJSON(&msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("WriteJSON error: %s", err)
|
||||
}
|
||||
|
||||
if msg.Close {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handlerEchoHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, r.URL.Query().Get("echo"))
|
||||
}
|
||||
|
||||
func handlerLatencyEchoHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
sleep()
|
||||
handlerEchoHTTP(w, r)
|
||||
}
|
||||
|
||||
func handlerEchoTCP(conn net.Conn) {
|
||||
io.Copy(conn, conn)
|
||||
}
|
||||
|
||||
func handlerLatencyEchoTCP(conn net.Conn) {
|
||||
sleep()
|
||||
handlerEchoTCP(conn)
|
||||
}
|
||||
|
||||
var tcpTimeout = 10 * time.Second
|
||||
|
||||
type tcpClient struct {
|
||||
conn net.Conn
|
||||
scanner *bufio.Scanner
|
||||
in chan string
|
||||
out chan string
|
||||
}
|
||||
|
||||
func (c *tcpClient) loop() {
|
||||
for out := range c.out {
|
||||
if _, err := fmt.Fprintln(c.conn, out); err != nil {
|
||||
log.Printf("[tunnelclient] error writing %q to %q: %s", out, c.conn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
if !c.scanner.Scan() {
|
||||
log.Printf("[tunnelclient] error reading from %q: %v", c.conn.RemoteAddr(), c.scanner.Err())
|
||||
return
|
||||
}
|
||||
|
||||
c.in <- c.scanner.Text()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *tcpClient) Close() error {
|
||||
close(c.out)
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func dialTCP(addr string) (*tcpClient, error) {
|
||||
conn, err := net.DialTimeout("tcp", addr, tcpTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &tcpClient{
|
||||
conn: conn,
|
||||
scanner: bufio.NewScanner(conn),
|
||||
in: make(chan string, 1),
|
||||
out: make(chan string, 1),
|
||||
}
|
||||
|
||||
go c.loop()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func singleHTTP(handler interface{}) map[string]*tunneltest.Tunnel {
|
||||
return singleRecHTTP(handler, nil)
|
||||
}
|
||||
|
||||
func singleRecHTTP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel {
|
||||
return map[string]*tunneltest.Tunnel{
|
||||
"http": {
|
||||
Type: tunneltest.TypeHTTP,
|
||||
LocalAddr: "127.0.0.1:0",
|
||||
Handler: handler,
|
||||
StateChanges: stateChanges,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func singleTCP(handler interface{}) map[string]*tunneltest.Tunnel {
|
||||
return singleRecTCP(handler, nil)
|
||||
}
|
||||
|
||||
func singleRecTCP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel {
|
||||
return map[string]*tunneltest.Tunnel{
|
||||
"http": {
|
||||
Type: tunneltest.TypeHTTP,
|
||||
LocalAddr: "127.0.0.1:0",
|
||||
Handler: handlerEchoHTTP,
|
||||
StateChanges: stateChanges,
|
||||
},
|
||||
"tcp": {
|
||||
Type: tunneltest.TypeTCP,
|
||||
ClientIdent: "http",
|
||||
LocalAddr: "127.0.0.1:0",
|
||||
RemoteAddr: "127.0.0.1:0",
|
||||
Handler: handler,
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/koding/logging"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
httpLog = logging.NewLogger("http")
|
||||
)
|
||||
|
||||
// HTTPProxy forwards HTTP traffic.
|
||||
//
|
||||
// When tunnel server requests a connection it's proxied to 127.0.0.1:incomingPort
|
||||
// where incomingPort is control message LocalPort.
|
||||
// Usually this is tunnel server's public exposed Port.
|
||||
// This behaviour can be changed by setting LocalAddr or FetchLocalAddr.
|
||||
// FetchLocalAddr takes precedence over LocalAddr.
|
||||
//
|
||||
// When connection to local server cannot be established proxy responds with http error message.
|
||||
type HTTPProxy struct {
|
||||
// LocalAddr defines the TCP address of the local server.
|
||||
// This is optional if you want to specify a single TCP address.
|
||||
LocalAddr string
|
||||
// FetchLocalAddr is used for looking up TCP address of the server.
|
||||
// This is optional if you want to specify a dynamic TCP address based on incommig port.
|
||||
FetchLocalAddr func(port int) (string, error)
|
||||
// ErrorResp is custom response send to tunnel server when client cannot
|
||||
// establish connection to local server. If not set a default "no local server"
|
||||
// response is sent.
|
||||
ErrorResp *http.Response
|
||||
// Log is a custom logger that can be used for the proxy.
|
||||
// If not set a "http" logger is used.
|
||||
Log logging.Logger
|
||||
}
|
||||
|
||||
// Proxy is a ProxyFunc.
|
||||
func (p *HTTPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) {
|
||||
if msg.Protocol != proto.HTTP && msg.Protocol != proto.WS {
|
||||
panic("Proxy mismatch")
|
||||
}
|
||||
|
||||
var log = p.log()
|
||||
|
||||
var port = msg.LocalPort
|
||||
if port == 0 {
|
||||
port = 80
|
||||
}
|
||||
|
||||
var localAddr = fmt.Sprintf("127.0.0.1:%d", port)
|
||||
if p.LocalAddr != "" {
|
||||
localAddr = p.LocalAddr
|
||||
} else if p.FetchLocalAddr != nil {
|
||||
l, err := p.FetchLocalAddr(msg.LocalPort)
|
||||
if err != nil {
|
||||
log.Warning("Failed to get custom local address: %s", err)
|
||||
p.sendError(remote)
|
||||
return
|
||||
}
|
||||
localAddr = l
|
||||
}
|
||||
|
||||
log.Debug("Dialing local server %q", localAddr)
|
||||
local, err := net.DialTimeout("tcp", localAddr, defaultTimeout)
|
||||
if err != nil {
|
||||
log.Error("Dialing local server %q failed: %s", localAddr, err)
|
||||
p.sendError(remote)
|
||||
return
|
||||
}
|
||||
|
||||
Join(local, remote, log)
|
||||
}
|
||||
|
||||
func (p *HTTPProxy) sendError(remote net.Conn) {
|
||||
var w = noLocalServer()
|
||||
if p.ErrorResp != nil {
|
||||
w = p.ErrorResp
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
w.Write(buf)
|
||||
if _, err := io.Copy(remote, buf); err != nil {
|
||||
var log = p.log()
|
||||
log.Debug("Copy in-mem response error: %s", err)
|
||||
}
|
||||
|
||||
remote.Close()
|
||||
}
|
||||
|
||||
func noLocalServer() *http.Response {
|
||||
body := bytes.NewBufferString("no local server")
|
||||
return &http.Response{
|
||||
Status: http.StatusText(http.StatusServiceUnavailable),
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Body: ioutil.NopCloser(body),
|
||||
ContentLength: int64(body.Len()),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProxy) log() logging.Logger {
|
||||
if p.Log != nil {
|
||||
return p.Log
|
||||
}
|
||||
return httpLog
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package proto
|
||||
|
||||
// ControlMessage is sent from server to client to establish tunneled connection.
|
||||
type ControlMessage struct {
|
||||
Action Action `json:"action"`
|
||||
Protocol Type `json:"transportProtocol"`
|
||||
LocalPort int `json:"localPort"`
|
||||
}
|
||||
|
||||
// Action represents type of ControlMsg request.
|
||||
type Action int
|
||||
|
||||
// ControlMessage actions.
|
||||
const (
|
||||
RequestClientSession Action = iota + 1
|
||||
)
|
||||
|
||||
// Type represents tunneled connection type.
|
||||
type Type int
|
||||
|
||||
// ControlMessage protocols.
|
||||
const (
|
||||
HTTP Type = iota + 1
|
||||
TCP
|
||||
WS
|
||||
)
|
|
@ -0,0 +1,19 @@
|
|||
// Package proto defines tunnel client server communication protocol.
|
||||
package proto
|
||||
|
||||
const (
|
||||
// ControlPath is http.Handler url path for control connection.
|
||||
ControlPath = "/_controlPath/"
|
||||
|
||||
// ClientIdentifierHeader is header carrying information about tunnel identifier.
|
||||
ClientIdentifierHeader = "X-KTunnel-Identifier"
|
||||
|
||||
// control messages
|
||||
|
||||
// Connected is message sent by server to client when control connection was established.
|
||||
Connected = "200 Connected to Tunnel"
|
||||
// HandshakeRequest is hello message sent by client to server.
|
||||
HandshakeRequest = "controlHandshake"
|
||||
// HandshakeResponse is response to HandshakeRequest sent by server to client.
|
||||
HandshakeResponse = "controlOk"
|
||||
)
|
|
@ -0,0 +1,101 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/koding/logging"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
)
|
||||
|
||||
// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back.
|
||||
type ProxyFunc func(remote net.Conn, msg *proto.ControlMessage)
|
||||
|
||||
var (
|
||||
// DefaultProxyFuncs holds global default proxy functions for all transport protocols.
|
||||
DefaultProxyFuncs = ProxyFuncs{
|
||||
HTTP: new(HTTPProxy).Proxy,
|
||||
TCP: new(TCPProxy).Proxy,
|
||||
WS: new(HTTPProxy).Proxy,
|
||||
}
|
||||
// DefaultProxy is a ProxyFunc that uses DefaultProxyFuncs.
|
||||
DefaultProxy = Proxy(ProxyFuncs{})
|
||||
)
|
||||
|
||||
// ProxyFuncs is a collection of ProxyFunc.
|
||||
type ProxyFuncs struct {
|
||||
// HTTP is custom implementation of HTTP proxing.
|
||||
HTTP ProxyFunc
|
||||
// TCP is custom implementation of TCP proxing.
|
||||
TCP ProxyFunc
|
||||
// WS is custom implementation of web socket proxing.
|
||||
WS ProxyFunc
|
||||
}
|
||||
|
||||
// Proxy returns a ProxyFunc that uses custom function if provided, otherwise falls back to DefaultProxyFuncs.
|
||||
func Proxy(p ProxyFuncs) ProxyFunc {
|
||||
return func(remote net.Conn, msg *proto.ControlMessage) {
|
||||
var f ProxyFunc
|
||||
switch msg.Protocol {
|
||||
case proto.HTTP:
|
||||
f = DefaultProxyFuncs.HTTP
|
||||
if p.HTTP != nil {
|
||||
f = p.HTTP
|
||||
}
|
||||
case proto.TCP:
|
||||
f = DefaultProxyFuncs.TCP
|
||||
if p.TCP != nil {
|
||||
f = p.TCP
|
||||
}
|
||||
case proto.WS:
|
||||
f = DefaultProxyFuncs.WS
|
||||
if p.WS != nil {
|
||||
f = p.WS
|
||||
}
|
||||
}
|
||||
|
||||
if f == nil {
|
||||
logging.Error("Could not determine proxy function for %v", msg)
|
||||
remote.Close()
|
||||
}
|
||||
|
||||
f(remote, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Join copies data between local and remote connections.
|
||||
// It reads from one connection and writes to the other.
|
||||
// It's a building block for ProxyFunc implementations.
|
||||
func Join(local, remote net.Conn, log logging.Logger) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
transfer := func(side string, dst, src net.Conn) {
|
||||
log.Debug("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
|
||||
|
||||
n, err := io.Copy(dst, src)
|
||||
if err != nil {
|
||||
log.Error("%s: copy error: %s", side, err)
|
||||
}
|
||||
|
||||
if err := src.Close(); err != nil {
|
||||
log.Debug("%s: close error: %s", side, err)
|
||||
}
|
||||
|
||||
// not for yamux streams, but for client to local server connections
|
||||
if d, ok := dst.(*net.TCPConn); ok {
|
||||
if err := d.CloseWrite(); err != nil {
|
||||
log.Debug("%s: closeWrite error: %s", side, err)
|
||||
}
|
||||
|
||||
}
|
||||
wg.Done()
|
||||
log.Debug("done proxing %s -> %s: %d bytes", src.RemoteAddr(), dst.RemoteAddr(), n)
|
||||
}
|
||||
|
||||
go transfer("remote to local", local, remote)
|
||||
go transfer("local to remote", remote, local)
|
||||
|
||||
wg.Wait()
|
||||
}
|
|
@ -0,0 +1,755 @@
|
|||
// Package tunnel is a server/client package that enables to proxy public
|
||||
// connections to your local machine over a tunnel connection from the local
|
||||
// machine to the public server.
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/koding/logging"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoClientSession = errors.New("no client session established")
|
||||
defaultTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// Server is responsible for proxying public connections to the client over a
|
||||
// tunnel connection. It also listens to control messages from the client.
|
||||
type Server struct {
|
||||
// pending contains the channel that is associated with each new tunnel request.
|
||||
pending map[string]chan net.Conn
|
||||
// pendingMu protects the pending map.
|
||||
pendingMu sync.Mutex
|
||||
|
||||
// sessions contains a session per virtual host.
|
||||
// Sessions provides multiplexing over one connection.
|
||||
sessions map[string]*yamux.Session
|
||||
// sessionsMu protects sessions.
|
||||
sessionsMu sync.Mutex
|
||||
|
||||
// controls contains the control connection from the client to the server.
|
||||
controls *controls
|
||||
|
||||
// virtualHosts is used to map public hosts to remote clients.
|
||||
virtualHosts vhostStorage
|
||||
|
||||
// virtualAddrs.
|
||||
virtualAddrs *vaddrStorage
|
||||
|
||||
// connCh is used to publish accepted connections for tcp tunnels.
|
||||
connCh chan net.Conn
|
||||
|
||||
// onConnectCallbacks contains client callbacks called when control
|
||||
// session is established for a client with given identifier.
|
||||
onConnectCallbacks *callbacks
|
||||
|
||||
// onDisconnectCallbacks contains client callbacks called when control
|
||||
// session is closed for a client with given identifier.
|
||||
onDisconnectCallbacks *callbacks
|
||||
|
||||
// states represents current clients' connections state.
|
||||
states map[string]ClientState
|
||||
// statesMu protects states.
|
||||
statesMu sync.RWMutex
|
||||
// stateCh notifies receiver about client state changes.
|
||||
stateCh chan<- *ClientStateChange
|
||||
|
||||
// httpDirector is provided by ServerConfig, if not nil decorates http requests
|
||||
// before forwarding them to client.
|
||||
httpDirector func(*http.Request)
|
||||
|
||||
// yamuxConfig is passed to new yamux.Session's
|
||||
yamuxConfig *yamux.Config
|
||||
|
||||
log logging.Logger
|
||||
}
|
||||
|
||||
// ServerConfig defines the configuration for the Server
|
||||
type ServerConfig struct {
|
||||
// StateChanges receives state transition details each time client
|
||||
// connection state changes. The channel is expected to be sufficiently
|
||||
// buffered to keep up with event pace.
|
||||
//
|
||||
// If nil, no information about state transitions are dispatched
|
||||
// by the library.
|
||||
StateChanges chan<- *ClientStateChange
|
||||
|
||||
// Director is a function that modifies HTTP request into a new HTTP request
|
||||
// before sending to client. If nil no modifications are done.
|
||||
Director func(*http.Request)
|
||||
|
||||
// Debug enables debug mode, enable only if you want to debug the server
|
||||
Debug bool
|
||||
|
||||
// Log defines the logger. If nil a default logging.Logger is used.
|
||||
Log logging.Logger
|
||||
|
||||
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
|
||||
// yamux.DefaultConfig() is used.
|
||||
YamuxConfig *yamux.Config
|
||||
}
|
||||
|
||||
// NewServer creates a new Server. The defaults are used if config is nil.
|
||||
func NewServer(cfg *ServerConfig) (*Server, error) {
|
||||
yamuxConfig := yamux.DefaultConfig()
|
||||
if cfg.YamuxConfig != nil {
|
||||
if err := yamux.VerifyConfig(cfg.YamuxConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
yamuxConfig = cfg.YamuxConfig
|
||||
}
|
||||
|
||||
log := newLogger("tunnel-server", cfg.Debug)
|
||||
if cfg.Log != nil {
|
||||
log = cfg.Log
|
||||
}
|
||||
|
||||
connCh := make(chan net.Conn)
|
||||
|
||||
opts := &vaddrOptions{
|
||||
connCh: connCh,
|
||||
log: log,
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
pending: make(map[string]chan net.Conn),
|
||||
sessions: make(map[string]*yamux.Session),
|
||||
onConnectCallbacks: newCallbacks("OnConnect"),
|
||||
onDisconnectCallbacks: newCallbacks("OnDisconnect"),
|
||||
virtualHosts: newVirtualHosts(),
|
||||
virtualAddrs: newVirtualAddrs(opts),
|
||||
controls: newControls(),
|
||||
states: make(map[string]ClientState),
|
||||
stateCh: cfg.StateChanges,
|
||||
httpDirector: cfg.Director,
|
||||
yamuxConfig: yamuxConfig,
|
||||
connCh: connCh,
|
||||
log: log,
|
||||
}
|
||||
|
||||
go s.serveTCP()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ServeHTTP is a tunnel that creates an http/websocket tunnel between a
|
||||
// public connection and the client connection.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// if the user didn't add the control and tunnel handler manually, we'll
|
||||
// going to infer and call the respective path handlers.
|
||||
switch path.Clean(r.URL.Path) + "/" {
|
||||
case proto.ControlPath:
|
||||
s.checkConnect(s.controlHandler).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.handleHTTP(w, r); err != nil {
|
||||
if !strings.Contains(err.Error(), "no virtual host available") { // this one is outputted too much, unnecessarily
|
||||
s.log.Error("remote %s (%s): %s", r.RemoteAddr, r.RequestURI, err)
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
}
|
||||
|
||||
// handleHTTP handles a single HTTP request
|
||||
func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
s.log.Debug("HandleHTTP request:")
|
||||
s.log.Debug("%v", r)
|
||||
|
||||
if s.httpDirector != nil {
|
||||
s.httpDirector(r)
|
||||
}
|
||||
|
||||
hostPort := strings.ToLower(r.Host)
|
||||
if hostPort == "" {
|
||||
return errors.New("request host is empty")
|
||||
}
|
||||
|
||||
// if someone hits foo.example.com:8080, this should be proxied to
|
||||
// localhost:8080, so send the port to the client so it knows how to proxy
|
||||
// correctly. If no port is available, it's up to client how to interpret it
|
||||
host, port, err := parseHostPort(hostPort)
|
||||
if err != nil {
|
||||
// no need to return, just continue lazily, port will be 0, which in
|
||||
// our case will be proxied to client's local servers port 80
|
||||
s.log.Debug("No port available for %q, sending port 80 to client", hostPort)
|
||||
}
|
||||
|
||||
// get the identifier associated with this host
|
||||
identifier, ok := s.getIdentifier(hostPort)
|
||||
if !ok {
|
||||
// fallback to host
|
||||
identifier, ok = s.getIdentifier(host)
|
||||
if !ok {
|
||||
return fmt.Errorf("no virtual host available for %q", hostPort)
|
||||
}
|
||||
}
|
||||
|
||||
if isWebsocketConn(r) {
|
||||
s.log.Debug("handling websocket connection")
|
||||
|
||||
return s.handleWSConn(w, r, identifier, port)
|
||||
}
|
||||
|
||||
stream, err := s.dial(identifier, proto.HTTP, port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
s.log.Debug("Closing stream")
|
||||
stream.Close()
|
||||
}()
|
||||
|
||||
if err := r.Write(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.Debug("Session opened to client, writing request to client")
|
||||
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read from tunnel: %s", err.Error())
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if resp.Body != nil {
|
||||
if err := resp.Body.Close(); err != nil && err != io.ErrUnexpectedEOF {
|
||||
s.log.Error("resp.Body Close error: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.log.Debug("Response received, writing back to public connection: %+v", resp)
|
||||
|
||||
copyHeader(w.Header(), resp.Header)
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
if _, err := io.Copy(w, resp.Body); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
s.log.Debug("Client closed the connection, couldn't copy response")
|
||||
} else {
|
||||
s.log.Error("copy err: %s", err) // do not return, because we might write multipe headers
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) serveTCP() {
|
||||
for conn := range s.connCh {
|
||||
go s.serveTCPConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) serveTCPConn(conn net.Conn) {
|
||||
err := s.handleTCPConn(conn)
|
||||
if err != nil {
|
||||
s.log.Warning("failed to serve %q: %s", conn.RemoteAddr(), err)
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleWSConn(w http.ResponseWriter, r *http.Request, ident string, port int) error {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return errors.New("webserver doesn't support hijacking")
|
||||
}
|
||||
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return fmt.Errorf("hijack not possible: %s", err)
|
||||
}
|
||||
|
||||
stream, err := s.dial(ident, proto.WS, port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.Write(stream); err != nil {
|
||||
err = errors.New("unable to write upgrade request: " + err.Error())
|
||||
return nonil(err, stream.Close())
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
||||
if err != nil {
|
||||
err = errors.New("unable to read upgrade response: " + err.Error())
|
||||
return nonil(err, stream.Close())
|
||||
}
|
||||
|
||||
if err := resp.Write(conn); err != nil {
|
||||
err = errors.New("unable to write upgrade response: " + err.Error())
|
||||
return nonil(err, stream.Close())
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go s.proxy(&wg, conn, stream)
|
||||
go s.proxy(&wg, stream, conn)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nonil(stream.Close(), conn.Close())
|
||||
}
|
||||
|
||||
func (s *Server) handleTCPConn(conn net.Conn) error {
|
||||
ident, ok := s.virtualAddrs.getIdent(conn)
|
||||
if !ok {
|
||||
return fmt.Errorf("no virtual address available for %s", conn.LocalAddr())
|
||||
}
|
||||
|
||||
_, port, err := parseHostPort(conn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := s.dial(ident, proto.TCP, port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go s.proxy(&wg, conn, stream)
|
||||
go s.proxy(&wg, stream, conn)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nonil(stream.Close(), conn.Close())
|
||||
}
|
||||
|
||||
func (s *Server) proxy(wg *sync.WaitGroup, dst, src net.Conn) {
|
||||
defer wg.Done()
|
||||
|
||||
s.log.Debug("tunneling %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
|
||||
n, err := io.Copy(dst, src)
|
||||
s.log.Debug("tunneled %d bytes %s -> %s: %v", n, src.RemoteAddr(), dst.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
func (s *Server) dial(identifier string, p proto.Type, port int) (net.Conn, error) {
|
||||
control, ok := s.getControl(identifier)
|
||||
if !ok {
|
||||
return nil, errNoClientSession
|
||||
}
|
||||
|
||||
session, err := s.getSession(identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg := proto.ControlMessage{
|
||||
Action: proto.RequestClientSession,
|
||||
Protocol: p,
|
||||
LocalPort: port,
|
||||
}
|
||||
|
||||
s.log.Debug("Sending control msg %+v", msg)
|
||||
|
||||
// ask client to open a session to us, so we can accept it
|
||||
if err := control.send(msg); err != nil {
|
||||
// we might have several issues here, either the stream is closed, or
|
||||
// the session is going be shut down, the underlying connection might
|
||||
// be broken. In all cases, it's not reliable anymore having a client
|
||||
// session.
|
||||
control.Close()
|
||||
s.deleteControl(identifier)
|
||||
return nil, errNoClientSession
|
||||
}
|
||||
|
||||
var stream net.Conn
|
||||
acceptStream := func() error {
|
||||
stream, err = session.Accept()
|
||||
return err
|
||||
}
|
||||
|
||||
// if we don't receive anything from the client, we'll timeout
|
||||
s.log.Debug("Waiting for session accept")
|
||||
|
||||
select {
|
||||
case err := <-async(acceptStream):
|
||||
return stream, err
|
||||
case <-time.After(defaultTimeout):
|
||||
return nil, errors.New("timeout getting session")
|
||||
}
|
||||
}
|
||||
|
||||
// controlHandler is used to capture incoming tunnel connect requests into raw
|
||||
// tunnel TCP connections.
|
||||
func (s *Server) controlHandler(w http.ResponseWriter, r *http.Request) (ctErr error) {
|
||||
identifier := r.Header.Get(proto.ClientIdentifierHeader)
|
||||
_, ok := s.getHost(identifier)
|
||||
if !ok {
|
||||
return fmt.Errorf("no host associated for identifier %s. please use server.AddHost()", identifier)
|
||||
}
|
||||
|
||||
ct, ok := s.getControl(identifier)
|
||||
if ok {
|
||||
ct.Close()
|
||||
s.deleteControl(identifier)
|
||||
s.deleteSession(identifier)
|
||||
s.log.Warning("Control connection for %q already exists. This is a race condition and needs to be fixed on client implementation", identifier)
|
||||
return fmt.Errorf("control conn for %s already exist. \n", identifier)
|
||||
}
|
||||
|
||||
s.log.Debug("Tunnel with identifier %s", identifier)
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return errors.New("webserver doesn't support hijacking")
|
||||
}
|
||||
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return fmt.Errorf("hijack not possible: %s", err)
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(conn, "HTTP/1.1 "+proto.Connected+"\n\n"); err != nil {
|
||||
return fmt.Errorf("error writing response: %s", err)
|
||||
}
|
||||
|
||||
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||
return fmt.Errorf("error setting connection deadline: %s", err)
|
||||
}
|
||||
|
||||
s.log.Debug("Creating control session")
|
||||
session, err := yamux.Server(conn, s.yamuxConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.addSession(identifier, session)
|
||||
|
||||
var stream net.Conn
|
||||
|
||||
// close and delete the session/stream if something goes wrong
|
||||
defer func() {
|
||||
if ctErr != nil {
|
||||
if stream != nil {
|
||||
stream.Close()
|
||||
}
|
||||
s.deleteSession(identifier)
|
||||
}
|
||||
}()
|
||||
|
||||
acceptStream := func() error {
|
||||
stream, err = session.Accept()
|
||||
return err
|
||||
}
|
||||
|
||||
// if we don't receive anything from the client, we'll timeout
|
||||
select {
|
||||
case err := <-async(acceptStream):
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-time.After(time.Second * 10):
|
||||
return errors.New("timeout getting session")
|
||||
}
|
||||
|
||||
s.log.Debug("Initiating handshake protocol")
|
||||
buf := make([]byte, len(proto.HandshakeRequest))
|
||||
if _, err := stream.Read(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(buf) != proto.HandshakeRequest {
|
||||
return fmt.Errorf("handshake aborted. got: %s", string(buf))
|
||||
}
|
||||
|
||||
if _, err := stream.Write([]byte(proto.HandshakeResponse)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// setup control stream and start to listen to messages
|
||||
ct = newControl(stream)
|
||||
s.addControl(identifier, ct)
|
||||
go s.listenControl(ct)
|
||||
|
||||
s.log.Debug("Control connection is setup")
|
||||
return nil
|
||||
}
|
||||
|
||||
// listenControl listens to messages coming from the client.
|
||||
func (s *Server) listenControl(ct *control) {
|
||||
s.onConnect(ct.identifier)
|
||||
|
||||
for {
|
||||
var msg map[string]interface{}
|
||||
err := ct.dec.Decode(&msg)
|
||||
if err != nil {
|
||||
host, _ := s.getHost(ct.identifier)
|
||||
s.log.Debug("Closing client connection: '%s', %s'", host, ct.identifier)
|
||||
|
||||
// close client connection so it reconnects again
|
||||
ct.Close()
|
||||
|
||||
// don't forget to cleanup anything
|
||||
s.deleteControl(ct.identifier)
|
||||
s.deleteSession(ct.identifier)
|
||||
|
||||
s.onDisconnect(ct.identifier, err)
|
||||
|
||||
if err != io.EOF {
|
||||
s.log.Error("decode err: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// right now we don't do anything with the messages, but because the
|
||||
// underlying connection needs to establihsed, we know when we have
|
||||
// disconnection(above), so we can cleanup the connection.
|
||||
s.log.Debug("msg: %s", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// OnConnect invokes a callback for client with given identifier,
|
||||
// when it establishes a control session.
|
||||
// After a client is connected, the associated function
|
||||
// is also removed and needs to be added again.
|
||||
func (s *Server) OnConnect(identifier string, fn func() error) {
|
||||
s.onConnectCallbacks.add(identifier, fn)
|
||||
}
|
||||
|
||||
// onConnect sends notifications to listeners (registered in onConnectCallbacks
|
||||
// or stateChanges chanel readers) when client connects.
|
||||
func (s *Server) onConnect(identifier string) {
|
||||
if err := s.onConnectCallbacks.call(identifier); err != nil {
|
||||
s.log.Error("OnConnect: error calling callback for %q: %s", identifier, err)
|
||||
}
|
||||
|
||||
s.changeState(identifier, ClientConnected, nil)
|
||||
}
|
||||
|
||||
// OnDisconnect calls the function when the client connected with the
|
||||
// associated identifier disconnects from the server.
|
||||
// After a client is disconnected, the associated function
|
||||
// is also removed and needs to be added again.
|
||||
func (s *Server) OnDisconnect(identifier string, fn func() error) {
|
||||
s.onDisconnectCallbacks.add(identifier, fn)
|
||||
}
|
||||
|
||||
// onDisconnect sends notifications to listeners (registered in onDisconnectCallbacks
|
||||
// or stateChanges chanel readers) when client disconnects.
|
||||
func (s *Server) onDisconnect(identifier string, err error) {
|
||||
if err := s.onDisconnectCallbacks.call(identifier); err != nil {
|
||||
s.log.Error("OnDisconnect: error calling callback for %q: %s", identifier, err)
|
||||
}
|
||||
|
||||
s.changeState(identifier, ClientClosed, err)
|
||||
}
|
||||
|
||||
func (s *Server) changeState(identifier string, state ClientState, err error) (prev ClientState) {
|
||||
s.statesMu.Lock()
|
||||
defer s.statesMu.Unlock()
|
||||
|
||||
prev = s.states[identifier]
|
||||
s.states[identifier] = state
|
||||
|
||||
if s.stateCh != nil {
|
||||
change := &ClientStateChange{
|
||||
Identifier: identifier,
|
||||
Previous: prev,
|
||||
Current: state,
|
||||
Error: err,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.stateCh <- change:
|
||||
default:
|
||||
s.log.Warning("Dropping state change due to slow reader: %s", change)
|
||||
}
|
||||
}
|
||||
|
||||
return prev
|
||||
}
|
||||
|
||||
// AddHost adds the given virtual host and maps it to the identifier.
|
||||
func (s *Server) AddHost(host, identifier string) {
|
||||
s.virtualHosts.AddHost(host, identifier)
|
||||
}
|
||||
|
||||
// DeleteHost deletes the given virtual host. Once removed any request to this
|
||||
// host is denied.
|
||||
func (s *Server) DeleteHost(host string) {
|
||||
s.virtualHosts.DeleteHost(host)
|
||||
}
|
||||
|
||||
// AddAddr starts accepting connections on listener l, routing every connection
|
||||
// to a tunnel client given by the identifier.
|
||||
//
|
||||
// When ip parameter is nil, all connections accepted from the listener are
|
||||
// routed to the tunnel client specified by the identifier (port-based routing).
|
||||
//
|
||||
// When ip parameter is non-nil, only those connections are routed whose local
|
||||
// address matches the specified ip (ip-based routing).
|
||||
//
|
||||
// If l listens on multiple interfaces it's desirable to call AddAddr multiple
|
||||
// times with the same l value but different ip one.
|
||||
func (s *Server) AddAddr(l net.Listener, ip net.IP, identifier string) {
|
||||
s.virtualAddrs.Add(l, ip, identifier)
|
||||
}
|
||||
|
||||
// DeleteAddr stops listening for connections on the given listener.
|
||||
//
|
||||
// Upon return no more connections will be tunneled, but as the method does not
|
||||
// close the listener, so any ongoing connection won't get interrupted.
|
||||
func (s *Server) DeleteAddr(l net.Listener, ip net.IP) {
|
||||
s.virtualAddrs.Delete(l, ip)
|
||||
}
|
||||
|
||||
func (s *Server) getIdentifier(host string) (string, bool) {
|
||||
identifier, ok := s.virtualHosts.GetIdentifier(host)
|
||||
return identifier, ok
|
||||
}
|
||||
|
||||
func (s *Server) getHost(identifier string) (string, bool) {
|
||||
host, ok := s.virtualHosts.GetHost(identifier)
|
||||
return host, ok
|
||||
}
|
||||
|
||||
func (s *Server) addControl(identifier string, conn *control) {
|
||||
s.controls.addControl(identifier, conn)
|
||||
}
|
||||
|
||||
func (s *Server) getControl(identifier string) (*control, bool) {
|
||||
return s.controls.getControl(identifier)
|
||||
}
|
||||
|
||||
func (s *Server) deleteControl(identifier string) {
|
||||
s.controls.deleteControl(identifier)
|
||||
}
|
||||
|
||||
func (s *Server) getSession(identifier string) (*yamux.Session, error) {
|
||||
s.sessionsMu.Lock()
|
||||
session, ok := s.sessions[identifier]
|
||||
s.sessionsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no session available for identifier: '%s'", identifier)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *Server) addSession(identifier string, session *yamux.Session) {
|
||||
s.sessionsMu.Lock()
|
||||
s.sessions[identifier] = session
|
||||
s.sessionsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) deleteSession(identifier string) {
|
||||
s.sessionsMu.Lock()
|
||||
defer s.sessionsMu.Unlock()
|
||||
|
||||
session, ok := s.sessions[identifier]
|
||||
|
||||
if !ok {
|
||||
return // nothing to delete
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
session.GoAway() // don't accept any new connection
|
||||
session.Close()
|
||||
}
|
||||
|
||||
delete(s.sessions, identifier)
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, v := range src {
|
||||
vv := make([]string, len(v))
|
||||
copy(vv, v)
|
||||
dst[k] = vv
|
||||
}
|
||||
}
|
||||
|
||||
// checkConnect checks whether the incoming request is HTTP CONNECT method.
|
||||
func (s *Server) checkConnect(fn func(w http.ResponseWriter, r *http.Request) error) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "CONNECT" {
|
||||
http.Error(w, "405 must CONNECT\n", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if err := fn(w, r); err != nil {
|
||||
s.log.Error("Handler err: %v", err.Error())
|
||||
|
||||
if identifier := r.Header.Get(proto.ClientIdentifierHeader); identifier != "" {
|
||||
s.onDisconnect(identifier, err)
|
||||
}
|
||||
|
||||
http.Error(w, err.Error(), 502)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func parseHostPort(addr string) (string, int, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
n, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return host, int(n), nil
|
||||
}
|
||||
|
||||
func isWebsocketConn(r *http.Request) bool {
|
||||
return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") &&
|
||||
headerContains(r.Header["Upgrade"], "websocket")
|
||||
}
|
||||
|
||||
// headerContains is a copy of tokenListContainsValue from gorilla/websocket/util.go
|
||||
func headerContains(header []string, value string) bool {
|
||||
for _, h := range header {
|
||||
for _, v := range strings.Split(h, ",") {
|
||||
if strings.EqualFold(strings.TrimSpace(v), value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func nonil(err ...error) error {
|
||||
for _, e := range err {
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newLogger(name string, debug bool) logging.Logger {
|
||||
log := logging.NewLogger(name)
|
||||
logHandler := logging.NewWriterHandler(os.Stderr)
|
||||
logHandler.Colorize = true
|
||||
log.SetHandler(logHandler)
|
||||
|
||||
if debug {
|
||||
log.SetLevel(logging.DEBUG)
|
||||
logHandler.SetLevel(logging.DEBUG)
|
||||
}
|
||||
|
||||
return log
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
# Specification
|
||||
|
||||
# Naming conventions
|
||||
|
||||
* `server` is listening to public connection and is responsible of routing
|
||||
public HTTP requests to clients.
|
||||
* `client` is a long running process, connected to a server and running on a local machine.
|
||||
* `virtualHost` is a virtual domain that maps a domain to a single client. i.e:
|
||||
`arslan.koding.io` is a virtualhost which is mapped to my `client` running on
|
||||
my local machine.
|
||||
* `identifier` is a secret token, which is not meant to be shared with others.
|
||||
An identifier is responsible of mapping a virtualhost to a client.
|
||||
* `session` is a single TCP connection which uses the library `yamux`. A
|
||||
session can be created either via `yamux.Server()` or `yamux.Client`
|
||||
* `stream` is a `net.Conn` compatible `virtual` connection that is multiplexed
|
||||
over the `session`. A session can have hundreds of thousands streams
|
||||
* `control connection` is a single `stream` which is used to communicate and
|
||||
handle messaging between server and client. It uses a custom protocol which
|
||||
is JSON encoded.
|
||||
* `tunnel connection` is a single `stream` which is used to proxy public HTTP
|
||||
requests from the `server` to the `client` and vice versa. A single `tunnel`
|
||||
connection is created for every single HTTP requests.
|
||||
* `public connection` is a connection from a remote machine to the `server`
|
||||
* `ControlHandler` is a http.Handler which listens to requests coming to
|
||||
`/_controlPath_/`. It's used to setup the initial `session` connection from
|
||||
`client` to `server`. And creates the `control connection` from this session.
|
||||
server and client, and also for all additional new tunnel. It literally
|
||||
captures the incoming HTTP request and hijacks it and converts it into RAW TCP,
|
||||
which then is used as the foundation for all yamux `sessions.`
|
||||
|
||||
|
||||
# Server
|
||||
1. Server is created with `NewServer()` which returns `*Server`, a `http.Handler`
|
||||
compatible type. Plug into any HTTP server you want. The root path `"/"` is
|
||||
recommended to listen and proxy any tunnels. It also listens to any request
|
||||
coming to `ControlHandler`
|
||||
2. Tunneling is based on virtual hosts. A virtual hosts is identified with an
|
||||
unique identifier. This identifier is the only piece that both client and
|
||||
server needs to known ahead. Think of it as a secret token.
|
||||
3. To add a virtual host, call `server.AddHost(virtualHost, identifier)`. This
|
||||
step needs to be done from the server itself. This can be could manually or
|
||||
via custom auth based HTTP handlers, such as "/addhost", which adds
|
||||
virtualhosts and returns the `identifier` to the requester (in our case `client`)
|
||||
4. A DNS record and it's subdomains needs to point to a `server`, so it can
|
||||
handle virtual hosts, i.e: `*.example.com` is routed to a server, which can
|
||||
handle `foo.example.com`, `bar.example.com`, etc..
|
||||
|
||||
|
||||
# Client
|
||||
|
||||
1. Client is created with `NewClient(serverAddr, localAddr)` which returns a
|
||||
`*Client`. Here `serverAddr` is the TCP address to the server. `localAddr`
|
||||
is the server in which all public requests are forwarded to. It's optional if
|
||||
you want it to be done dynamically
|
||||
2. Once a client is created, it starts with `client.Start(identifier)`. Here
|
||||
`identifier` is needed upfront. This method creates the initial TCP
|
||||
connection to the server. It sends the identifier back to the server. This
|
||||
TCP connection is used as the foundation for `yamux.Client()`. Once a yamux
|
||||
session is established, we are able to use this single connection to have
|
||||
multiple streams, which are multiplexed over this one connection. A `control
|
||||
connection` is created and client starts to listen it. `client.Start` is
|
||||
blocking.
|
||||
|
||||
# Control Handshake
|
||||
|
||||
1. Client sends a `handshakeRequest` over the `control connection` stream
|
||||
2. The server sends back a `handshakeResponse` to the client over the `control connection` stream
|
||||
3. Once the client receives the `handshakeResponse` from the server, it starts
|
||||
to listen from the `control connection` stream.
|
||||
4. A `control connection` is json.Encoder/Decoder both for server and client
|
||||
|
||||
|
||||
# Tunnel creation
|
||||
1. When the server receives a public connection, it checks the HTTP host
|
||||
headers and retrieves the corresponding identifier from the given host.
|
||||
2. The server retrieves the `control connection` which was associated with this
|
||||
`identifier` and sends a `ControlMsg` message with the action
|
||||
`RequestClientSession`. This message is in the form of:
|
||||
|
||||
type ControlMsg struct {
|
||||
Action Action `json:"action"`
|
||||
Protocol TransportProtocol `json:"transportProtocol"`
|
||||
LocalPort string `json:"localPort"`
|
||||
}
|
||||
|
||||
Here the `LocalPort` is read from the HTTP Host header. If absent a zero
|
||||
port is sent and client maps it to the local server running at port 80, unless
|
||||
the `localAddr` is specified in `client.Start()` method. `Protocol` is
|
||||
reserved for future features.
|
||||
3. The server immediately starts to listen(accept) to a new `stream`. This is
|
||||
blocking and it waits there.
|
||||
4. When the client receives the `RequestClientSession` message, it opens a new
|
||||
`virtual` TCP connection, a `stream` to the server.
|
||||
5. The server which was waiting for a new stream in step 3, establish the stream.
|
||||
6. The server copies the request over the stream to the client.
|
||||
7. The client copies the request coming from the server to the local server and
|
||||
copies back the result to the server
|
||||
8. The server reads the response coming from the client and returns back it to
|
||||
the public connection requester
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/koding/logging"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
tpcLog = logging.NewLogger("tcp")
|
||||
)
|
||||
|
||||
// TCPProxy forwards TCP streams.
|
||||
//
|
||||
// If port-based routing is used, LocalAddr or FetchLocalAddr field is required
|
||||
// for tunneling to function properly.
|
||||
// Otherwise you'll be forwarding traffic to random ports and this is usually not desired.
|
||||
//
|
||||
// If IP-based routing is used then tunnel server connection request is
|
||||
// proxied to 127.0.0.1:incomingPort where incomingPort is control message LocalPort.
|
||||
// Usually this is tunnel server's public exposed Port.
|
||||
// This behaviour can be changed by setting LocalAddr or FetchLocalAddr.
|
||||
// FetchLocalAddr takes precedence over LocalAddr.
|
||||
type TCPProxy struct {
|
||||
// LocalAddr defines the TCP address of the local server.
|
||||
// This is optional if you want to specify a single TCP address.
|
||||
LocalAddr string
|
||||
// FetchLocalAddr is used for looking up TCP address of the server.
|
||||
// This is optional if you want to specify a dynamic TCP address based on incommig port.
|
||||
FetchLocalAddr func(port int) (string, error)
|
||||
// Log is a custom logger that can be used for the proxy.
|
||||
// If not set a "tcp" logger is used.
|
||||
Log logging.Logger
|
||||
}
|
||||
|
||||
// Proxy is a ProxyFunc.
|
||||
func (p *TCPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) {
|
||||
if msg.Protocol != proto.TCP {
|
||||
panic("Proxy mismatch")
|
||||
}
|
||||
|
||||
var log = p.log()
|
||||
|
||||
var port = msg.LocalPort
|
||||
if port == 0 {
|
||||
log.Warning("TCP proxy to port 0")
|
||||
}
|
||||
|
||||
var localAddr = fmt.Sprintf("127.0.0.1:%d", port)
|
||||
if p.LocalAddr != "" {
|
||||
localAddr = p.LocalAddr
|
||||
} else if p.FetchLocalAddr != nil {
|
||||
l, err := p.FetchLocalAddr(msg.LocalPort)
|
||||
if err != nil {
|
||||
log.Warning("Failed to get custom local address: %s", err)
|
||||
return
|
||||
}
|
||||
localAddr = l
|
||||
}
|
||||
|
||||
log.Debug("Dialing local server: %q", localAddr)
|
||||
local, err := net.DialTimeout("tcp", localAddr, defaultTimeout)
|
||||
if err != nil {
|
||||
log.Error("Dialing local server %q failed: %s", localAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
Join(local, remote, log)
|
||||
}
|
||||
|
||||
func (p *TCPProxy) log() logging.Logger {
|
||||
if p.Log != nil {
|
||||
return p.Log
|
||||
}
|
||||
return tpcLog
|
||||
}
|
|
@ -0,0 +1,412 @@
|
|||
package tunnel_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.xeserv.us/xena/route/lib/tunnel"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
)
|
||||
|
||||
func TestMultipleRequest(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
// make a request to tunnelserver, this should be tunneled to local server
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
msg := "hello" + strconv.Itoa(i)
|
||||
res, err := echoHTTP(tt, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("echoHTTP error: %s", err)
|
||||
}
|
||||
|
||||
if res != msg {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestMultipleLatencyRequest(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
// make a request to tunnelserver, this should be tunneled to local server
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
msg := "hello" + strconv.Itoa(i)
|
||||
res, err := echoHTTP(tt, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("echoHTTP error: %s", err)
|
||||
}
|
||||
|
||||
if res != msg {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestReconnectClient(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
msg := "hello"
|
||||
res, err := echoHTTP(tt, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("echoHTTP error: %s", err)
|
||||
}
|
||||
|
||||
if res != msg {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
|
||||
client := tt.Clients["http"]
|
||||
|
||||
// close client, and start it again
|
||||
client.Close()
|
||||
|
||||
go client.Start()
|
||||
<-client.StartNotify()
|
||||
|
||||
msg = "helloagain"
|
||||
res, err = echoHTTP(tt, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("echoHTTP error: %s", err)
|
||||
}
|
||||
|
||||
if res != msg {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoClient(t *testing.T) {
|
||||
const expectedErr = "no client session established"
|
||||
|
||||
rec := tunneltest.NewStateRecorder()
|
||||
|
||||
tt, err := tunneltest.Serve(singleRecHTTP(handlerEchoHTTP, rec.C()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
if err := rec.WaitTransitions(
|
||||
tunnel.ClientStarted,
|
||||
tunnel.ClientConnecting,
|
||||
tunnel.ClientConnected,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := tt.ServerStateRecorder.WaitTransition(
|
||||
tunnel.ClientUnknown,
|
||||
tunnel.ClientConnected,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// close client, this is the main point of the test
|
||||
if err := tt.Clients["http"].Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := rec.WaitTransitions(
|
||||
tunnel.ClientConnected,
|
||||
tunnel.ClientDisconnected,
|
||||
tunnel.ClientClosed,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := tt.ServerStateRecorder.WaitTransition(
|
||||
tunnel.ClientConnected,
|
||||
tunnel.ClientClosed,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msg := "hello"
|
||||
res, err := echoHTTP(tt, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("echoHTTP error: %s", err)
|
||||
}
|
||||
|
||||
if res != expectedErr {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoHost(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
noBackoff := backoff.NewConstantBackOff(time.Duration(-1))
|
||||
|
||||
unknown, err := tunnel.NewClient(&tunnel.ClientConfig{
|
||||
Identifier: "unknown",
|
||||
ServerAddr: tt.ServerAddr().String(),
|
||||
Backoff: noBackoff,
|
||||
Debug: testing.Verbose(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("client error: %s", err)
|
||||
}
|
||||
unknown.Start()
|
||||
defer unknown.Close()
|
||||
|
||||
if err := tt.ServerStateRecorder.WaitTransition(
|
||||
tunnel.ClientUnknown,
|
||||
tunnel.ClientClosed,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
unknown.Start()
|
||||
if err := tt.ServerStateRecorder.WaitTransition(
|
||||
tunnel.ClientClosed,
|
||||
tunnel.ClientClosed,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoLocalServer(t *testing.T) {
|
||||
const expectedErr = "no local server"
|
||||
|
||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
// close local listener, this is the main point of the test
|
||||
tt.Listeners["http"][0].Close()
|
||||
|
||||
msg := "hello"
|
||||
res, err := echoHTTP(tt, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("echoHTTP error: %s", err)
|
||||
}
|
||||
|
||||
if res != expectedErr {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleRequest(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
msg := "hello"
|
||||
res, err := echoHTTP(tt, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("echoHTTP error: %s", err)
|
||||
}
|
||||
|
||||
if res != msg {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleLatencyRequest(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
msg := "hello"
|
||||
res, err := echoHTTP(tt, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("echoHTTP error: %s", err)
|
||||
}
|
||||
|
||||
if res != msg {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleTCP(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
msg := "hello"
|
||||
res, err := echoTCP(tt, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("echoTCP error: %s", err)
|
||||
}
|
||||
|
||||
if msg != res {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleTCP(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
msg := "hello" + strconv.Itoa(i)
|
||||
res, err := echoTCP(tt, msg)
|
||||
if err != nil {
|
||||
t.Errorf("echoTCP: %s", err)
|
||||
}
|
||||
|
||||
if res != msg {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestMultipleLatencyTCP(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleTCP(handlerLatencyEchoTCP))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
msg := "hello" + strconv.Itoa(i)
|
||||
res, err := echoTCP(tt, msg)
|
||||
if err != nil {
|
||||
t.Errorf("echoTCP: %s", err)
|
||||
}
|
||||
|
||||
if res != msg {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestMultipleStreamTCP(t *testing.T) {
|
||||
tunnels := map[string]*tunneltest.Tunnel{
|
||||
"http": {
|
||||
Type: tunneltest.TypeHTTP,
|
||||
LocalAddr: "127.0.0.1:0",
|
||||
Handler: handlerEchoHTTP,
|
||||
},
|
||||
"tcp": {
|
||||
Type: tunneltest.TypeTCP,
|
||||
ClientIdent: "http",
|
||||
LocalAddr: "127.0.0.1:0",
|
||||
RemoteAddr: "127.0.0.1:0",
|
||||
Handler: handlerEchoTCP,
|
||||
},
|
||||
"tcp_all": {
|
||||
Type: tunneltest.TypeTCP,
|
||||
ClientIdent: "http",
|
||||
LocalAddr: "127.0.0.1:0",
|
||||
RemoteAddr: "0.0.0.0:0",
|
||||
Handler: handlerEchoTCP,
|
||||
},
|
||||
}
|
||||
|
||||
addrs, err := tunneltest.UsableAddrs()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clients := []string{"tcp"}
|
||||
for i, addr := range addrs {
|
||||
if addr.IP.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
client := fmt.Sprintf("tcp_%d", i)
|
||||
|
||||
tunnels[client] = &tunneltest.Tunnel{
|
||||
Type: tunneltest.TypeTCP,
|
||||
ClientIdent: "http",
|
||||
LocalAddr: "127.0.0.1:0",
|
||||
RemoteAddrIdent: "tcp_all",
|
||||
IP: addr.IP,
|
||||
Handler: handlerEchoTCP,
|
||||
}
|
||||
|
||||
clients = append(clients, client)
|
||||
}
|
||||
|
||||
tt, err := tunneltest.Serve(tunnels)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tt.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100/len(clients); i++ {
|
||||
wg.Add(len(clients))
|
||||
|
||||
for j, ident := range clients {
|
||||
go func(ident string, i, j int) {
|
||||
defer wg.Done()
|
||||
msg := fmt.Sprintf("hello_%d_client_%d", j, i)
|
||||
res, err := echoTCPIdent(tt, msg, ident)
|
||||
if err != nil {
|
||||
t.Errorf("echoTCP: %s", err)
|
||||
}
|
||||
|
||||
if res != msg {
|
||||
t.Errorf("got %q, want %q", res, msg)
|
||||
}
|
||||
}(ident, i, j)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package tunneltest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.xeserv.us/xena/route/lib/tunnel"
|
||||
)
|
||||
|
||||
var (
|
||||
recWaitTimeout = 5 * time.Second
|
||||
recBuffer = 32
|
||||
)
|
||||
|
||||
// States is a sequence of client state changes.
|
||||
type States []*tunnel.ClientStateChange
|
||||
|
||||
func (s States) String() string {
|
||||
if len(s) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
fmt.Fprintf(&buf, "[%s", s[0].String())
|
||||
|
||||
for _, s := range s[1:] {
|
||||
fmt.Fprintf(&buf, ",%s", s.String())
|
||||
}
|
||||
|
||||
buf.WriteRune(']')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// StateRecorder saves state changes pushed to StateRecorder.C().
|
||||
type StateRecorder struct {
|
||||
mu sync.Mutex
|
||||
ch chan *tunnel.ClientStateChange
|
||||
recorded []*tunnel.ClientStateChange
|
||||
offset int
|
||||
}
|
||||
|
||||
func NewStateRecorder() *StateRecorder {
|
||||
rec := &StateRecorder{
|
||||
ch: make(chan *tunnel.ClientStateChange, recBuffer),
|
||||
}
|
||||
|
||||
go rec.record()
|
||||
|
||||
return rec
|
||||
}
|
||||
|
||||
func (rec *StateRecorder) record() {
|
||||
for state := range rec.ch {
|
||||
rec.mu.Lock()
|
||||
rec.recorded = append(rec.recorded, state)
|
||||
rec.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *StateRecorder) C() chan<- *tunnel.ClientStateChange {
|
||||
return rec.ch
|
||||
}
|
||||
|
||||
func (rec *StateRecorder) WaitTransitions(states ...tunnel.ClientState) error {
|
||||
from := states[0]
|
||||
for _, to := range states[1:] {
|
||||
if err := rec.WaitTransition(from, to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
from = to
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rec *StateRecorder) WaitTransition(from, to tunnel.ClientState) error {
|
||||
timeout := time.After(recWaitTimeout)
|
||||
|
||||
var lastStates []*tunnel.ClientStateChange
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return fmt.Errorf("timed out waiting for %s->%s transition: %v", from, to, States(lastStates))
|
||||
default:
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
lastStates = rec.States()[rec.offset:]
|
||||
|
||||
for i, state := range lastStates {
|
||||
if from != 0 && state.Previous != from {
|
||||
continue
|
||||
}
|
||||
|
||||
if to != 0 && state.Current != to {
|
||||
continue
|
||||
}
|
||||
|
||||
rec.offset += i
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *StateRecorder) States() []*tunnel.ClientStateChange {
|
||||
rec.mu.Lock()
|
||||
defer rec.mu.Unlock()
|
||||
|
||||
states := make([]*tunnel.ClientStateChange, len(rec.recorded))
|
||||
copy(states, rec.recorded)
|
||||
return states
|
||||
}
|
|
@ -0,0 +1,561 @@
|
|||
package tunneltest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.xeserv.us/xena/route/lib/tunnel"
|
||||
)
|
||||
|
||||
var debugNet = os.Getenv("DEBUGNET") == "1"
|
||||
|
||||
type dbgListener struct {
|
||||
net.Listener
|
||||
}
|
||||
|
||||
func (l dbgListener) Accept() (net.Conn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dbgConn{conn}, nil
|
||||
}
|
||||
|
||||
type dbgConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c dbgConn) Read(p []byte) (int, error) {
|
||||
n, err := c.Conn.Read(p)
|
||||
os.Stderr.Write(p)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c dbgConn) Write(p []byte) (int, error) {
|
||||
n, err := c.Conn.Write(p)
|
||||
os.Stderr.Write(p)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func logf(format string, args ...interface{}) {
|
||||
if testing.Verbose() {
|
||||
log.Printf("[tunneltest] "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func nonil(err ...error) error {
|
||||
for _, e := range err {
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseHostPort(addr string) (string, int, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
n, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return host, int(n), nil
|
||||
}
|
||||
|
||||
// UsableAddrs returns all tcp addresses that we can bind a listener to.
|
||||
func UsableAddrs() ([]*net.TCPAddr, error) {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var usable []*net.TCPAddr
|
||||
for _, addr := range addrs {
|
||||
if ipNet, ok := addr.(*net.IPNet); ok {
|
||||
if !ipNet.IP.IsLinkLocalUnicast() {
|
||||
usable = append(usable, &net.TCPAddr{IP: ipNet.IP})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(usable) == 0 {
|
||||
return nil, errors.New("no usable addresses found")
|
||||
}
|
||||
|
||||
return usable, nil
|
||||
}
|
||||
|
||||
const (
|
||||
TypeHTTP = iota
|
||||
TypeTCP
|
||||
)
|
||||
|
||||
// Tunnel represents a single HTTP or TCP tunnel that can be served
|
||||
// by TunnelTest.
|
||||
type Tunnel struct {
|
||||
// Type specifies a tunnel type - either TypeHTTP (default) or TypeTCP.
|
||||
Type int
|
||||
|
||||
// Handler is a handler to use for serving tunneled connections on
|
||||
// local server. The value of this field is required to be of type:
|
||||
//
|
||||
// - http.Handler or http.HandlerFunc for HTTP tunnels
|
||||
// - func(net.Conn) for TCP tunnels
|
||||
//
|
||||
// Required field.
|
||||
Handler interface{}
|
||||
|
||||
// LocalAddr is a network address of local server that handles
|
||||
// connections/requests with Handler.
|
||||
//
|
||||
// Optional field, takes value of "127.0.0.1:0" when empty.
|
||||
LocalAddr string
|
||||
|
||||
// ClientIdent is an identifier of a client that have already
|
||||
// registered a HTTP tunnel and have established control connection.
|
||||
//
|
||||
// If the Type is TypeTCP, instead of creating new client
|
||||
// for this TCP tunnel, we add it to an existing client
|
||||
// specified by the field.
|
||||
//
|
||||
// Optional field for TCP tunnels.
|
||||
// Ignored field for HTTP tunnels.
|
||||
ClientIdent string
|
||||
|
||||
// RemoteAddr is a network address of remote server, which accepts
|
||||
// connections on a tunnel server side.
|
||||
//
|
||||
// Required field for TCP tunnels.
|
||||
// Ignored field for HTTP tunnels.
|
||||
RemoteAddr string
|
||||
|
||||
// RemoteAddrIdent an identifier of an already existing listener,
|
||||
// that listens on multiple interfaces; if the RemoteAddrIdent is valid
|
||||
// identifier the IP field is required to be non-nil and RemoteAddr
|
||||
// is ignored.
|
||||
//
|
||||
// Optional field for TCP tunnels.
|
||||
// Ignored field for HTTP tunnels.
|
||||
RemoteAddrIdent string
|
||||
|
||||
// IP specifies an IP address value for IP-based routing for TCP tunnels.
|
||||
// For more details see inline documentation for (*tunnel.Server).AddAddr.
|
||||
//
|
||||
// Optional field for TCP tunnels.
|
||||
// Ignored field for HTTP tunnels.
|
||||
IP net.IP
|
||||
|
||||
// StateChanges listens on state transitions.
|
||||
//
|
||||
// If ClientIdent field is empty, the StateChanges will receive
|
||||
// state transition events for the newly created client.
|
||||
// Otherwise setting this field is a nop.
|
||||
StateChanges chan<- *tunnel.ClientStateChange
|
||||
}
|
||||
|
||||
type TunnelTest struct {
|
||||
Server *tunnel.Server
|
||||
ServerStateRecorder *StateRecorder
|
||||
Clients map[string]*tunnel.Client
|
||||
Listeners map[string][2]net.Listener // [0] is local listener, [1] is remote one (for TCP tunnels)
|
||||
Addrs []*net.TCPAddr
|
||||
Tunnels map[string]*Tunnel
|
||||
DebugNet bool // for debugging network communication
|
||||
|
||||
mu sync.Mutex // protects Listeners
|
||||
}
|
||||
|
||||
func NewTunnelTest() (*TunnelTest, error) {
|
||||
rec := NewStateRecorder()
|
||||
|
||||
cfg := &tunnel.ServerConfig{
|
||||
StateChanges: rec.C(),
|
||||
Debug: testing.Verbose(),
|
||||
}
|
||||
s, err := tunnel.NewServer(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if debugNet {
|
||||
l = dbgListener{l}
|
||||
}
|
||||
|
||||
addrs, err := UsableAddrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go (&http.Server{Handler: s}).Serve(l)
|
||||
|
||||
return &TunnelTest{
|
||||
Server: s,
|
||||
ServerStateRecorder: rec,
|
||||
Clients: make(map[string]*tunnel.Client),
|
||||
Listeners: map[string][2]net.Listener{"": {l, nil}},
|
||||
Addrs: addrs,
|
||||
Tunnels: make(map[string]*Tunnel),
|
||||
DebugNet: debugNet,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Serve creates new TunnelTest that serves the given tunnels.
|
||||
//
|
||||
// If tunnels is nil, DefaultTunnels() are used instead.
|
||||
func Serve(tunnels map[string]*Tunnel) (*TunnelTest, error) {
|
||||
tt, err := NewTunnelTest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = tt.Serve(tunnels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tt, nil
|
||||
}
|
||||
|
||||
func (tt *TunnelTest) serveSingle(ident string, t *Tunnel) (bool, error) {
|
||||
// Verify tunnel dependencies for TCP tunnels.
|
||||
if t.Type == TypeTCP {
|
||||
// If tunnel specified by t.Client was not already started,
|
||||
// skip and move on.
|
||||
if _, ok := tt.Clients[t.ClientIdent]; t.ClientIdent != "" && !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Verify the TCP tunnel whose remote endpoint listens on multiple
|
||||
// interfaces is already served.
|
||||
if t.RemoteAddrIdent != "" {
|
||||
if _, ok := tt.Listeners[t.RemoteAddrIdent]; !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if tt.Tunnels[t.RemoteAddrIdent].Type != TypeTCP {
|
||||
return false, fmt.Errorf("expected tunnel %q to be of TCP type", t.RemoteAddrIdent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", t.LocalAddr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.LocalAddr, ident, err)
|
||||
}
|
||||
|
||||
if tt.DebugNet {
|
||||
l = dbgListener{l}
|
||||
}
|
||||
|
||||
localAddr := l.Addr().String()
|
||||
httpProxy := &tunnel.HTTPProxy{LocalAddr: localAddr}
|
||||
tcpProxy := &tunnel.TCPProxy{FetchLocalAddr: tt.fetchLocalAddr}
|
||||
|
||||
cfg := &tunnel.ClientConfig{
|
||||
Identifier: ident,
|
||||
ServerAddr: tt.ServerAddr().String(),
|
||||
Proxy: tunnel.Proxy(tunnel.ProxyFuncs{
|
||||
HTTP: httpProxy.Proxy,
|
||||
TCP: tcpProxy.Proxy,
|
||||
}),
|
||||
StateChanges: t.StateChanges,
|
||||
Debug: testing.Verbose(),
|
||||
}
|
||||
|
||||
// Register tunnel:
|
||||
//
|
||||
// - start tunnel.Client (tt.Clients[ident]) or reuse existing one (tt.Clients[t.ExistingClient])
|
||||
// - listen on local address and start local server (tt.Listeners[ident][0])
|
||||
// - register tunnel on tunnel.Server
|
||||
//
|
||||
switch t.Type {
|
||||
case TypeHTTP:
|
||||
// TODO(rjeczalik): refactor to separate method
|
||||
|
||||
h, ok := t.Handler.(http.Handler)
|
||||
if !ok {
|
||||
h, ok = t.Handler.(http.HandlerFunc)
|
||||
if !ok {
|
||||
fn, ok := t.Handler.(func(http.ResponseWriter, *http.Request))
|
||||
if !ok {
|
||||
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
|
||||
}
|
||||
|
||||
h = http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
logf("serving on local %s for HTTP tunnel %q", l.Addr(), ident)
|
||||
|
||||
go (&http.Server{Handler: h}).Serve(l)
|
||||
|
||||
tt.Server.AddHost(localAddr, ident)
|
||||
|
||||
tt.mu.Lock()
|
||||
tt.Listeners[ident] = [2]net.Listener{l, nil}
|
||||
tt.mu.Unlock()
|
||||
|
||||
if err := tt.addClient(ident, cfg); err != nil {
|
||||
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
|
||||
}
|
||||
|
||||
logf("registered HTTP tunnel: host=%s, ident=%s", localAddr, ident)
|
||||
|
||||
case TypeTCP:
|
||||
// TODO(rjeczalik): refactor to separate method
|
||||
|
||||
h, ok := t.Handler.(func(net.Conn))
|
||||
if !ok {
|
||||
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
|
||||
}
|
||||
|
||||
logf("serving on local %s for TCP tunnel %q", l.Addr(), ident)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Printf("failed accepting conn for %q tunnel: %s", ident, err)
|
||||
return
|
||||
}
|
||||
|
||||
go h(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
var remote net.Listener
|
||||
|
||||
if t.RemoteAddrIdent != "" {
|
||||
tt.mu.Lock()
|
||||
remote = tt.Listeners[t.RemoteAddrIdent][1]
|
||||
tt.mu.Unlock()
|
||||
} else {
|
||||
remote, err = net.Listen("tcp", t.RemoteAddr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.RemoteAddr, ident, err)
|
||||
}
|
||||
}
|
||||
|
||||
// addrIdent holds identifier of client which is going to have registered
|
||||
// tunnel via (*tunnel.Server).AddAddr
|
||||
addrIdent := ident
|
||||
if t.ClientIdent != "" {
|
||||
tt.Clients[ident] = tt.Clients[t.ClientIdent]
|
||||
addrIdent = t.ClientIdent
|
||||
}
|
||||
|
||||
tt.Server.AddAddr(remote, t.IP, addrIdent)
|
||||
|
||||
tt.mu.Lock()
|
||||
tt.Listeners[ident] = [2]net.Listener{l, remote}
|
||||
tt.mu.Unlock()
|
||||
|
||||
if _, ok := tt.Clients[ident]; !ok {
|
||||
if err := tt.addClient(ident, cfg); err != nil {
|
||||
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
|
||||
}
|
||||
}
|
||||
|
||||
logf("registered TCP tunnel: listener=%s, ip=%v, ident=%s", remote.Addr(), t.IP, addrIdent)
|
||||
|
||||
default:
|
||||
return false, fmt.Errorf("unknown %q tunnel type: %d", ident, t.Type)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (tt *TunnelTest) addClient(ident string, cfg *tunnel.ClientConfig) error {
|
||||
if _, ok := tt.Clients[ident]; ok {
|
||||
return fmt.Errorf("tunnel %q is already being served", ident)
|
||||
}
|
||||
|
||||
c, err := tunnel.NewClient(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
tt.Server.OnConnect(ident, func() error {
|
||||
close(done)
|
||||
return nil
|
||||
})
|
||||
|
||||
go c.Start()
|
||||
<-c.StartNotify()
|
||||
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
return errors.New("timed out after 10s waiting on client to establish control conn")
|
||||
case <-done:
|
||||
}
|
||||
|
||||
tt.Clients[ident] = c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tt *TunnelTest) Serve(tunnels map[string]*Tunnel) error {
|
||||
if len(tunnels) == 0 {
|
||||
return errors.New("no tunnels to serve")
|
||||
}
|
||||
|
||||
// Since one tunnels depends on others do 3 passes to start them
|
||||
// all, each started tunnel is removed from the tunnels map.
|
||||
// After 3 passes all of them must be started, otherwise the
|
||||
// configuration is bad:
|
||||
//
|
||||
// - first pass starts HTTP tunnels as new client tunnels
|
||||
// - second pass starts TCP tunnels that rely on on already existing client tunnels (t.ClientIdent)
|
||||
// - third pass starts TCP tunnels that rely on on already existing TCP tunnels (t.RemoteAddrIdent)
|
||||
//
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := tt.popServedDeps(tunnels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(tunnels) != 0 {
|
||||
unresolved := make([]string, len(tunnels))
|
||||
for ident := range tunnels {
|
||||
unresolved = append(unresolved, ident)
|
||||
}
|
||||
sort.Strings(unresolved)
|
||||
|
||||
return fmt.Errorf("unable to start tunnels due to unresolved dependencies: %v", unresolved)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tt *TunnelTest) popServedDeps(tunnels map[string]*Tunnel) error {
|
||||
for ident, t := range tunnels {
|
||||
ok, err := tt.serveSingle(ident, t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok {
|
||||
// Remove already started tunnels so they won't get started again.
|
||||
delete(tunnels, ident)
|
||||
tt.Tunnels[ident] = t
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tt *TunnelTest) fetchLocalAddr(port int) (string, error) {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
for _, l := range tt.Listeners {
|
||||
if l[1] == nil {
|
||||
// this listener does not belong to a TCP tunnel
|
||||
continue
|
||||
}
|
||||
|
||||
_, remotePort, err := parseHostPort(l[1].Addr().String())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if port == remotePort {
|
||||
return l[0].Addr().String(), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no route for %d port", port)
|
||||
}
|
||||
|
||||
func (tt *TunnelTest) ServerAddr() net.Addr {
|
||||
return tt.Listeners[""][0].Addr()
|
||||
}
|
||||
|
||||
// Addr gives server endpoint of the TCP tunnel for the given ident.
|
||||
//
|
||||
// If the tunnel does not exist or is a HTTP one, TunnelAddr return nil.
|
||||
func (tt *TunnelTest) Addr(ident string) net.Addr {
|
||||
l, ok := tt.Listeners[ident]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return l[1].Addr()
|
||||
}
|
||||
|
||||
// Request creates a HTTP request to a server endpoint of the HTTP tunnel
|
||||
// for the given ident.
|
||||
//
|
||||
// If the tunnel does not exist, Request returns nil.
|
||||
func (tt *TunnelTest) Request(ident string, query url.Values) *http.Request {
|
||||
l, ok := tt.Listeners[ident]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var raw string
|
||||
if query != nil {
|
||||
raw = query.Encode()
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: tt.ServerAddr().String(),
|
||||
Path: "/",
|
||||
RawQuery: raw,
|
||||
},
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Host: l[0].Addr().String(),
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *TunnelTest) Close() (err error) {
|
||||
// Close tunnel.Clients.
|
||||
clients := make(map[*tunnel.Client]struct{})
|
||||
for _, c := range tt.Clients {
|
||||
clients[c] = struct{}{}
|
||||
}
|
||||
for c := range clients {
|
||||
err = nonil(err, c.Close())
|
||||
}
|
||||
|
||||
// Stop all TCP/HTTP servers.
|
||||
listeners := make(map[net.Listener]struct{})
|
||||
for _, l := range tt.Listeners {
|
||||
for _, l := range l {
|
||||
if l != nil {
|
||||
listeners[l] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
for l := range listeners {
|
||||
err = nonil(err, l.Close())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
)
|
||||
|
||||
// async is a helper function to convert a blocking function to a function
|
||||
// returning an error. Useful for plugging function closures into select and co
|
||||
func async(fn func() error) <-chan error {
|
||||
errChan := make(chan error, 0)
|
||||
go func() {
|
||||
select {
|
||||
case errChan <- fn():
|
||||
default:
|
||||
}
|
||||
|
||||
close(errChan)
|
||||
}()
|
||||
|
||||
return errChan
|
||||
}
|
||||
|
||||
type expBackoff struct {
|
||||
mu sync.Mutex
|
||||
bk *backoff.ExponentialBackOff
|
||||
}
|
||||
|
||||
func newForeverBackoff() *expBackoff {
|
||||
eb := &expBackoff{
|
||||
bk: backoff.NewExponentialBackOff(),
|
||||
}
|
||||
eb.bk.MaxElapsedTime = 0 // never stops
|
||||
return eb
|
||||
}
|
||||
|
||||
func (eb *expBackoff) NextBackOff() time.Duration {
|
||||
eb.mu.Lock()
|
||||
defer eb.mu.Unlock()
|
||||
|
||||
return eb.bk.NextBackOff()
|
||||
}
|
||||
|
||||
func (eb *expBackoff) Reset() {
|
||||
eb.mu.Lock()
|
||||
eb.bk.Reset()
|
||||
eb.mu.Unlock()
|
||||
}
|
||||
|
||||
type callbacks struct {
|
||||
mu sync.Mutex
|
||||
name string
|
||||
funcs map[string]func() error
|
||||
}
|
||||
|
||||
func newCallbacks(name string) *callbacks {
|
||||
return &callbacks{
|
||||
name: name,
|
||||
funcs: make(map[string]func() error),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *callbacks) add(identifier string, fn func() error) {
|
||||
c.mu.Lock()
|
||||
c.funcs[identifier] = fn
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *callbacks) pop(identifier string) (func() error, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
fn, ok := c.funcs[identifier]
|
||||
if !ok {
|
||||
return nil, nil // nop
|
||||
}
|
||||
|
||||
delete(c.funcs, identifier)
|
||||
|
||||
if fn == nil {
|
||||
return nil, fmt.Errorf("nil callback set for %q client", identifier)
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
func (c *callbacks) call(identifier string) error {
|
||||
fn, err := c.pop(identifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if fn == nil {
|
||||
return nil // nop
|
||||
}
|
||||
|
||||
return fn()
|
||||
}
|
||||
|
||||
// Returns server control url as a string. Reads scheme and remote address from connection.
|
||||
func controlURL(conn net.Conn) string {
|
||||
return fmt.Sprint(scheme(conn), "://", conn.RemoteAddr(), proto.ControlPath)
|
||||
}
|
||||
|
||||
func scheme(conn net.Conn) (scheme string) {
|
||||
switch conn.(type) {
|
||||
case *tls.Conn:
|
||||
scheme = "https"
|
||||
default:
|
||||
scheme = "http"
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,179 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/koding/logging"
|
||||
)
|
||||
|
||||
type listener struct {
|
||||
net.Listener
|
||||
*vaddrOptions
|
||||
|
||||
done int32
|
||||
|
||||
// ips keeps track of registered clients for ip-based routing;
|
||||
// when last client is deleted from the ip routing map, we stop
|
||||
// listening on connections
|
||||
ips map[string]struct{}
|
||||
}
|
||||
|
||||
type vaddrOptions struct {
|
||||
connCh chan<- net.Conn
|
||||
log logging.Logger
|
||||
}
|
||||
|
||||
type vaddrStorage struct {
|
||||
*vaddrOptions
|
||||
|
||||
listeners map[net.Listener]*listener
|
||||
ports map[int]string // port-based routing: maps port number to identifier
|
||||
ips map[string]string // ip-based routing: maps ip address to identifier
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newVirtualAddrs(opts *vaddrOptions) *vaddrStorage {
|
||||
return &vaddrStorage{
|
||||
vaddrOptions: opts,
|
||||
listeners: make(map[net.Listener]*listener),
|
||||
ports: make(map[int]string),
|
||||
ips: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *listener) serve() {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
l.log.Error("failue listening on %q: %s", l.Addr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&l.done) != 0 {
|
||||
l.log.Debug("stopped serving %q", l.Addr())
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
l.connCh <- conn
|
||||
}
|
||||
}
|
||||
|
||||
func (l *listener) localAddr() string {
|
||||
if addr, ok := l.Addr().(*net.TCPAddr); ok {
|
||||
if addr.IP.Equal(net.IPv4zero) {
|
||||
return net.JoinHostPort("127.0.0.1", strconv.Itoa(addr.Port))
|
||||
}
|
||||
}
|
||||
return l.Addr().String()
|
||||
}
|
||||
|
||||
func (l *listener) stop() {
|
||||
if atomic.CompareAndSwapInt32(&l.done, 0, 1) {
|
||||
// stop is called when no more connections should be accepted by
|
||||
// the user-provided listener; as we can't simple close the listener
|
||||
// to not break the guarantee given by the (*Server).DeleteAddr
|
||||
// method, we make a dummy connection to break out of serve loop.
|
||||
// It is safe to make a dummy connection, as either the following
|
||||
// dial will time out when the listener is busy accepting connections,
|
||||
// or will get closed immadiately after idle listeners accepts connection
|
||||
// and returns from the serve loop.
|
||||
conn, err := net.DialTimeout("tcp", l.localAddr(), defaultTimeout)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (vaddr *vaddrStorage) Add(l net.Listener, ip net.IP, ident string) {
|
||||
vaddr.mu.Lock()
|
||||
defer vaddr.mu.Unlock()
|
||||
|
||||
lis, ok := vaddr.listeners[l]
|
||||
if !ok {
|
||||
lis = vaddr.newListener(l)
|
||||
vaddr.listeners[l] = lis
|
||||
go lis.serve()
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
lis.ips[ip.String()] = struct{}{}
|
||||
vaddr.ips[ip.String()] = ident
|
||||
} else {
|
||||
vaddr.ports[mustPort(l)] = ident
|
||||
}
|
||||
}
|
||||
|
||||
func (vaddr *vaddrStorage) Delete(l net.Listener, ip net.IP) {
|
||||
vaddr.mu.Lock()
|
||||
defer vaddr.mu.Unlock()
|
||||
|
||||
lis, ok := vaddr.listeners[l]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var stop bool
|
||||
|
||||
if ip != nil {
|
||||
delete(lis.ips, ip.String())
|
||||
delete(vaddr.ips, ip.String())
|
||||
|
||||
stop = len(lis.ips) == 0
|
||||
} else {
|
||||
delete(vaddr.ports, mustPort(l))
|
||||
|
||||
stop = true
|
||||
}
|
||||
|
||||
// Only stop listening for connections when listener has clients
|
||||
// registered to tunnel the connections to.
|
||||
if stop {
|
||||
lis.stop()
|
||||
delete(vaddr.listeners, l)
|
||||
}
|
||||
}
|
||||
|
||||
func (vaddr *vaddrStorage) newListener(l net.Listener) *listener {
|
||||
return &listener{
|
||||
Listener: l,
|
||||
vaddrOptions: vaddr.vaddrOptions,
|
||||
ips: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (vaddr *vaddrStorage) getIdent(conn net.Conn) (string, bool) {
|
||||
vaddr.mu.Lock()
|
||||
defer vaddr.mu.Unlock()
|
||||
|
||||
ip, port, err := parseHostPort(conn.LocalAddr().String())
|
||||
if err != nil {
|
||||
vaddr.log.Debug("failed to get identifier for connection %q: %s", conn.LocalAddr(), err)
|
||||
return "", false
|
||||
}
|
||||
|
||||
// First lookup if there's a ip-based route, then try port-base one.
|
||||
|
||||
if ident, ok := vaddr.ips[ip]; ok {
|
||||
return ident, true
|
||||
}
|
||||
|
||||
ident, ok := vaddr.ports[port]
|
||||
return ident, ok
|
||||
}
|
||||
|
||||
func mustPort(l net.Listener) int {
|
||||
_, port, err := parseHostPort(l.Addr().String())
|
||||
if err != nil {
|
||||
// This can happened when user passed custom type that
|
||||
// implements net.Listener, which returns ill-formed
|
||||
// net.Addr value.
|
||||
panic("ill-formed net.Addr: " + err.Error())
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type vhostStorage interface {
|
||||
// AddHost adds the given host and identifier to the storage
|
||||
AddHost(host, identifier string)
|
||||
|
||||
// DeleteHost deletes the given host
|
||||
DeleteHost(host string)
|
||||
|
||||
// GetHost returns the host name for the given identifier
|
||||
GetHost(identifier string) (string, bool)
|
||||
|
||||
// GetIdentifier returns the identifier for the given host
|
||||
GetIdentifier(host string) (string, bool)
|
||||
}
|
||||
|
||||
type virtualHost struct {
|
||||
identifier string
|
||||
}
|
||||
|
||||
// virtualHosts is used for mapping host to users example: host
|
||||
// "fs-1-fatih.kd.io" belongs to user "arslan"
|
||||
type virtualHosts struct {
|
||||
mapping map[string]*virtualHost
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// newVirtualHosts provides an in memory virtual host storage for mapping
|
||||
// virtual hosts to identifiers.
|
||||
func newVirtualHosts() *virtualHosts {
|
||||
return &virtualHosts{
|
||||
mapping: make(map[string]*virtualHost),
|
||||
}
|
||||
}
|
||||
|
||||
func (v *virtualHosts) AddHost(host, identifier string) {
|
||||
v.Lock()
|
||||
v.mapping[host] = &virtualHost{identifier: identifier}
|
||||
v.Unlock()
|
||||
}
|
||||
|
||||
func (v *virtualHosts) DeleteHost(host string) {
|
||||
v.Lock()
|
||||
delete(v.mapping, host)
|
||||
v.Unlock()
|
||||
}
|
||||
|
||||
// GetIdentifier returns the identifier associated with the given host
|
||||
func (v *virtualHosts) GetIdentifier(host string) (string, bool) {
|
||||
v.Lock()
|
||||
ht, ok := v.mapping[host]
|
||||
v.Unlock()
|
||||
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return ht.identifier, true
|
||||
}
|
||||
|
||||
// GetHost returns the host associated with the given identifier
|
||||
func (v *virtualHosts) GetHost(identifier string) (string, bool) {
|
||||
v.Lock()
|
||||
defer v.Unlock()
|
||||
|
||||
for hostname, hst := range v.mapping {
|
||||
if hst.identifier == identifier {
|
||||
return hostname, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
package tunnel_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
||||
)
|
||||
|
||||
func testWebsocket(name string, n int, t *testing.T, tt *tunneltest.TunnelTest) {
|
||||
conn, err := websocketDial(tt, "http")
|
||||
if err != nil {
|
||||
t.Fatalf("Dial()=%s", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
want := &EchoMessage{
|
||||
Value: fmt.Sprintf("message #%d", i),
|
||||
Close: i == (n - 1),
|
||||
}
|
||||
|
||||
err := conn.WriteJSON(want)
|
||||
if err != nil {
|
||||
t.Errorf("(test %s) %d: failed sending %q: %s", name, i, want, err)
|
||||
continue
|
||||
}
|
||||
|
||||
got := &EchoMessage{}
|
||||
|
||||
err = conn.ReadJSON(got)
|
||||
if err != nil {
|
||||
t.Errorf("(test %s) %d: failed reading: %s", name, i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("(test %s) %d: got %+v, want %+v", name, i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testHandler(t *testing.T, fn func(w http.ResponseWriter, r *http.Request) error) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := fn(w, r); err != nil {
|
||||
t.Errorf("handler func error: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocket(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(nil))))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testWebsocket("handlerEchoWS", 100, t, tt)
|
||||
}
|
||||
|
||||
func TestLatencyWebsocket(t *testing.T) {
|
||||
tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(sleep))))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testWebsocket("handlerLatencyEchoWS", 20, t, tt)
|
||||
}
|
Loading…
Reference in New Issue