import github.com/koding/tunnel
This commit is contained in:
parent
551017e893
commit
86be40fea0
|
@ -0,0 +1,19 @@
|
||||||
|
language: go
|
||||||
|
|
||||||
|
sudo: false
|
||||||
|
|
||||||
|
addons:
|
||||||
|
apt:
|
||||||
|
packages:
|
||||||
|
- moreutils
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.4.3
|
||||||
|
- 1.6.3
|
||||||
|
- 1.7
|
||||||
|
|
||||||
|
script:
|
||||||
|
- export GOMAXPROCS=$(nproc)
|
||||||
|
- gofmt -s -l . | ifne false
|
||||||
|
- go build ./...
|
||||||
|
- go test -race ./...
|
|
@ -0,0 +1,28 @@
|
||||||
|
Copyright (c) 2015 The Koding Authors.
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
* Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
* Neither the name of Koding Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
# Tunnel [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/koding/tunnel) [![Go Report Card](https://goreportcard.com/badge/github.com/koding/tunnel)](https://goreportcard.com/report/github.com/koding/tunnel) [![Build Status](http://img.shields.io/travis/koding/tunnel.svg?style=flat-square)](https://travis-ci.org/koding/tunnel)
|
||||||
|
|
||||||
|
Tunnel is a server/client package that enables to proxy public connections to
|
||||||
|
your local machine over a tunnel connection from the local machine to the
|
||||||
|
public server. What this means is, you can share your localhost even if it
|
||||||
|
doesn't have a Public IP or if it's not reachable from outside.
|
||||||
|
|
||||||
|
It uses the excellent [yamux](https://github.com/hashicorp/yamux) package to
|
||||||
|
multiplex connections between server and client.
|
||||||
|
|
||||||
|
The project is under active development, please vendor it if you want to use it.
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
|
||||||
|
The tunnel package consists of two parts. The `server` and the `client`.
|
||||||
|
|
||||||
|
Server is the public facing part. It's type that satisfies the `http.Handler`.
|
||||||
|
So it's easily pluggable into existing servers.
|
||||||
|
|
||||||
|
|
||||||
|
Let assume that you setup your DNS service so all `*.example.com` domains route
|
||||||
|
to your server at the public IP `203.0.113.0`. Let us first create the server
|
||||||
|
part:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/koding/tunnel"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := &tunnel.ServerConfig{}
|
||||||
|
server, _ := tunnel.NewServer(cfg)
|
||||||
|
server.AddHost("sub.example.com", "1234")
|
||||||
|
http.ListenAndServe(":80", server)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Once you create the `server`, you just plug it into your server. The only
|
||||||
|
detail here is to map a virtualhost to a secret token. The secret token is the
|
||||||
|
only part that needs to be known for the client side.
|
||||||
|
|
||||||
|
Let us now create the client side part:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "github.com/koding/tunnel"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := &tunnel.ClientConfig{
|
||||||
|
Identifier: "1234",
|
||||||
|
ServerAddr: "203.0.113.0:80",
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := tunnel.NewClient(cfg)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.Start()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The `Start()` method is by default blocking. As you see you, we just passed the
|
||||||
|
server address and the secret token.
|
||||||
|
|
||||||
|
Now whenever someone hit `sub.example.com`, the request will be proxied to the
|
||||||
|
machine where client is running and hit the local server running `127.0.0.1:80`
|
||||||
|
(assuming there is one). If someone hits `sub.example.com:3000` (assume your
|
||||||
|
server is running at this port), it'll be routed to `127.0.0.1:3000`
|
||||||
|
|
||||||
|
That's it.
|
||||||
|
|
||||||
|
There are many options that can be changed, such as a static local address for
|
||||||
|
your client. Have alook at the
|
||||||
|
[documentation](http://godoc.org/github.com/koding/tunnel)
|
||||||
|
|
||||||
|
|
||||||
|
# Protocol
|
||||||
|
|
||||||
|
The server/client protocol is written in the [spec.md](spec.md) file. Please
|
||||||
|
have a look for more detail.
|
||||||
|
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
The BSD 3-Clause License - see LICENSE for more details
|
|
@ -0,0 +1,565 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/koding/logging"
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||||
|
|
||||||
|
"github.com/hashicorp/yamux"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:generate stringer -type ClientState
|
||||||
|
|
||||||
|
// ErrRedialAborted is emitted on ClientClosed event, when backoff policy
|
||||||
|
// used by a client decided no more reconnection attempts must be made.
|
||||||
|
var ErrRedialAborted = errors.New("unable to restore the connection, aborting")
|
||||||
|
|
||||||
|
// ClientState represents client connection state to tunnel server.
|
||||||
|
type ClientState uint32
|
||||||
|
|
||||||
|
// ClientState enumeration.
|
||||||
|
const (
|
||||||
|
ClientUnknown ClientState = iota
|
||||||
|
ClientStarted
|
||||||
|
ClientConnecting
|
||||||
|
ClientConnected
|
||||||
|
ClientDisconnected
|
||||||
|
ClientClosed // keep it always last
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientStateChange represents single client state transition.
|
||||||
|
type ClientStateChange struct {
|
||||||
|
Identifier string
|
||||||
|
Previous ClientState
|
||||||
|
Current ClientState
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strings implements the fmt.Stringer interface.
|
||||||
|
func (cs *ClientStateChange) String() string {
|
||||||
|
if cs.Error != nil {
|
||||||
|
return fmt.Sprintf("[%s] %s->%s (%s)", cs.Identifier, cs.Previous, cs.Current, cs.Error)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[%s] %s->%s", cs.Identifier, cs.Previous, cs.Current)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backoff defines behavior of staggering reconnection retries.
|
||||||
|
type Backoff interface {
|
||||||
|
// Next returns the duration to sleep before retrying reconnections.
|
||||||
|
// If the returned value is negative, the retry is aborted.
|
||||||
|
NextBackOff() time.Duration
|
||||||
|
|
||||||
|
// Reset is used to signal a reconnection was successful and next
|
||||||
|
// call to Next should return desired time duration for 1st reconnection
|
||||||
|
// attempt.
|
||||||
|
Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client is responsible for creating a control connection to a tunnel server,
|
||||||
|
// creating new tunnels and proxy them to tunnel server.
|
||||||
|
type Client struct {
|
||||||
|
// underlying yamux session
|
||||||
|
session *yamux.Session
|
||||||
|
|
||||||
|
// config holds the ClientConfig
|
||||||
|
config *ClientConfig
|
||||||
|
|
||||||
|
// yamuxConfig is passed to new yamux.Session's
|
||||||
|
yamuxConfig *yamux.Config
|
||||||
|
|
||||||
|
// proxy handles local server communication.
|
||||||
|
proxy ProxyFunc
|
||||||
|
|
||||||
|
// startNotify is a chanel user can get to be notified when client is
|
||||||
|
// connected to the server. The preferred way of doing this however,
|
||||||
|
// would be using StateChanges in ClientConfig where user can provide
|
||||||
|
// his own channel.
|
||||||
|
startNotify chan bool
|
||||||
|
// closed is a flag set when client calls Close() and quits.
|
||||||
|
closed bool
|
||||||
|
// closedMu guards both closed flag and startNotify channel. Since library
|
||||||
|
// owns the channel it's cleared when trying to reconnect.
|
||||||
|
closedMu sync.RWMutex
|
||||||
|
|
||||||
|
reqWg sync.WaitGroup
|
||||||
|
ctrlWg sync.WaitGroup
|
||||||
|
|
||||||
|
state ClientState
|
||||||
|
|
||||||
|
// redialBackoff is used to reconnect in exponential backoff intervals
|
||||||
|
redialBackoff Backoff
|
||||||
|
|
||||||
|
log logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientConfig defines the configuration for the Client
|
||||||
|
type ClientConfig struct {
|
||||||
|
// Identifier is the secret token that needs to be passed to the server.
|
||||||
|
// Required if FetchIdentifier is not set.
|
||||||
|
Identifier string
|
||||||
|
|
||||||
|
// FetchIdentifier can be used to fetch identifier. Required if Identifier
|
||||||
|
// is not set.
|
||||||
|
FetchIdentifier func() (string, error)
|
||||||
|
|
||||||
|
// ServerAddr defines the TCP address of the tunnel server to be connected.
|
||||||
|
// Required if FetchServerAddr is not set.
|
||||||
|
ServerAddr string
|
||||||
|
|
||||||
|
// FetchServerAddr can be used to fetch tunnel server address.
|
||||||
|
// Required if ServerAddress is not set.
|
||||||
|
FetchServerAddr func() (string, error)
|
||||||
|
|
||||||
|
// Dial provides custom transport layer for client server communication.
|
||||||
|
//
|
||||||
|
// If nil, default implementation is to return net.Dial("tcp", address).
|
||||||
|
//
|
||||||
|
// It can be used for connection monitoring, setting different timeouts or
|
||||||
|
// securing the connection.
|
||||||
|
Dial func(network, address string) (net.Conn, error)
|
||||||
|
|
||||||
|
// Proxy defines custom proxing logic. This is optional extension point
|
||||||
|
// where you can provide your local server selection or communication rules.
|
||||||
|
Proxy ProxyFunc
|
||||||
|
|
||||||
|
// StateChanges receives state transition details each time client
|
||||||
|
// connection state changes. The channel is expected to be sufficiently
|
||||||
|
// buffered to keep up with event pace.
|
||||||
|
//
|
||||||
|
// If nil, no information about state transitions are dispatched
|
||||||
|
// by the library.
|
||||||
|
StateChanges chan<- *ClientStateChange
|
||||||
|
|
||||||
|
// Backoff is used to control behavior of staggering reconnection loop.
|
||||||
|
//
|
||||||
|
// If nil, default backoff policy is used which makes a client to never
|
||||||
|
// give up on reconnection.
|
||||||
|
//
|
||||||
|
// If custom backoff is used, client will emit ErrRedialAborted set
|
||||||
|
// with ClientClosed event when no more reconnection atttemps should
|
||||||
|
// be made.
|
||||||
|
Backoff Backoff
|
||||||
|
|
||||||
|
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
|
||||||
|
// yamux.DefaultConfig() is used.
|
||||||
|
YamuxConfig *yamux.Config
|
||||||
|
|
||||||
|
// Log defines the logger. If nil a default logging.Logger is used.
|
||||||
|
Log logging.Logger
|
||||||
|
|
||||||
|
// Debug enables debug mode, enable only if you want to debug the server.
|
||||||
|
Debug bool
|
||||||
|
|
||||||
|
// DEPRECATED:
|
||||||
|
|
||||||
|
// LocalAddr is DEPRECATED please use ProxyHTTP.LocalAddr, see ProxyOverwrite for more details.
|
||||||
|
LocalAddr string
|
||||||
|
|
||||||
|
// FetchLocalAddr is DEPRECATED please use ProxyTCP.FetchLocalAddr, see ProxyOverwrite for more details.
|
||||||
|
FetchLocalAddr func(port int) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify is used to verify the ClientConfig
|
||||||
|
func (c *ClientConfig) verify() error {
|
||||||
|
if c.ServerAddr == "" && c.FetchServerAddr == nil {
|
||||||
|
return errors.New("neither ServerAddr nor FetchServerAddr is set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Identifier == "" && c.FetchIdentifier == nil {
|
||||||
|
return errors.New("neither Identifier nor FetchIdentifier is set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.YamuxConfig != nil {
|
||||||
|
if err := yamux.VerifyConfig(c.YamuxConfig); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Proxy != nil && (c.LocalAddr != "" || c.FetchLocalAddr != nil) {
|
||||||
|
return errors.New("both Proxy and LocalAddr or FetchLocalAddr are set")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new tunnel that is established between the serverAddr
|
||||||
|
// and localAddr. It exits if it can't create a new control connection to the
|
||||||
|
// server. If localAddr is empty client will always try to proxy to a local
|
||||||
|
// port.
|
||||||
|
func NewClient(cfg *ClientConfig) (*Client, error) {
|
||||||
|
if err := cfg.verify(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
yamuxConfig := yamux.DefaultConfig()
|
||||||
|
if cfg.YamuxConfig != nil {
|
||||||
|
yamuxConfig = cfg.YamuxConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxy = DefaultProxy
|
||||||
|
if cfg.Proxy != nil {
|
||||||
|
proxy = cfg.Proxy
|
||||||
|
}
|
||||||
|
// DEPRECATED API SUPPORT
|
||||||
|
if cfg.LocalAddr != "" || cfg.FetchLocalAddr != nil {
|
||||||
|
var f ProxyFuncs
|
||||||
|
if cfg.LocalAddr != "" {
|
||||||
|
f.HTTP = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
|
||||||
|
f.WS = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
|
||||||
|
}
|
||||||
|
if cfg.FetchLocalAddr != nil {
|
||||||
|
f.TCP = (&TCPProxy{FetchLocalAddr: cfg.FetchLocalAddr}).Proxy
|
||||||
|
}
|
||||||
|
proxy = Proxy(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
var bo Backoff = newForeverBackoff()
|
||||||
|
if cfg.Backoff != nil {
|
||||||
|
bo = cfg.Backoff
|
||||||
|
}
|
||||||
|
|
||||||
|
log := newLogger("tunnel-client", cfg.Debug)
|
||||||
|
if cfg.Log != nil {
|
||||||
|
log = cfg.Log
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
config: cfg,
|
||||||
|
yamuxConfig: yamuxConfig,
|
||||||
|
proxy: proxy,
|
||||||
|
startNotify: make(chan bool, 1),
|
||||||
|
redialBackoff: bo,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the client and connects to the server with the identifier.
|
||||||
|
// client.FetchIdentifier() will be used if it's not nil. It's supports
|
||||||
|
// reconnecting with exponential backoff intervals when the connection to the
|
||||||
|
// server disconnects. Call client.Close() to shutdown the client completely. A
|
||||||
|
// successful connection will cause StartNotify() to receive a value.
|
||||||
|
func (c *Client) Start() {
|
||||||
|
fetchIdent := func() (string, error) {
|
||||||
|
if c.config.FetchIdentifier != nil {
|
||||||
|
return c.config.FetchIdentifier()
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.config.Identifier, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fetchServerAddr := func() (string, error) {
|
||||||
|
if c.config.FetchServerAddr != nil {
|
||||||
|
return c.config.FetchServerAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.config.ServerAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.changeState(ClientStarted, nil)
|
||||||
|
|
||||||
|
c.redialBackoff.Reset()
|
||||||
|
var lastErr error
|
||||||
|
for {
|
||||||
|
prev := c.changeState(ClientConnecting, lastErr)
|
||||||
|
|
||||||
|
if c.isRetry(prev) {
|
||||||
|
dur := c.redialBackoff.NextBackOff()
|
||||||
|
if dur < 0 {
|
||||||
|
c.setClosed(true)
|
||||||
|
c.changeState(ClientClosed, ErrRedialAborted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(dur)
|
||||||
|
|
||||||
|
// exit if closed
|
||||||
|
if c.isClosed() {
|
||||||
|
c.changeState(ClientClosed, lastErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
identifier, err := fetchIdent()
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
c.log.Critical("client fetch identifier error: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr, err := fetchServerAddr()
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
c.log.Critical("client fetch server address error: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.setClosed(false)
|
||||||
|
|
||||||
|
if err := c.connect(identifier, serverAddr); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
c.log.Debug("client connect error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exit if closed
|
||||||
|
if c.isClosed() {
|
||||||
|
c.changeState(ClientClosed, lastErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the client and shutdowns the connection to the tunnel server
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
defer c.setClosed(true)
|
||||||
|
|
||||||
|
if c.session == nil {
|
||||||
|
return errors.New("session is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait until all connections are finished
|
||||||
|
waitCh := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
if err := c.session.GoAway(); err != nil {
|
||||||
|
c.log.Debug("Session go away failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.reqWg.Wait()
|
||||||
|
close(waitCh)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-waitCh:
|
||||||
|
// ok
|
||||||
|
case <-time.After(time.Second * 10):
|
||||||
|
c.log.Info("Timeout waiting for connections to finish")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.session.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isClosed securely checks if client is marked as closed.
|
||||||
|
func (c *Client) isClosed() bool {
|
||||||
|
c.closedMu.RLock()
|
||||||
|
defer c.closedMu.RUnlock()
|
||||||
|
return c.closed
|
||||||
|
}
|
||||||
|
|
||||||
|
// setClosed securely marks client as closed (or not closed). If not closed
|
||||||
|
// also empty the value inside the startNotify channel by retrieving it (if any),
|
||||||
|
// so it doesn't block during connect, when the client was closed and started again,
|
||||||
|
// and startNotify was never listened to.
|
||||||
|
func (c *Client) setClosed(closed bool) {
|
||||||
|
c.closedMu.Lock()
|
||||||
|
defer c.closedMu.Unlock()
|
||||||
|
c.closed = closed
|
||||||
|
|
||||||
|
if !closed {
|
||||||
|
// clear channel
|
||||||
|
select {
|
||||||
|
case <-c.startNotify:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startNotifyIfNeeded sends ok to startNotify channel if it's listened to.
|
||||||
|
// This function is called by connect when connection was successful.
|
||||||
|
func (c *Client) startNotifyIfNeeded() {
|
||||||
|
c.closedMu.RLock()
|
||||||
|
if !c.closed {
|
||||||
|
c.log.Debug("sending ok to startNotify chan")
|
||||||
|
select {
|
||||||
|
case c.startNotify <- true:
|
||||||
|
default:
|
||||||
|
// reaching here means the client never read the signal via
|
||||||
|
// StartNotify(). This is OK, we shouldn't except it the consumer
|
||||||
|
// to read from this channel. It's optional, so we just drop the
|
||||||
|
// signal.
|
||||||
|
c.log.Debug("startNotify message was dropped")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.closedMu.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartNotify returns a channel that receives a single value when the client
|
||||||
|
// established a successful connection to the server.
|
||||||
|
func (c *Client) StartNotify() <-chan bool {
|
||||||
|
return c.startNotify
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) changeState(state ClientState, err error) (prev ClientState) {
|
||||||
|
prev = ClientState(atomic.LoadUint32((*uint32)(&c.state)))
|
||||||
|
|
||||||
|
if c.config.StateChanges != nil {
|
||||||
|
change := &ClientStateChange{
|
||||||
|
Identifier: c.config.Identifier,
|
||||||
|
Previous: ClientState(prev),
|
||||||
|
Current: state,
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.config.StateChanges <- change:
|
||||||
|
default:
|
||||||
|
c.log.Warning("Dropping state change due to slow reader: %s", change)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.CompareAndSwapUint32((*uint32)(&c.state), uint32(prev), uint32(state))
|
||||||
|
|
||||||
|
return prev
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) isRetry(state ClientState) bool {
|
||||||
|
return state != ClientStarted && state != ClientClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) connect(identifier, serverAddr string) error {
|
||||||
|
c.log.Debug("Trying to connect to %q with identifier %q", serverAddr, identifier)
|
||||||
|
|
||||||
|
conn, err := c.dial(serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteURL := controlURL(conn)
|
||||||
|
c.log.Debug("CONNECT to %q", remoteURL)
|
||||||
|
req, err := http.NewRequest("CONNECT", remoteURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error creating request to %s: %s", remoteURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set(proto.ClientIdentifierHeader, identifier)
|
||||||
|
|
||||||
|
c.log.Debug("Writing request to TCP: %+v", req)
|
||||||
|
|
||||||
|
if err := req.Write(conn); err != nil {
|
||||||
|
return fmt.Errorf("writing CONNECT request to %s failed: %s", req.URL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Debug("Reading response from TCP")
|
||||||
|
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading CONNECT response from %s failed: %s", req.URL, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK || resp.Status != proto.Connected {
|
||||||
|
out, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("tunnel server error: status=%d, error=%s", resp.StatusCode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("tunnel server error: status=%d, body=%s", resp.StatusCode, string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.ctrlWg.Wait() // wait until previous listenControl observes disconnection
|
||||||
|
|
||||||
|
c.session, err = yamux.Client(conn, c.yamuxConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("session initialization failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var stream net.Conn
|
||||||
|
openStream := func() error {
|
||||||
|
// this is blocking until client opens a session to us
|
||||||
|
stream, err = c.session.Open()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we don't receive anything from the server, we'll timeout
|
||||||
|
select {
|
||||||
|
case err := <-async(openStream):
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("waiting for session to open failed: %s", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second * 10):
|
||||||
|
if stream != nil {
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
return errors.New("timeout opening session")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := stream.Write([]byte(proto.HandshakeRequest)); err != nil {
|
||||||
|
return fmt.Errorf("writing handshake request failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, len(proto.HandshakeResponse))
|
||||||
|
if _, err := stream.Read(buf); err != nil {
|
||||||
|
return fmt.Errorf("reading handshake response failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf) != proto.HandshakeResponse {
|
||||||
|
return fmt.Errorf("invalid handshake response, received: %s", string(buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
ct := newControl(stream)
|
||||||
|
c.log.Debug("client has started successfully")
|
||||||
|
c.redialBackoff.Reset() // we successfully connected, so we can reset the backoff
|
||||||
|
|
||||||
|
c.startNotifyIfNeeded()
|
||||||
|
|
||||||
|
return c.listenControl(ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) dial(serverAddr string) (net.Conn, error) {
|
||||||
|
if c.config.Dial != nil {
|
||||||
|
return c.config.Dial("tcp", serverAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.Dial("tcp", serverAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) listenControl(ct *control) error {
|
||||||
|
c.ctrlWg.Add(1)
|
||||||
|
defer c.ctrlWg.Done()
|
||||||
|
|
||||||
|
c.changeState(ClientConnected, nil)
|
||||||
|
|
||||||
|
for {
|
||||||
|
var msg proto.ControlMessage
|
||||||
|
if err := ct.dec.Decode(&msg); err != nil {
|
||||||
|
c.reqWg.Wait() // wait until all requests are finished
|
||||||
|
c.session.GoAway()
|
||||||
|
c.session.Close()
|
||||||
|
c.changeState(ClientDisconnected, err)
|
||||||
|
|
||||||
|
return fmt.Errorf("failure decoding control message: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Debug("Received control msg %+v", msg)
|
||||||
|
c.log.Debug("Opening a new stream from server session")
|
||||||
|
|
||||||
|
remote, err := c.session.Open()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
isHTTP := msg.Protocol == proto.HTTP
|
||||||
|
if isHTTP {
|
||||||
|
c.reqWg.Add(1)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
c.proxy(remote, &msg)
|
||||||
|
if isHTTP {
|
||||||
|
c.reqWg.Done()
|
||||||
|
}
|
||||||
|
remote.Close()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
// Code generated by "stringer -type ClientState"; DO NOT EDIT
|
||||||
|
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
const _ClientState_name = "ClientUnknownClientStartedClientConnectingClientConnectedClientDisconnectedClientClosed"
|
||||||
|
|
||||||
|
var _ClientState_index = [...]uint8{0, 13, 26, 42, 57, 75, 87}
|
||||||
|
|
||||||
|
func (i ClientState) String() string {
|
||||||
|
if i >= ClientState(len(_ClientState_index)-1) {
|
||||||
|
return fmt.Sprintf("ClientState(%d)", i)
|
||||||
|
}
|
||||||
|
return _ClientState_name[_ClientState_index[i]:_ClientState_index[i+1]]
|
||||||
|
}
|
|
@ -0,0 +1,110 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errControlClosed = errors.New("control connection is closed")
|
||||||
|
|
||||||
|
type control struct {
|
||||||
|
// enc and dec are responsible for encoding and decoding json values forth
|
||||||
|
// and back
|
||||||
|
enc *json.Encoder
|
||||||
|
dec *json.Decoder
|
||||||
|
|
||||||
|
// underlying connection responsible for encoder and decoder
|
||||||
|
nc net.Conn
|
||||||
|
|
||||||
|
// identifier associated with this control
|
||||||
|
identifier string
|
||||||
|
|
||||||
|
mu sync.Mutex // guards the following
|
||||||
|
closed bool // if Close() and quits
|
||||||
|
}
|
||||||
|
|
||||||
|
func newControl(nc net.Conn) *control {
|
||||||
|
c := &control{
|
||||||
|
enc: json.NewEncoder(nc),
|
||||||
|
dec: json.NewDecoder(nc),
|
||||||
|
nc: nc,
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *control) send(v interface{}) error {
|
||||||
|
if c.enc == nil {
|
||||||
|
return errors.New("encoder is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return errControlClosed
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
return c.enc.Encode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *control) recv(v interface{}) error {
|
||||||
|
if c.dec == nil {
|
||||||
|
return errors.New("decoder is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return errControlClosed
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
return c.dec.Decode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *control) Close() error {
|
||||||
|
if c.nc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.closed = true
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
return c.nc.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type controls struct {
|
||||||
|
sync.Mutex
|
||||||
|
controls map[string]*control
|
||||||
|
}
|
||||||
|
|
||||||
|
func newControls() *controls {
|
||||||
|
return &controls{
|
||||||
|
controls: make(map[string]*control),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controls) getControl(identifier string) (*control, bool) {
|
||||||
|
c.Lock()
|
||||||
|
control, ok := c.controls[identifier]
|
||||||
|
c.Unlock()
|
||||||
|
return control, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controls) addControl(identifier string, control *control) {
|
||||||
|
control.identifier = identifier
|
||||||
|
|
||||||
|
c.Lock()
|
||||||
|
c.controls[identifier] = control
|
||||||
|
c.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controls) deleteControl(identifier string) {
|
||||||
|
c.Lock()
|
||||||
|
delete(c.controls, identifier)
|
||||||
|
c.Unlock()
|
||||||
|
}
|
|
@ -0,0 +1,263 @@
|
||||||
|
package tunnel_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel"
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rand.Seed(time.Now().UnixNano() + int64(os.Getpid()))
|
||||||
|
}
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
type EchoMessage struct {
|
||||||
|
Value string `json:"value,omitempty"`
|
||||||
|
Close bool `json:"close,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var timeout = 10 * time.Second
|
||||||
|
|
||||||
|
var dialer = &websocket.Dialer{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
HandshakeTimeout: timeout,
|
||||||
|
NetDial: func(_, addr string) (net.Conn, error) {
|
||||||
|
return net.DialTimeout("tcp4", addr, timeout)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func echoHTTP(tt *tunneltest.TunnelTest, echo string) (string, error) {
|
||||||
|
req := tt.Request("http", url.Values{"echo": []string{echo}})
|
||||||
|
if req == nil {
|
||||||
|
return "", fmt.Errorf(`tunnel "http" does not exist`)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Close = rand.Int()%2 == 0
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
p, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(bytes.TrimSpace(p)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func echoTCP(tt *tunneltest.TunnelTest, echo string) (string, error) {
|
||||||
|
return echoTCPIdent(tt, echo, "tcp")
|
||||||
|
}
|
||||||
|
|
||||||
|
func echoTCPIdent(tt *tunneltest.TunnelTest, echo, ident string) (string, error) {
|
||||||
|
addr := tt.Addr(ident)
|
||||||
|
if addr == nil {
|
||||||
|
return "", fmt.Errorf("tunnel %q does not exist", ident)
|
||||||
|
}
|
||||||
|
s := addr.String()
|
||||||
|
ip := tt.Tunnels[ident].IP
|
||||||
|
if ip != nil {
|
||||||
|
_, port, err := net.SplitHostPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
s = net.JoinHostPort(ip.String(), port)
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := dialTCP(s)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.out <- echo
|
||||||
|
|
||||||
|
select {
|
||||||
|
case reply := <-c.in:
|
||||||
|
return reply, nil
|
||||||
|
case <-time.After(tcpTimeout):
|
||||||
|
return "", fmt.Errorf("timed out waiting for reply from %s (%s) after %v", s, addr, tcpTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketDial(tt *tunneltest.TunnelTest, ident string) (*websocket.Conn, error) {
|
||||||
|
req := tt.Request(ident, nil)
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("no client found for ident %q", ident)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := http.Header{"Host": {req.Host}}
|
||||||
|
wsurl := fmt.Sprintf("ws://%s", tt.ServerAddr())
|
||||||
|
|
||||||
|
conn, _, err := dialer.Dial(wsurl, h)
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleep() {
|
||||||
|
time.Sleep(time.Duration(rand.Intn(2000)) * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlerEchoWS(sleepFn func()) func(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) (e error) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := conn.Close()
|
||||||
|
if e == nil {
|
||||||
|
e = err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if sleepFn != nil {
|
||||||
|
sleepFn()
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
var msg EchoMessage
|
||||||
|
err := conn.ReadJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ReadJSON error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sleepFn != nil {
|
||||||
|
sleepFn()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.WriteJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("WriteJSON error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Close {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlerEchoHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
io.WriteString(w, r.URL.Query().Get("echo"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlerLatencyEchoHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
sleep()
|
||||||
|
handlerEchoHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlerEchoTCP(conn net.Conn) {
|
||||||
|
io.Copy(conn, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlerLatencyEchoTCP(conn net.Conn) {
|
||||||
|
sleep()
|
||||||
|
handlerEchoTCP(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tcpTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
type tcpClient struct {
|
||||||
|
conn net.Conn
|
||||||
|
scanner *bufio.Scanner
|
||||||
|
in chan string
|
||||||
|
out chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tcpClient) loop() {
|
||||||
|
for out := range c.out {
|
||||||
|
if _, err := fmt.Fprintln(c.conn, out); err != nil {
|
||||||
|
log.Printf("[tunnelclient] error writing %q to %q: %s", out, c.conn.RemoteAddr(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.scanner.Scan() {
|
||||||
|
log.Printf("[tunnelclient] error reading from %q: %v", c.conn.RemoteAddr(), c.scanner.Err())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.in <- c.scanner.Text()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tcpClient) Close() error {
|
||||||
|
close(c.out)
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialTCP(addr string) (*tcpClient, error) {
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, tcpTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &tcpClient{
|
||||||
|
conn: conn,
|
||||||
|
scanner: bufio.NewScanner(conn),
|
||||||
|
in: make(chan string, 1),
|
||||||
|
out: make(chan string, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
go c.loop()
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleHTTP(handler interface{}) map[string]*tunneltest.Tunnel {
|
||||||
|
return singleRecHTTP(handler, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleRecHTTP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel {
|
||||||
|
return map[string]*tunneltest.Tunnel{
|
||||||
|
"http": {
|
||||||
|
Type: tunneltest.TypeHTTP,
|
||||||
|
LocalAddr: "127.0.0.1:0",
|
||||||
|
Handler: handler,
|
||||||
|
StateChanges: stateChanges,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleTCP(handler interface{}) map[string]*tunneltest.Tunnel {
|
||||||
|
return singleRecTCP(handler, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleRecTCP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel {
|
||||||
|
return map[string]*tunneltest.Tunnel{
|
||||||
|
"http": {
|
||||||
|
Type: tunneltest.TypeHTTP,
|
||||||
|
LocalAddr: "127.0.0.1:0",
|
||||||
|
Handler: handlerEchoHTTP,
|
||||||
|
StateChanges: stateChanges,
|
||||||
|
},
|
||||||
|
"tcp": {
|
||||||
|
Type: tunneltest.TypeTCP,
|
||||||
|
ClientIdent: "http",
|
||||||
|
LocalAddr: "127.0.0.1:0",
|
||||||
|
RemoteAddr: "127.0.0.1:0",
|
||||||
|
Handler: handler,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,115 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/koding/logging"
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
httpLog = logging.NewLogger("http")
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPProxy forwards HTTP traffic.
|
||||||
|
//
|
||||||
|
// When tunnel server requests a connection it's proxied to 127.0.0.1:incomingPort
|
||||||
|
// where incomingPort is control message LocalPort.
|
||||||
|
// Usually this is tunnel server's public exposed Port.
|
||||||
|
// This behaviour can be changed by setting LocalAddr or FetchLocalAddr.
|
||||||
|
// FetchLocalAddr takes precedence over LocalAddr.
|
||||||
|
//
|
||||||
|
// When connection to local server cannot be established proxy responds with http error message.
|
||||||
|
type HTTPProxy struct {
|
||||||
|
// LocalAddr defines the TCP address of the local server.
|
||||||
|
// This is optional if you want to specify a single TCP address.
|
||||||
|
LocalAddr string
|
||||||
|
// FetchLocalAddr is used for looking up TCP address of the server.
|
||||||
|
// This is optional if you want to specify a dynamic TCP address based on incommig port.
|
||||||
|
FetchLocalAddr func(port int) (string, error)
|
||||||
|
// ErrorResp is custom response send to tunnel server when client cannot
|
||||||
|
// establish connection to local server. If not set a default "no local server"
|
||||||
|
// response is sent.
|
||||||
|
ErrorResp *http.Response
|
||||||
|
// Log is a custom logger that can be used for the proxy.
|
||||||
|
// If not set a "http" logger is used.
|
||||||
|
Log logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy is a ProxyFunc.
|
||||||
|
func (p *HTTPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) {
|
||||||
|
if msg.Protocol != proto.HTTP && msg.Protocol != proto.WS {
|
||||||
|
panic("Proxy mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
var log = p.log()
|
||||||
|
|
||||||
|
var port = msg.LocalPort
|
||||||
|
if port == 0 {
|
||||||
|
port = 80
|
||||||
|
}
|
||||||
|
|
||||||
|
var localAddr = fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
if p.LocalAddr != "" {
|
||||||
|
localAddr = p.LocalAddr
|
||||||
|
} else if p.FetchLocalAddr != nil {
|
||||||
|
l, err := p.FetchLocalAddr(msg.LocalPort)
|
||||||
|
if err != nil {
|
||||||
|
log.Warning("Failed to get custom local address: %s", err)
|
||||||
|
p.sendError(remote)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
localAddr = l
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Dialing local server %q", localAddr)
|
||||||
|
local, err := net.DialTimeout("tcp", localAddr, defaultTimeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Dialing local server %q failed: %s", localAddr, err)
|
||||||
|
p.sendError(remote)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Join(local, remote, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *HTTPProxy) sendError(remote net.Conn) {
|
||||||
|
var w = noLocalServer()
|
||||||
|
if p.ErrorResp != nil {
|
||||||
|
w = p.ErrorResp
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
w.Write(buf)
|
||||||
|
if _, err := io.Copy(remote, buf); err != nil {
|
||||||
|
var log = p.log()
|
||||||
|
log.Debug("Copy in-mem response error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
remote.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func noLocalServer() *http.Response {
|
||||||
|
body := bytes.NewBufferString("no local server")
|
||||||
|
return &http.Response{
|
||||||
|
Status: http.StatusText(http.StatusServiceUnavailable),
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Proto: "HTTP/1.1",
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Body: ioutil.NopCloser(body),
|
||||||
|
ContentLength: int64(body.Len()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *HTTPProxy) log() logging.Logger {
|
||||||
|
if p.Log != nil {
|
||||||
|
return p.Log
|
||||||
|
}
|
||||||
|
return httpLog
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package proto
|
||||||
|
|
||||||
|
// ControlMessage is sent from server to client to establish tunneled connection.
|
||||||
|
type ControlMessage struct {
|
||||||
|
Action Action `json:"action"`
|
||||||
|
Protocol Type `json:"transportProtocol"`
|
||||||
|
LocalPort int `json:"localPort"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Action represents type of ControlMsg request.
|
||||||
|
type Action int
|
||||||
|
|
||||||
|
// ControlMessage actions.
|
||||||
|
const (
|
||||||
|
RequestClientSession Action = iota + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// Type represents tunneled connection type.
|
||||||
|
type Type int
|
||||||
|
|
||||||
|
// ControlMessage protocols.
|
||||||
|
const (
|
||||||
|
HTTP Type = iota + 1
|
||||||
|
TCP
|
||||||
|
WS
|
||||||
|
)
|
|
@ -0,0 +1,19 @@
|
||||||
|
// Package proto defines tunnel client server communication protocol.
|
||||||
|
package proto
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ControlPath is http.Handler url path for control connection.
|
||||||
|
ControlPath = "/_controlPath/"
|
||||||
|
|
||||||
|
// ClientIdentifierHeader is header carrying information about tunnel identifier.
|
||||||
|
ClientIdentifierHeader = "X-KTunnel-Identifier"
|
||||||
|
|
||||||
|
// control messages
|
||||||
|
|
||||||
|
// Connected is message sent by server to client when control connection was established.
|
||||||
|
Connected = "200 Connected to Tunnel"
|
||||||
|
// HandshakeRequest is hello message sent by client to server.
|
||||||
|
HandshakeRequest = "controlHandshake"
|
||||||
|
// HandshakeResponse is response to HandshakeRequest sent by server to client.
|
||||||
|
HandshakeResponse = "controlOk"
|
||||||
|
)
|
|
@ -0,0 +1,101 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/koding/logging"
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back.
|
||||||
|
type ProxyFunc func(remote net.Conn, msg *proto.ControlMessage)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DefaultProxyFuncs holds global default proxy functions for all transport protocols.
|
||||||
|
DefaultProxyFuncs = ProxyFuncs{
|
||||||
|
HTTP: new(HTTPProxy).Proxy,
|
||||||
|
TCP: new(TCPProxy).Proxy,
|
||||||
|
WS: new(HTTPProxy).Proxy,
|
||||||
|
}
|
||||||
|
// DefaultProxy is a ProxyFunc that uses DefaultProxyFuncs.
|
||||||
|
DefaultProxy = Proxy(ProxyFuncs{})
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxyFuncs is a collection of ProxyFunc.
|
||||||
|
type ProxyFuncs struct {
|
||||||
|
// HTTP is custom implementation of HTTP proxing.
|
||||||
|
HTTP ProxyFunc
|
||||||
|
// TCP is custom implementation of TCP proxing.
|
||||||
|
TCP ProxyFunc
|
||||||
|
// WS is custom implementation of web socket proxing.
|
||||||
|
WS ProxyFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy returns a ProxyFunc that uses custom function if provided, otherwise falls back to DefaultProxyFuncs.
|
||||||
|
func Proxy(p ProxyFuncs) ProxyFunc {
|
||||||
|
return func(remote net.Conn, msg *proto.ControlMessage) {
|
||||||
|
var f ProxyFunc
|
||||||
|
switch msg.Protocol {
|
||||||
|
case proto.HTTP:
|
||||||
|
f = DefaultProxyFuncs.HTTP
|
||||||
|
if p.HTTP != nil {
|
||||||
|
f = p.HTTP
|
||||||
|
}
|
||||||
|
case proto.TCP:
|
||||||
|
f = DefaultProxyFuncs.TCP
|
||||||
|
if p.TCP != nil {
|
||||||
|
f = p.TCP
|
||||||
|
}
|
||||||
|
case proto.WS:
|
||||||
|
f = DefaultProxyFuncs.WS
|
||||||
|
if p.WS != nil {
|
||||||
|
f = p.WS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if f == nil {
|
||||||
|
logging.Error("Could not determine proxy function for %v", msg)
|
||||||
|
remote.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
f(remote, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join copies data between local and remote connections.
|
||||||
|
// It reads from one connection and writes to the other.
|
||||||
|
// It's a building block for ProxyFunc implementations.
|
||||||
|
func Join(local, remote net.Conn, log logging.Logger) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
transfer := func(side string, dst, src net.Conn) {
|
||||||
|
log.Debug("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
|
||||||
|
|
||||||
|
n, err := io.Copy(dst, src)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("%s: copy error: %s", side, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := src.Close(); err != nil {
|
||||||
|
log.Debug("%s: close error: %s", side, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// not for yamux streams, but for client to local server connections
|
||||||
|
if d, ok := dst.(*net.TCPConn); ok {
|
||||||
|
if err := d.CloseWrite(); err != nil {
|
||||||
|
log.Debug("%s: closeWrite error: %s", side, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
log.Debug("done proxing %s -> %s: %d bytes", src.RemoteAddr(), dst.RemoteAddr(), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
go transfer("remote to local", local, remote)
|
||||||
|
go transfer("local to remote", remote, local)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
|
@ -0,0 +1,755 @@
|
||||||
|
// Package tunnel is a server/client package that enables to proxy public
|
||||||
|
// connections to your local machine over a tunnel connection from the local
|
||||||
|
// machine to the public server.
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/koding/logging"
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||||
|
|
||||||
|
"github.com/hashicorp/yamux"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errNoClientSession = errors.New("no client session established")
|
||||||
|
defaultTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server is responsible for proxying public connections to the client over a
|
||||||
|
// tunnel connection. It also listens to control messages from the client.
|
||||||
|
type Server struct {
|
||||||
|
// pending contains the channel that is associated with each new tunnel request.
|
||||||
|
pending map[string]chan net.Conn
|
||||||
|
// pendingMu protects the pending map.
|
||||||
|
pendingMu sync.Mutex
|
||||||
|
|
||||||
|
// sessions contains a session per virtual host.
|
||||||
|
// Sessions provides multiplexing over one connection.
|
||||||
|
sessions map[string]*yamux.Session
|
||||||
|
// sessionsMu protects sessions.
|
||||||
|
sessionsMu sync.Mutex
|
||||||
|
|
||||||
|
// controls contains the control connection from the client to the server.
|
||||||
|
controls *controls
|
||||||
|
|
||||||
|
// virtualHosts is used to map public hosts to remote clients.
|
||||||
|
virtualHosts vhostStorage
|
||||||
|
|
||||||
|
// virtualAddrs.
|
||||||
|
virtualAddrs *vaddrStorage
|
||||||
|
|
||||||
|
// connCh is used to publish accepted connections for tcp tunnels.
|
||||||
|
connCh chan net.Conn
|
||||||
|
|
||||||
|
// onConnectCallbacks contains client callbacks called when control
|
||||||
|
// session is established for a client with given identifier.
|
||||||
|
onConnectCallbacks *callbacks
|
||||||
|
|
||||||
|
// onDisconnectCallbacks contains client callbacks called when control
|
||||||
|
// session is closed for a client with given identifier.
|
||||||
|
onDisconnectCallbacks *callbacks
|
||||||
|
|
||||||
|
// states represents current clients' connections state.
|
||||||
|
states map[string]ClientState
|
||||||
|
// statesMu protects states.
|
||||||
|
statesMu sync.RWMutex
|
||||||
|
// stateCh notifies receiver about client state changes.
|
||||||
|
stateCh chan<- *ClientStateChange
|
||||||
|
|
||||||
|
// httpDirector is provided by ServerConfig, if not nil decorates http requests
|
||||||
|
// before forwarding them to client.
|
||||||
|
httpDirector func(*http.Request)
|
||||||
|
|
||||||
|
// yamuxConfig is passed to new yamux.Session's
|
||||||
|
yamuxConfig *yamux.Config
|
||||||
|
|
||||||
|
log logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig defines the configuration for the Server
|
||||||
|
type ServerConfig struct {
|
||||||
|
// StateChanges receives state transition details each time client
|
||||||
|
// connection state changes. The channel is expected to be sufficiently
|
||||||
|
// buffered to keep up with event pace.
|
||||||
|
//
|
||||||
|
// If nil, no information about state transitions are dispatched
|
||||||
|
// by the library.
|
||||||
|
StateChanges chan<- *ClientStateChange
|
||||||
|
|
||||||
|
// Director is a function that modifies HTTP request into a new HTTP request
|
||||||
|
// before sending to client. If nil no modifications are done.
|
||||||
|
Director func(*http.Request)
|
||||||
|
|
||||||
|
// Debug enables debug mode, enable only if you want to debug the server
|
||||||
|
Debug bool
|
||||||
|
|
||||||
|
// Log defines the logger. If nil a default logging.Logger is used.
|
||||||
|
Log logging.Logger
|
||||||
|
|
||||||
|
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
|
||||||
|
// yamux.DefaultConfig() is used.
|
||||||
|
YamuxConfig *yamux.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new Server. The defaults are used if config is nil.
|
||||||
|
func NewServer(cfg *ServerConfig) (*Server, error) {
|
||||||
|
yamuxConfig := yamux.DefaultConfig()
|
||||||
|
if cfg.YamuxConfig != nil {
|
||||||
|
if err := yamux.VerifyConfig(cfg.YamuxConfig); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
yamuxConfig = cfg.YamuxConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
log := newLogger("tunnel-server", cfg.Debug)
|
||||||
|
if cfg.Log != nil {
|
||||||
|
log = cfg.Log
|
||||||
|
}
|
||||||
|
|
||||||
|
connCh := make(chan net.Conn)
|
||||||
|
|
||||||
|
opts := &vaddrOptions{
|
||||||
|
connCh: connCh,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Server{
|
||||||
|
pending: make(map[string]chan net.Conn),
|
||||||
|
sessions: make(map[string]*yamux.Session),
|
||||||
|
onConnectCallbacks: newCallbacks("OnConnect"),
|
||||||
|
onDisconnectCallbacks: newCallbacks("OnDisconnect"),
|
||||||
|
virtualHosts: newVirtualHosts(),
|
||||||
|
virtualAddrs: newVirtualAddrs(opts),
|
||||||
|
controls: newControls(),
|
||||||
|
states: make(map[string]ClientState),
|
||||||
|
stateCh: cfg.StateChanges,
|
||||||
|
httpDirector: cfg.Director,
|
||||||
|
yamuxConfig: yamuxConfig,
|
||||||
|
connCh: connCh,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.serveTCP()
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP is a tunnel that creates an http/websocket tunnel between a
|
||||||
|
// public connection and the client connection.
|
||||||
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// if the user didn't add the control and tunnel handler manually, we'll
|
||||||
|
// going to infer and call the respective path handlers.
|
||||||
|
switch path.Clean(r.URL.Path) + "/" {
|
||||||
|
case proto.ControlPath:
|
||||||
|
s.checkConnect(s.controlHandler).ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.handleHTTP(w, r); err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "no virtual host available") { // this one is outputted too much, unnecessarily
|
||||||
|
s.log.Error("remote %s (%s): %s", r.RemoteAddr, r.RequestURI, err)
|
||||||
|
}
|
||||||
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHTTP handles a single HTTP request
|
||||||
|
func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
s.log.Debug("HandleHTTP request:")
|
||||||
|
s.log.Debug("%v", r)
|
||||||
|
|
||||||
|
if s.httpDirector != nil {
|
||||||
|
s.httpDirector(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
hostPort := strings.ToLower(r.Host)
|
||||||
|
if hostPort == "" {
|
||||||
|
return errors.New("request host is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// if someone hits foo.example.com:8080, this should be proxied to
|
||||||
|
// localhost:8080, so send the port to the client so it knows how to proxy
|
||||||
|
// correctly. If no port is available, it's up to client how to interpret it
|
||||||
|
host, port, err := parseHostPort(hostPort)
|
||||||
|
if err != nil {
|
||||||
|
// no need to return, just continue lazily, port will be 0, which in
|
||||||
|
// our case will be proxied to client's local servers port 80
|
||||||
|
s.log.Debug("No port available for %q, sending port 80 to client", hostPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the identifier associated with this host
|
||||||
|
identifier, ok := s.getIdentifier(hostPort)
|
||||||
|
if !ok {
|
||||||
|
// fallback to host
|
||||||
|
identifier, ok = s.getIdentifier(host)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("no virtual host available for %q", hostPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isWebsocketConn(r) {
|
||||||
|
s.log.Debug("handling websocket connection")
|
||||||
|
|
||||||
|
return s.handleWSConn(w, r, identifier, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := s.dial(identifier, proto.HTTP, port)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
s.log.Debug("Closing stream")
|
||||||
|
stream.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := r.Write(stream); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Debug("Session opened to client, writing request to client")
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read from tunnel: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if resp.Body != nil {
|
||||||
|
if err := resp.Body.Close(); err != nil && err != io.ErrUnexpectedEOF {
|
||||||
|
s.log.Error("resp.Body Close error: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.log.Debug("Response received, writing back to public connection: %+v", resp)
|
||||||
|
|
||||||
|
copyHeader(w.Header(), resp.Header)
|
||||||
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
if _, err := io.Copy(w, resp.Body); err != nil {
|
||||||
|
if err == io.ErrUnexpectedEOF {
|
||||||
|
s.log.Debug("Client closed the connection, couldn't copy response")
|
||||||
|
} else {
|
||||||
|
s.log.Error("copy err: %s", err) // do not return, because we might write multipe headers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) serveTCP() {
|
||||||
|
for conn := range s.connCh {
|
||||||
|
go s.serveTCPConn(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) serveTCPConn(conn net.Conn) {
|
||||||
|
err := s.handleTCPConn(conn)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Warning("failed to serve %q: %s", conn.RemoteAddr(), err)
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleWSConn(w http.ResponseWriter, r *http.Request, ident string, port int) error {
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("webserver doesn't support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hijack not possible: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := s.dial(ident, proto.WS, port)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.Write(stream); err != nil {
|
||||||
|
err = errors.New("unable to write upgrade request: " + err.Error())
|
||||||
|
return nonil(err, stream.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.New("unable to read upgrade response: " + err.Error())
|
||||||
|
return nonil(err, stream.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := resp.Write(conn); err != nil {
|
||||||
|
err = errors.New("unable to write upgrade response: " + err.Error())
|
||||||
|
return nonil(err, stream.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go s.proxy(&wg, conn, stream)
|
||||||
|
go s.proxy(&wg, stream, conn)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
return nonil(stream.Close(), conn.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTCPConn(conn net.Conn) error {
|
||||||
|
ident, ok := s.virtualAddrs.getIdent(conn)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("no virtual address available for %s", conn.LocalAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, port, err := parseHostPort(conn.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := s.dial(ident, proto.TCP, port)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go s.proxy(&wg, conn, stream)
|
||||||
|
go s.proxy(&wg, stream, conn)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
return nonil(stream.Close(), conn.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) proxy(wg *sync.WaitGroup, dst, src net.Conn) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
s.log.Debug("tunneling %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
|
||||||
|
n, err := io.Copy(dst, src)
|
||||||
|
s.log.Debug("tunneled %d bytes %s -> %s: %v", n, src.RemoteAddr(), dst.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) dial(identifier string, p proto.Type, port int) (net.Conn, error) {
|
||||||
|
control, ok := s.getControl(identifier)
|
||||||
|
if !ok {
|
||||||
|
return nil, errNoClientSession
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := s.getSession(identifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := proto.ControlMessage{
|
||||||
|
Action: proto.RequestClientSession,
|
||||||
|
Protocol: p,
|
||||||
|
LocalPort: port,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Debug("Sending control msg %+v", msg)
|
||||||
|
|
||||||
|
// ask client to open a session to us, so we can accept it
|
||||||
|
if err := control.send(msg); err != nil {
|
||||||
|
// we might have several issues here, either the stream is closed, or
|
||||||
|
// the session is going be shut down, the underlying connection might
|
||||||
|
// be broken. In all cases, it's not reliable anymore having a client
|
||||||
|
// session.
|
||||||
|
control.Close()
|
||||||
|
s.deleteControl(identifier)
|
||||||
|
return nil, errNoClientSession
|
||||||
|
}
|
||||||
|
|
||||||
|
var stream net.Conn
|
||||||
|
acceptStream := func() error {
|
||||||
|
stream, err = session.Accept()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we don't receive anything from the client, we'll timeout
|
||||||
|
s.log.Debug("Waiting for session accept")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-async(acceptStream):
|
||||||
|
return stream, err
|
||||||
|
case <-time.After(defaultTimeout):
|
||||||
|
return nil, errors.New("timeout getting session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// controlHandler is used to capture incoming tunnel connect requests into raw
|
||||||
|
// tunnel TCP connections.
|
||||||
|
func (s *Server) controlHandler(w http.ResponseWriter, r *http.Request) (ctErr error) {
|
||||||
|
identifier := r.Header.Get(proto.ClientIdentifierHeader)
|
||||||
|
_, ok := s.getHost(identifier)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("no host associated for identifier %s. please use server.AddHost()", identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
ct, ok := s.getControl(identifier)
|
||||||
|
if ok {
|
||||||
|
ct.Close()
|
||||||
|
s.deleteControl(identifier)
|
||||||
|
s.deleteSession(identifier)
|
||||||
|
s.log.Warning("Control connection for %q already exists. This is a race condition and needs to be fixed on client implementation", identifier)
|
||||||
|
return fmt.Errorf("control conn for %s already exist. \n", identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Debug("Tunnel with identifier %s", identifier)
|
||||||
|
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("webserver doesn't support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hijack not possible: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.WriteString(conn, "HTTP/1.1 "+proto.Connected+"\n\n"); err != nil {
|
||||||
|
return fmt.Errorf("error writing response: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
return fmt.Errorf("error setting connection deadline: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Debug("Creating control session")
|
||||||
|
session, err := yamux.Server(conn, s.yamuxConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.addSession(identifier, session)
|
||||||
|
|
||||||
|
var stream net.Conn
|
||||||
|
|
||||||
|
// close and delete the session/stream if something goes wrong
|
||||||
|
defer func() {
|
||||||
|
if ctErr != nil {
|
||||||
|
if stream != nil {
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
s.deleteSession(identifier)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
acceptStream := func() error {
|
||||||
|
stream, err = session.Accept()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we don't receive anything from the client, we'll timeout
|
||||||
|
select {
|
||||||
|
case err := <-async(acceptStream):
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second * 10):
|
||||||
|
return errors.New("timeout getting session")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Debug("Initiating handshake protocol")
|
||||||
|
buf := make([]byte, len(proto.HandshakeRequest))
|
||||||
|
if _, err := stream.Read(buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf) != proto.HandshakeRequest {
|
||||||
|
return fmt.Errorf("handshake aborted. got: %s", string(buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := stream.Write([]byte(proto.HandshakeResponse)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup control stream and start to listen to messages
|
||||||
|
ct = newControl(stream)
|
||||||
|
s.addControl(identifier, ct)
|
||||||
|
go s.listenControl(ct)
|
||||||
|
|
||||||
|
s.log.Debug("Control connection is setup")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// listenControl listens to messages coming from the client.
|
||||||
|
func (s *Server) listenControl(ct *control) {
|
||||||
|
s.onConnect(ct.identifier)
|
||||||
|
|
||||||
|
for {
|
||||||
|
var msg map[string]interface{}
|
||||||
|
err := ct.dec.Decode(&msg)
|
||||||
|
if err != nil {
|
||||||
|
host, _ := s.getHost(ct.identifier)
|
||||||
|
s.log.Debug("Closing client connection: '%s', %s'", host, ct.identifier)
|
||||||
|
|
||||||
|
// close client connection so it reconnects again
|
||||||
|
ct.Close()
|
||||||
|
|
||||||
|
// don't forget to cleanup anything
|
||||||
|
s.deleteControl(ct.identifier)
|
||||||
|
s.deleteSession(ct.identifier)
|
||||||
|
|
||||||
|
s.onDisconnect(ct.identifier, err)
|
||||||
|
|
||||||
|
if err != io.EOF {
|
||||||
|
s.log.Error("decode err: %s", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// right now we don't do anything with the messages, but because the
|
||||||
|
// underlying connection needs to establihsed, we know when we have
|
||||||
|
// disconnection(above), so we can cleanup the connection.
|
||||||
|
s.log.Debug("msg: %s", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConnect invokes a callback for client with given identifier,
|
||||||
|
// when it establishes a control session.
|
||||||
|
// After a client is connected, the associated function
|
||||||
|
// is also removed and needs to be added again.
|
||||||
|
func (s *Server) OnConnect(identifier string, fn func() error) {
|
||||||
|
s.onConnectCallbacks.add(identifier, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// onConnect sends notifications to listeners (registered in onConnectCallbacks
|
||||||
|
// or stateChanges chanel readers) when client connects.
|
||||||
|
func (s *Server) onConnect(identifier string) {
|
||||||
|
if err := s.onConnectCallbacks.call(identifier); err != nil {
|
||||||
|
s.log.Error("OnConnect: error calling callback for %q: %s", identifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.changeState(identifier, ClientConnected, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDisconnect calls the function when the client connected with the
|
||||||
|
// associated identifier disconnects from the server.
|
||||||
|
// After a client is disconnected, the associated function
|
||||||
|
// is also removed and needs to be added again.
|
||||||
|
func (s *Server) OnDisconnect(identifier string, fn func() error) {
|
||||||
|
s.onDisconnectCallbacks.add(identifier, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// onDisconnect sends notifications to listeners (registered in onDisconnectCallbacks
|
||||||
|
// or stateChanges chanel readers) when client disconnects.
|
||||||
|
func (s *Server) onDisconnect(identifier string, err error) {
|
||||||
|
if err := s.onDisconnectCallbacks.call(identifier); err != nil {
|
||||||
|
s.log.Error("OnDisconnect: error calling callback for %q: %s", identifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.changeState(identifier, ClientClosed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) changeState(identifier string, state ClientState, err error) (prev ClientState) {
|
||||||
|
s.statesMu.Lock()
|
||||||
|
defer s.statesMu.Unlock()
|
||||||
|
|
||||||
|
prev = s.states[identifier]
|
||||||
|
s.states[identifier] = state
|
||||||
|
|
||||||
|
if s.stateCh != nil {
|
||||||
|
change := &ClientStateChange{
|
||||||
|
Identifier: identifier,
|
||||||
|
Previous: prev,
|
||||||
|
Current: state,
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case s.stateCh <- change:
|
||||||
|
default:
|
||||||
|
s.log.Warning("Dropping state change due to slow reader: %s", change)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return prev
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHost adds the given virtual host and maps it to the identifier.
|
||||||
|
func (s *Server) AddHost(host, identifier string) {
|
||||||
|
s.virtualHosts.AddHost(host, identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteHost deletes the given virtual host. Once removed any request to this
|
||||||
|
// host is denied.
|
||||||
|
func (s *Server) DeleteHost(host string) {
|
||||||
|
s.virtualHosts.DeleteHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAddr starts accepting connections on listener l, routing every connection
|
||||||
|
// to a tunnel client given by the identifier.
|
||||||
|
//
|
||||||
|
// When ip parameter is nil, all connections accepted from the listener are
|
||||||
|
// routed to the tunnel client specified by the identifier (port-based routing).
|
||||||
|
//
|
||||||
|
// When ip parameter is non-nil, only those connections are routed whose local
|
||||||
|
// address matches the specified ip (ip-based routing).
|
||||||
|
//
|
||||||
|
// If l listens on multiple interfaces it's desirable to call AddAddr multiple
|
||||||
|
// times with the same l value but different ip one.
|
||||||
|
func (s *Server) AddAddr(l net.Listener, ip net.IP, identifier string) {
|
||||||
|
s.virtualAddrs.Add(l, ip, identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAddr stops listening for connections on the given listener.
|
||||||
|
//
|
||||||
|
// Upon return no more connections will be tunneled, but as the method does not
|
||||||
|
// close the listener, so any ongoing connection won't get interrupted.
|
||||||
|
func (s *Server) DeleteAddr(l net.Listener, ip net.IP) {
|
||||||
|
s.virtualAddrs.Delete(l, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getIdentifier(host string) (string, bool) {
|
||||||
|
identifier, ok := s.virtualHosts.GetIdentifier(host)
|
||||||
|
return identifier, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getHost(identifier string) (string, bool) {
|
||||||
|
host, ok := s.virtualHosts.GetHost(identifier)
|
||||||
|
return host, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) addControl(identifier string, conn *control) {
|
||||||
|
s.controls.addControl(identifier, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getControl(identifier string) (*control, bool) {
|
||||||
|
return s.controls.getControl(identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) deleteControl(identifier string) {
|
||||||
|
s.controls.deleteControl(identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getSession(identifier string) (*yamux.Session, error) {
|
||||||
|
s.sessionsMu.Lock()
|
||||||
|
session, ok := s.sessions[identifier]
|
||||||
|
s.sessionsMu.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no session available for identifier: '%s'", identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) addSession(identifier string, session *yamux.Session) {
|
||||||
|
s.sessionsMu.Lock()
|
||||||
|
s.sessions[identifier] = session
|
||||||
|
s.sessionsMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) deleteSession(identifier string) {
|
||||||
|
s.sessionsMu.Lock()
|
||||||
|
defer s.sessionsMu.Unlock()
|
||||||
|
|
||||||
|
session, ok := s.sessions[identifier]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return // nothing to delete
|
||||||
|
}
|
||||||
|
|
||||||
|
if session != nil {
|
||||||
|
session.GoAway() // don't accept any new connection
|
||||||
|
session.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(s.sessions, identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyHeader(dst, src http.Header) {
|
||||||
|
for k, v := range src {
|
||||||
|
vv := make([]string, len(v))
|
||||||
|
copy(vv, v)
|
||||||
|
dst[k] = vv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkConnect checks whether the incoming request is HTTP CONNECT method.
|
||||||
|
func (s *Server) checkConnect(fn func(w http.ResponseWriter, r *http.Request) error) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != "CONNECT" {
|
||||||
|
http.Error(w, "405 must CONNECT\n", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fn(w, r); err != nil {
|
||||||
|
s.log.Error("Handler err: %v", err.Error())
|
||||||
|
|
||||||
|
if identifier := r.Header.Get(proto.ClientIdentifierHeader); identifier != "" {
|
||||||
|
s.onDisconnect(identifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Error(w, err.Error(), 502)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHostPort(addr string) (string, int, error) {
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := strconv.ParseUint(port, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return host, int(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWebsocketConn(r *http.Request) bool {
|
||||||
|
return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") &&
|
||||||
|
headerContains(r.Header["Upgrade"], "websocket")
|
||||||
|
}
|
||||||
|
|
||||||
|
// headerContains is a copy of tokenListContainsValue from gorilla/websocket/util.go
|
||||||
|
func headerContains(header []string, value string) bool {
|
||||||
|
for _, h := range header {
|
||||||
|
for _, v := range strings.Split(h, ",") {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(v), value) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func nonil(err ...error) error {
|
||||||
|
for _, e := range err {
|
||||||
|
if e != nil {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLogger(name string, debug bool) logging.Logger {
|
||||||
|
log := logging.NewLogger(name)
|
||||||
|
logHandler := logging.NewWriterHandler(os.Stderr)
|
||||||
|
logHandler.Colorize = true
|
||||||
|
log.SetHandler(logHandler)
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
log.SetLevel(logging.DEBUG)
|
||||||
|
logHandler.SetLevel(logging.DEBUG)
|
||||||
|
}
|
||||||
|
|
||||||
|
return log
|
||||||
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
# Specification
|
||||||
|
|
||||||
|
# Naming conventions
|
||||||
|
|
||||||
|
* `server` is listening to public connection and is responsible of routing
|
||||||
|
public HTTP requests to clients.
|
||||||
|
* `client` is a long running process, connected to a server and running on a local machine.
|
||||||
|
* `virtualHost` is a virtual domain that maps a domain to a single client. i.e:
|
||||||
|
`arslan.koding.io` is a virtualhost which is mapped to my `client` running on
|
||||||
|
my local machine.
|
||||||
|
* `identifier` is a secret token, which is not meant to be shared with others.
|
||||||
|
An identifier is responsible of mapping a virtualhost to a client.
|
||||||
|
* `session` is a single TCP connection which uses the library `yamux`. A
|
||||||
|
session can be created either via `yamux.Server()` or `yamux.Client`
|
||||||
|
* `stream` is a `net.Conn` compatible `virtual` connection that is multiplexed
|
||||||
|
over the `session`. A session can have hundreds of thousands streams
|
||||||
|
* `control connection` is a single `stream` which is used to communicate and
|
||||||
|
handle messaging between server and client. It uses a custom protocol which
|
||||||
|
is JSON encoded.
|
||||||
|
* `tunnel connection` is a single `stream` which is used to proxy public HTTP
|
||||||
|
requests from the `server` to the `client` and vice versa. A single `tunnel`
|
||||||
|
connection is created for every single HTTP requests.
|
||||||
|
* `public connection` is a connection from a remote machine to the `server`
|
||||||
|
* `ControlHandler` is a http.Handler which listens to requests coming to
|
||||||
|
`/_controlPath_/`. It's used to setup the initial `session` connection from
|
||||||
|
`client` to `server`. And creates the `control connection` from this session.
|
||||||
|
server and client, and also for all additional new tunnel. It literally
|
||||||
|
captures the incoming HTTP request and hijacks it and converts it into RAW TCP,
|
||||||
|
which then is used as the foundation for all yamux `sessions.`
|
||||||
|
|
||||||
|
|
||||||
|
# Server
|
||||||
|
1. Server is created with `NewServer()` which returns `*Server`, a `http.Handler`
|
||||||
|
compatible type. Plug into any HTTP server you want. The root path `"/"` is
|
||||||
|
recommended to listen and proxy any tunnels. It also listens to any request
|
||||||
|
coming to `ControlHandler`
|
||||||
|
2. Tunneling is based on virtual hosts. A virtual hosts is identified with an
|
||||||
|
unique identifier. This identifier is the only piece that both client and
|
||||||
|
server needs to known ahead. Think of it as a secret token.
|
||||||
|
3. To add a virtual host, call `server.AddHost(virtualHost, identifier)`. This
|
||||||
|
step needs to be done from the server itself. This can be could manually or
|
||||||
|
via custom auth based HTTP handlers, such as "/addhost", which adds
|
||||||
|
virtualhosts and returns the `identifier` to the requester (in our case `client`)
|
||||||
|
4. A DNS record and it's subdomains needs to point to a `server`, so it can
|
||||||
|
handle virtual hosts, i.e: `*.example.com` is routed to a server, which can
|
||||||
|
handle `foo.example.com`, `bar.example.com`, etc..
|
||||||
|
|
||||||
|
|
||||||
|
# Client
|
||||||
|
|
||||||
|
1. Client is created with `NewClient(serverAddr, localAddr)` which returns a
|
||||||
|
`*Client`. Here `serverAddr` is the TCP address to the server. `localAddr`
|
||||||
|
is the server in which all public requests are forwarded to. It's optional if
|
||||||
|
you want it to be done dynamically
|
||||||
|
2. Once a client is created, it starts with `client.Start(identifier)`. Here
|
||||||
|
`identifier` is needed upfront. This method creates the initial TCP
|
||||||
|
connection to the server. It sends the identifier back to the server. This
|
||||||
|
TCP connection is used as the foundation for `yamux.Client()`. Once a yamux
|
||||||
|
session is established, we are able to use this single connection to have
|
||||||
|
multiple streams, which are multiplexed over this one connection. A `control
|
||||||
|
connection` is created and client starts to listen it. `client.Start` is
|
||||||
|
blocking.
|
||||||
|
|
||||||
|
# Control Handshake
|
||||||
|
|
||||||
|
1. Client sends a `handshakeRequest` over the `control connection` stream
|
||||||
|
2. The server sends back a `handshakeResponse` to the client over the `control connection` stream
|
||||||
|
3. Once the client receives the `handshakeResponse` from the server, it starts
|
||||||
|
to listen from the `control connection` stream.
|
||||||
|
4. A `control connection` is json.Encoder/Decoder both for server and client
|
||||||
|
|
||||||
|
|
||||||
|
# Tunnel creation
|
||||||
|
1. When the server receives a public connection, it checks the HTTP host
|
||||||
|
headers and retrieves the corresponding identifier from the given host.
|
||||||
|
2. The server retrieves the `control connection` which was associated with this
|
||||||
|
`identifier` and sends a `ControlMsg` message with the action
|
||||||
|
`RequestClientSession`. This message is in the form of:
|
||||||
|
|
||||||
|
type ControlMsg struct {
|
||||||
|
Action Action `json:"action"`
|
||||||
|
Protocol TransportProtocol `json:"transportProtocol"`
|
||||||
|
LocalPort string `json:"localPort"`
|
||||||
|
}
|
||||||
|
|
||||||
|
Here the `LocalPort` is read from the HTTP Host header. If absent a zero
|
||||||
|
port is sent and client maps it to the local server running at port 80, unless
|
||||||
|
the `localAddr` is specified in `client.Start()` method. `Protocol` is
|
||||||
|
reserved for future features.
|
||||||
|
3. The server immediately starts to listen(accept) to a new `stream`. This is
|
||||||
|
blocking and it waits there.
|
||||||
|
4. When the client receives the `RequestClientSession` message, it opens a new
|
||||||
|
`virtual` TCP connection, a `stream` to the server.
|
||||||
|
5. The server which was waiting for a new stream in step 3, establish the stream.
|
||||||
|
6. The server copies the request over the stream to the client.
|
||||||
|
7. The client copies the request coming from the server to the local server and
|
||||||
|
copies back the result to the server
|
||||||
|
8. The server reads the response coming from the client and returns back it to
|
||||||
|
the public connection requester
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/koding/logging"
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
tpcLog = logging.NewLogger("tcp")
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPProxy forwards TCP streams.
|
||||||
|
//
|
||||||
|
// If port-based routing is used, LocalAddr or FetchLocalAddr field is required
|
||||||
|
// for tunneling to function properly.
|
||||||
|
// Otherwise you'll be forwarding traffic to random ports and this is usually not desired.
|
||||||
|
//
|
||||||
|
// If IP-based routing is used then tunnel server connection request is
|
||||||
|
// proxied to 127.0.0.1:incomingPort where incomingPort is control message LocalPort.
|
||||||
|
// Usually this is tunnel server's public exposed Port.
|
||||||
|
// This behaviour can be changed by setting LocalAddr or FetchLocalAddr.
|
||||||
|
// FetchLocalAddr takes precedence over LocalAddr.
|
||||||
|
type TCPProxy struct {
|
||||||
|
// LocalAddr defines the TCP address of the local server.
|
||||||
|
// This is optional if you want to specify a single TCP address.
|
||||||
|
LocalAddr string
|
||||||
|
// FetchLocalAddr is used for looking up TCP address of the server.
|
||||||
|
// This is optional if you want to specify a dynamic TCP address based on incommig port.
|
||||||
|
FetchLocalAddr func(port int) (string, error)
|
||||||
|
// Log is a custom logger that can be used for the proxy.
|
||||||
|
// If not set a "tcp" logger is used.
|
||||||
|
Log logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy is a ProxyFunc.
|
||||||
|
func (p *TCPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) {
|
||||||
|
if msg.Protocol != proto.TCP {
|
||||||
|
panic("Proxy mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
var log = p.log()
|
||||||
|
|
||||||
|
var port = msg.LocalPort
|
||||||
|
if port == 0 {
|
||||||
|
log.Warning("TCP proxy to port 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
var localAddr = fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
if p.LocalAddr != "" {
|
||||||
|
localAddr = p.LocalAddr
|
||||||
|
} else if p.FetchLocalAddr != nil {
|
||||||
|
l, err := p.FetchLocalAddr(msg.LocalPort)
|
||||||
|
if err != nil {
|
||||||
|
log.Warning("Failed to get custom local address: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
localAddr = l
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Dialing local server: %q", localAddr)
|
||||||
|
local, err := net.DialTimeout("tcp", localAddr, defaultTimeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Dialing local server %q failed: %s", localAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Join(local, remote, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TCPProxy) log() logging.Logger {
|
||||||
|
if p.Log != nil {
|
||||||
|
return p.Log
|
||||||
|
}
|
||||||
|
return tpcLog
|
||||||
|
}
|
|
@ -0,0 +1,412 @@
|
||||||
|
package tunnel_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel"
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMultipleRequest(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
// make a request to tunnelserver, this should be tunneled to local server
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
msg := "hello" + strconv.Itoa(i)
|
||||||
|
res, err := echoHTTP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("echoHTTP error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != msg {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleLatencyRequest(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
// make a request to tunnelserver, this should be tunneled to local server
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
msg := "hello" + strconv.Itoa(i)
|
||||||
|
res, err := echoHTTP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("echoHTTP error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != msg {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconnectClient(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
msg := "hello"
|
||||||
|
res, err := echoHTTP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("echoHTTP error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != msg {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := tt.Clients["http"]
|
||||||
|
|
||||||
|
// close client, and start it again
|
||||||
|
client.Close()
|
||||||
|
|
||||||
|
go client.Start()
|
||||||
|
<-client.StartNotify()
|
||||||
|
|
||||||
|
msg = "helloagain"
|
||||||
|
res, err = echoHTTP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("echoHTTP error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != msg {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoClient(t *testing.T) {
|
||||||
|
const expectedErr = "no client session established"
|
||||||
|
|
||||||
|
rec := tunneltest.NewStateRecorder()
|
||||||
|
|
||||||
|
tt, err := tunneltest.Serve(singleRecHTTP(handlerEchoHTTP, rec.C()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
if err := rec.WaitTransitions(
|
||||||
|
tunnel.ClientStarted,
|
||||||
|
tunnel.ClientConnecting,
|
||||||
|
tunnel.ClientConnected,
|
||||||
|
); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tt.ServerStateRecorder.WaitTransition(
|
||||||
|
tunnel.ClientUnknown,
|
||||||
|
tunnel.ClientConnected,
|
||||||
|
); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// close client, this is the main point of the test
|
||||||
|
if err := tt.Clients["http"].Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rec.WaitTransitions(
|
||||||
|
tunnel.ClientConnected,
|
||||||
|
tunnel.ClientDisconnected,
|
||||||
|
tunnel.ClientClosed,
|
||||||
|
); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tt.ServerStateRecorder.WaitTransition(
|
||||||
|
tunnel.ClientConnected,
|
||||||
|
tunnel.ClientClosed,
|
||||||
|
); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := "hello"
|
||||||
|
res, err := echoHTTP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("echoHTTP error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != expectedErr {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoHost(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
noBackoff := backoff.NewConstantBackOff(time.Duration(-1))
|
||||||
|
|
||||||
|
unknown, err := tunnel.NewClient(&tunnel.ClientConfig{
|
||||||
|
Identifier: "unknown",
|
||||||
|
ServerAddr: tt.ServerAddr().String(),
|
||||||
|
Backoff: noBackoff,
|
||||||
|
Debug: testing.Verbose(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client error: %s", err)
|
||||||
|
}
|
||||||
|
unknown.Start()
|
||||||
|
defer unknown.Close()
|
||||||
|
|
||||||
|
if err := tt.ServerStateRecorder.WaitTransition(
|
||||||
|
tunnel.ClientUnknown,
|
||||||
|
tunnel.ClientClosed,
|
||||||
|
); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
unknown.Start()
|
||||||
|
if err := tt.ServerStateRecorder.WaitTransition(
|
||||||
|
tunnel.ClientClosed,
|
||||||
|
tunnel.ClientClosed,
|
||||||
|
); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoLocalServer(t *testing.T) {
|
||||||
|
const expectedErr = "no local server"
|
||||||
|
|
||||||
|
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
// close local listener, this is the main point of the test
|
||||||
|
tt.Listeners["http"][0].Close()
|
||||||
|
|
||||||
|
msg := "hello"
|
||||||
|
res, err := echoHTTP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("echoHTTP error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != expectedErr {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleRequest(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
msg := "hello"
|
||||||
|
res, err := echoHTTP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("echoHTTP error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != msg {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleLatencyRequest(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
msg := "hello"
|
||||||
|
res, err := echoHTTP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("echoHTTP error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != msg {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleTCP(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
msg := "hello"
|
||||||
|
res, err := echoTCP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("echoTCP error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg != res {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleTCP(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
msg := "hello" + strconv.Itoa(i)
|
||||||
|
res, err := echoTCP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("echoTCP: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != msg {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleLatencyTCP(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleTCP(handlerLatencyEchoTCP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
msg := "hello" + strconv.Itoa(i)
|
||||||
|
res, err := echoTCP(tt, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("echoTCP: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != msg {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleStreamTCP(t *testing.T) {
|
||||||
|
tunnels := map[string]*tunneltest.Tunnel{
|
||||||
|
"http": {
|
||||||
|
Type: tunneltest.TypeHTTP,
|
||||||
|
LocalAddr: "127.0.0.1:0",
|
||||||
|
Handler: handlerEchoHTTP,
|
||||||
|
},
|
||||||
|
"tcp": {
|
||||||
|
Type: tunneltest.TypeTCP,
|
||||||
|
ClientIdent: "http",
|
||||||
|
LocalAddr: "127.0.0.1:0",
|
||||||
|
RemoteAddr: "127.0.0.1:0",
|
||||||
|
Handler: handlerEchoTCP,
|
||||||
|
},
|
||||||
|
"tcp_all": {
|
||||||
|
Type: tunneltest.TypeTCP,
|
||||||
|
ClientIdent: "http",
|
||||||
|
LocalAddr: "127.0.0.1:0",
|
||||||
|
RemoteAddr: "0.0.0.0:0",
|
||||||
|
Handler: handlerEchoTCP,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs, err := tunneltest.UsableAddrs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clients := []string{"tcp"}
|
||||||
|
for i, addr := range addrs {
|
||||||
|
if addr.IP.IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
client := fmt.Sprintf("tcp_%d", i)
|
||||||
|
|
||||||
|
tunnels[client] = &tunneltest.Tunnel{
|
||||||
|
Type: tunneltest.TypeTCP,
|
||||||
|
ClientIdent: "http",
|
||||||
|
LocalAddr: "127.0.0.1:0",
|
||||||
|
RemoteAddrIdent: "tcp_all",
|
||||||
|
IP: addr.IP,
|
||||||
|
Handler: handlerEchoTCP,
|
||||||
|
}
|
||||||
|
|
||||||
|
clients = append(clients, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
tt, err := tunneltest.Serve(tunnels)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tt.Close()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100/len(clients); i++ {
|
||||||
|
wg.Add(len(clients))
|
||||||
|
|
||||||
|
for j, ident := range clients {
|
||||||
|
go func(ident string, i, j int) {
|
||||||
|
defer wg.Done()
|
||||||
|
msg := fmt.Sprintf("hello_%d_client_%d", j, i)
|
||||||
|
res, err := echoTCPIdent(tt, msg, ident)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("echoTCP: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != msg {
|
||||||
|
t.Errorf("got %q, want %q", res, msg)
|
||||||
|
}
|
||||||
|
}(ident, i, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
|
@ -0,0 +1,118 @@
|
||||||
|
package tunneltest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
recWaitTimeout = 5 * time.Second
|
||||||
|
recBuffer = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
// States is a sequence of client state changes.
|
||||||
|
type States []*tunnel.ClientStateChange
|
||||||
|
|
||||||
|
func (s States) String() string {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
fmt.Fprintf(&buf, "[%s", s[0].String())
|
||||||
|
|
||||||
|
for _, s := range s[1:] {
|
||||||
|
fmt.Fprintf(&buf, ",%s", s.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteRune(']')
|
||||||
|
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StateRecorder saves state changes pushed to StateRecorder.C().
|
||||||
|
type StateRecorder struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
ch chan *tunnel.ClientStateChange
|
||||||
|
recorded []*tunnel.ClientStateChange
|
||||||
|
offset int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStateRecorder() *StateRecorder {
|
||||||
|
rec := &StateRecorder{
|
||||||
|
ch: make(chan *tunnel.ClientStateChange, recBuffer),
|
||||||
|
}
|
||||||
|
|
||||||
|
go rec.record()
|
||||||
|
|
||||||
|
return rec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *StateRecorder) record() {
|
||||||
|
for state := range rec.ch {
|
||||||
|
rec.mu.Lock()
|
||||||
|
rec.recorded = append(rec.recorded, state)
|
||||||
|
rec.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *StateRecorder) C() chan<- *tunnel.ClientStateChange {
|
||||||
|
return rec.ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *StateRecorder) WaitTransitions(states ...tunnel.ClientState) error {
|
||||||
|
from := states[0]
|
||||||
|
for _, to := range states[1:] {
|
||||||
|
if err := rec.WaitTransition(from, to); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
from = to
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *StateRecorder) WaitTransition(from, to tunnel.ClientState) error {
|
||||||
|
timeout := time.After(recWaitTimeout)
|
||||||
|
|
||||||
|
var lastStates []*tunnel.ClientStateChange
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
return fmt.Errorf("timed out waiting for %s->%s transition: %v", from, to, States(lastStates))
|
||||||
|
default:
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
lastStates = rec.States()[rec.offset:]
|
||||||
|
|
||||||
|
for i, state := range lastStates {
|
||||||
|
if from != 0 && state.Previous != from {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if to != 0 && state.Current != to {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rec.offset += i
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *StateRecorder) States() []*tunnel.ClientStateChange {
|
||||||
|
rec.mu.Lock()
|
||||||
|
defer rec.mu.Unlock()
|
||||||
|
|
||||||
|
states := make([]*tunnel.ClientStateChange, len(rec.recorded))
|
||||||
|
copy(states, rec.recorded)
|
||||||
|
return states
|
||||||
|
}
|
|
@ -0,0 +1,561 @@
|
||||||
|
package tunneltest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel"
|
||||||
|
)
|
||||||
|
|
||||||
|
var debugNet = os.Getenv("DEBUGNET") == "1"
|
||||||
|
|
||||||
|
type dbgListener struct {
|
||||||
|
net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l dbgListener) Accept() (net.Conn, error) {
|
||||||
|
conn, err := l.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dbgConn{conn}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type dbgConn struct {
|
||||||
|
net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c dbgConn) Read(p []byte) (int, error) {
|
||||||
|
n, err := c.Conn.Read(p)
|
||||||
|
os.Stderr.Write(p)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c dbgConn) Write(p []byte) (int, error) {
|
||||||
|
n, err := c.Conn.Write(p)
|
||||||
|
os.Stderr.Write(p)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func logf(format string, args ...interface{}) {
|
||||||
|
if testing.Verbose() {
|
||||||
|
log.Printf("[tunneltest] "+format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func nonil(err ...error) error {
|
||||||
|
for _, e := range err {
|
||||||
|
if e != nil {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHostPort(addr string) (string, int, error) {
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := strconv.ParseUint(port, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return host, int(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsableAddrs returns all tcp addresses that we can bind a listener to.
|
||||||
|
func UsableAddrs() ([]*net.TCPAddr, error) {
|
||||||
|
addrs, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var usable []*net.TCPAddr
|
||||||
|
for _, addr := range addrs {
|
||||||
|
if ipNet, ok := addr.(*net.IPNet); ok {
|
||||||
|
if !ipNet.IP.IsLinkLocalUnicast() {
|
||||||
|
usable = append(usable, &net.TCPAddr{IP: ipNet.IP})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(usable) == 0 {
|
||||||
|
return nil, errors.New("no usable addresses found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return usable, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeHTTP = iota
|
||||||
|
TypeTCP
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tunnel represents a single HTTP or TCP tunnel that can be served
|
||||||
|
// by TunnelTest.
|
||||||
|
type Tunnel struct {
|
||||||
|
// Type specifies a tunnel type - either TypeHTTP (default) or TypeTCP.
|
||||||
|
Type int
|
||||||
|
|
||||||
|
// Handler is a handler to use for serving tunneled connections on
|
||||||
|
// local server. The value of this field is required to be of type:
|
||||||
|
//
|
||||||
|
// - http.Handler or http.HandlerFunc for HTTP tunnels
|
||||||
|
// - func(net.Conn) for TCP tunnels
|
||||||
|
//
|
||||||
|
// Required field.
|
||||||
|
Handler interface{}
|
||||||
|
|
||||||
|
// LocalAddr is a network address of local server that handles
|
||||||
|
// connections/requests with Handler.
|
||||||
|
//
|
||||||
|
// Optional field, takes value of "127.0.0.1:0" when empty.
|
||||||
|
LocalAddr string
|
||||||
|
|
||||||
|
// ClientIdent is an identifier of a client that have already
|
||||||
|
// registered a HTTP tunnel and have established control connection.
|
||||||
|
//
|
||||||
|
// If the Type is TypeTCP, instead of creating new client
|
||||||
|
// for this TCP tunnel, we add it to an existing client
|
||||||
|
// specified by the field.
|
||||||
|
//
|
||||||
|
// Optional field for TCP tunnels.
|
||||||
|
// Ignored field for HTTP tunnels.
|
||||||
|
ClientIdent string
|
||||||
|
|
||||||
|
// RemoteAddr is a network address of remote server, which accepts
|
||||||
|
// connections on a tunnel server side.
|
||||||
|
//
|
||||||
|
// Required field for TCP tunnels.
|
||||||
|
// Ignored field for HTTP tunnels.
|
||||||
|
RemoteAddr string
|
||||||
|
|
||||||
|
// RemoteAddrIdent an identifier of an already existing listener,
|
||||||
|
// that listens on multiple interfaces; if the RemoteAddrIdent is valid
|
||||||
|
// identifier the IP field is required to be non-nil and RemoteAddr
|
||||||
|
// is ignored.
|
||||||
|
//
|
||||||
|
// Optional field for TCP tunnels.
|
||||||
|
// Ignored field for HTTP tunnels.
|
||||||
|
RemoteAddrIdent string
|
||||||
|
|
||||||
|
// IP specifies an IP address value for IP-based routing for TCP tunnels.
|
||||||
|
// For more details see inline documentation for (*tunnel.Server).AddAddr.
|
||||||
|
//
|
||||||
|
// Optional field for TCP tunnels.
|
||||||
|
// Ignored field for HTTP tunnels.
|
||||||
|
IP net.IP
|
||||||
|
|
||||||
|
// StateChanges listens on state transitions.
|
||||||
|
//
|
||||||
|
// If ClientIdent field is empty, the StateChanges will receive
|
||||||
|
// state transition events for the newly created client.
|
||||||
|
// Otherwise setting this field is a nop.
|
||||||
|
StateChanges chan<- *tunnel.ClientStateChange
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelTest struct {
|
||||||
|
Server *tunnel.Server
|
||||||
|
ServerStateRecorder *StateRecorder
|
||||||
|
Clients map[string]*tunnel.Client
|
||||||
|
Listeners map[string][2]net.Listener // [0] is local listener, [1] is remote one (for TCP tunnels)
|
||||||
|
Addrs []*net.TCPAddr
|
||||||
|
Tunnels map[string]*Tunnel
|
||||||
|
DebugNet bool // for debugging network communication
|
||||||
|
|
||||||
|
mu sync.Mutex // protects Listeners
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTunnelTest() (*TunnelTest, error) {
|
||||||
|
rec := NewStateRecorder()
|
||||||
|
|
||||||
|
cfg := &tunnel.ServerConfig{
|
||||||
|
StateChanges: rec.C(),
|
||||||
|
Debug: testing.Verbose(),
|
||||||
|
}
|
||||||
|
s, err := tunnel.NewServer(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if debugNet {
|
||||||
|
l = dbgListener{l}
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs, err := UsableAddrs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go (&http.Server{Handler: s}).Serve(l)
|
||||||
|
|
||||||
|
return &TunnelTest{
|
||||||
|
Server: s,
|
||||||
|
ServerStateRecorder: rec,
|
||||||
|
Clients: make(map[string]*tunnel.Client),
|
||||||
|
Listeners: map[string][2]net.Listener{"": {l, nil}},
|
||||||
|
Addrs: addrs,
|
||||||
|
Tunnels: make(map[string]*Tunnel),
|
||||||
|
DebugNet: debugNet,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve creates new TunnelTest that serves the given tunnels.
|
||||||
|
//
|
||||||
|
// If tunnels is nil, DefaultTunnels() are used instead.
|
||||||
|
func Serve(tunnels map[string]*Tunnel) (*TunnelTest, error) {
|
||||||
|
tt, err := NewTunnelTest()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tt.Serve(tunnels); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt *TunnelTest) serveSingle(ident string, t *Tunnel) (bool, error) {
|
||||||
|
// Verify tunnel dependencies for TCP tunnels.
|
||||||
|
if t.Type == TypeTCP {
|
||||||
|
// If tunnel specified by t.Client was not already started,
|
||||||
|
// skip and move on.
|
||||||
|
if _, ok := tt.Clients[t.ClientIdent]; t.ClientIdent != "" && !ok {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the TCP tunnel whose remote endpoint listens on multiple
|
||||||
|
// interfaces is already served.
|
||||||
|
if t.RemoteAddrIdent != "" {
|
||||||
|
if _, ok := tt.Listeners[t.RemoteAddrIdent]; !ok {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Tunnels[t.RemoteAddrIdent].Type != TypeTCP {
|
||||||
|
return false, fmt.Errorf("expected tunnel %q to be of TCP type", t.RemoteAddrIdent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := net.Listen("tcp", t.LocalAddr)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.LocalAddr, ident, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.DebugNet {
|
||||||
|
l = dbgListener{l}
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := l.Addr().String()
|
||||||
|
httpProxy := &tunnel.HTTPProxy{LocalAddr: localAddr}
|
||||||
|
tcpProxy := &tunnel.TCPProxy{FetchLocalAddr: tt.fetchLocalAddr}
|
||||||
|
|
||||||
|
cfg := &tunnel.ClientConfig{
|
||||||
|
Identifier: ident,
|
||||||
|
ServerAddr: tt.ServerAddr().String(),
|
||||||
|
Proxy: tunnel.Proxy(tunnel.ProxyFuncs{
|
||||||
|
HTTP: httpProxy.Proxy,
|
||||||
|
TCP: tcpProxy.Proxy,
|
||||||
|
}),
|
||||||
|
StateChanges: t.StateChanges,
|
||||||
|
Debug: testing.Verbose(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register tunnel:
|
||||||
|
//
|
||||||
|
// - start tunnel.Client (tt.Clients[ident]) or reuse existing one (tt.Clients[t.ExistingClient])
|
||||||
|
// - listen on local address and start local server (tt.Listeners[ident][0])
|
||||||
|
// - register tunnel on tunnel.Server
|
||||||
|
//
|
||||||
|
switch t.Type {
|
||||||
|
case TypeHTTP:
|
||||||
|
// TODO(rjeczalik): refactor to separate method
|
||||||
|
|
||||||
|
h, ok := t.Handler.(http.Handler)
|
||||||
|
if !ok {
|
||||||
|
h, ok = t.Handler.(http.HandlerFunc)
|
||||||
|
if !ok {
|
||||||
|
fn, ok := t.Handler.(func(http.ResponseWriter, *http.Request))
|
||||||
|
if !ok {
|
||||||
|
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
h = http.HandlerFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
logf("serving on local %s for HTTP tunnel %q", l.Addr(), ident)
|
||||||
|
|
||||||
|
go (&http.Server{Handler: h}).Serve(l)
|
||||||
|
|
||||||
|
tt.Server.AddHost(localAddr, ident)
|
||||||
|
|
||||||
|
tt.mu.Lock()
|
||||||
|
tt.Listeners[ident] = [2]net.Listener{l, nil}
|
||||||
|
tt.mu.Unlock()
|
||||||
|
|
||||||
|
if err := tt.addClient(ident, cfg); err != nil {
|
||||||
|
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logf("registered HTTP tunnel: host=%s, ident=%s", localAddr, ident)
|
||||||
|
|
||||||
|
case TypeTCP:
|
||||||
|
// TODO(rjeczalik): refactor to separate method
|
||||||
|
|
||||||
|
h, ok := t.Handler.(func(net.Conn))
|
||||||
|
if !ok {
|
||||||
|
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
logf("serving on local %s for TCP tunnel %q", l.Addr(), ident)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed accepting conn for %q tunnel: %s", ident, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go h(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var remote net.Listener
|
||||||
|
|
||||||
|
if t.RemoteAddrIdent != "" {
|
||||||
|
tt.mu.Lock()
|
||||||
|
remote = tt.Listeners[t.RemoteAddrIdent][1]
|
||||||
|
tt.mu.Unlock()
|
||||||
|
} else {
|
||||||
|
remote, err = net.Listen("tcp", t.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.RemoteAddr, ident, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addrIdent holds identifier of client which is going to have registered
|
||||||
|
// tunnel via (*tunnel.Server).AddAddr
|
||||||
|
addrIdent := ident
|
||||||
|
if t.ClientIdent != "" {
|
||||||
|
tt.Clients[ident] = tt.Clients[t.ClientIdent]
|
||||||
|
addrIdent = t.ClientIdent
|
||||||
|
}
|
||||||
|
|
||||||
|
tt.Server.AddAddr(remote, t.IP, addrIdent)
|
||||||
|
|
||||||
|
tt.mu.Lock()
|
||||||
|
tt.Listeners[ident] = [2]net.Listener{l, remote}
|
||||||
|
tt.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := tt.Clients[ident]; !ok {
|
||||||
|
if err := tt.addClient(ident, cfg); err != nil {
|
||||||
|
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logf("registered TCP tunnel: listener=%s, ip=%v, ident=%s", remote.Addr(), t.IP, addrIdent)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("unknown %q tunnel type: %d", ident, t.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt *TunnelTest) addClient(ident string, cfg *tunnel.ClientConfig) error {
|
||||||
|
if _, ok := tt.Clients[ident]; ok {
|
||||||
|
return fmt.Errorf("tunnel %q is already being served", ident)
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := tunnel.NewClient(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
tt.Server.OnConnect(ident, func() error {
|
||||||
|
close(done)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
go c.Start()
|
||||||
|
<-c.StartNotify()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
return errors.New("timed out after 10s waiting on client to establish control conn")
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
|
||||||
|
tt.Clients[ident] = c
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt *TunnelTest) Serve(tunnels map[string]*Tunnel) error {
|
||||||
|
if len(tunnels) == 0 {
|
||||||
|
return errors.New("no tunnels to serve")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since one tunnels depends on others do 3 passes to start them
|
||||||
|
// all, each started tunnel is removed from the tunnels map.
|
||||||
|
// After 3 passes all of them must be started, otherwise the
|
||||||
|
// configuration is bad:
|
||||||
|
//
|
||||||
|
// - first pass starts HTTP tunnels as new client tunnels
|
||||||
|
// - second pass starts TCP tunnels that rely on on already existing client tunnels (t.ClientIdent)
|
||||||
|
// - third pass starts TCP tunnels that rely on on already existing TCP tunnels (t.RemoteAddrIdent)
|
||||||
|
//
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if err := tt.popServedDeps(tunnels); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tunnels) != 0 {
|
||||||
|
unresolved := make([]string, len(tunnels))
|
||||||
|
for ident := range tunnels {
|
||||||
|
unresolved = append(unresolved, ident)
|
||||||
|
}
|
||||||
|
sort.Strings(unresolved)
|
||||||
|
|
||||||
|
return fmt.Errorf("unable to start tunnels due to unresolved dependencies: %v", unresolved)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt *TunnelTest) popServedDeps(tunnels map[string]*Tunnel) error {
|
||||||
|
for ident, t := range tunnels {
|
||||||
|
ok, err := tt.serveSingle(ident, t)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
// Remove already started tunnels so they won't get started again.
|
||||||
|
delete(tunnels, ident)
|
||||||
|
tt.Tunnels[ident] = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt *TunnelTest) fetchLocalAddr(port int) (string, error) {
|
||||||
|
tt.mu.Lock()
|
||||||
|
defer tt.mu.Unlock()
|
||||||
|
|
||||||
|
for _, l := range tt.Listeners {
|
||||||
|
if l[1] == nil {
|
||||||
|
// this listener does not belong to a TCP tunnel
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, remotePort, err := parseHostPort(l[1].Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if port == remotePort {
|
||||||
|
return l[0].Addr().String(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("no route for %d port", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt *TunnelTest) ServerAddr() net.Addr {
|
||||||
|
return tt.Listeners[""][0].Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr gives server endpoint of the TCP tunnel for the given ident.
|
||||||
|
//
|
||||||
|
// If the tunnel does not exist or is a HTTP one, TunnelAddr return nil.
|
||||||
|
func (tt *TunnelTest) Addr(ident string) net.Addr {
|
||||||
|
l, ok := tt.Listeners[ident]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return l[1].Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request creates a HTTP request to a server endpoint of the HTTP tunnel
|
||||||
|
// for the given ident.
|
||||||
|
//
|
||||||
|
// If the tunnel does not exist, Request returns nil.
|
||||||
|
func (tt *TunnelTest) Request(ident string, query url.Values) *http.Request {
|
||||||
|
l, ok := tt.Listeners[ident]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw string
|
||||||
|
if query != nil {
|
||||||
|
raw = query.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Request{
|
||||||
|
Method: "GET",
|
||||||
|
URL: &url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: tt.ServerAddr().String(),
|
||||||
|
Path: "/",
|
||||||
|
RawQuery: raw,
|
||||||
|
},
|
||||||
|
Proto: "HTTP/1.1",
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Host: l[0].Addr().String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt *TunnelTest) Close() (err error) {
|
||||||
|
// Close tunnel.Clients.
|
||||||
|
clients := make(map[*tunnel.Client]struct{})
|
||||||
|
for _, c := range tt.Clients {
|
||||||
|
clients[c] = struct{}{}
|
||||||
|
}
|
||||||
|
for c := range clients {
|
||||||
|
err = nonil(err, c.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop all TCP/HTTP servers.
|
||||||
|
listeners := make(map[net.Listener]struct{})
|
||||||
|
for _, l := range tt.Listeners {
|
||||||
|
for _, l := range l {
|
||||||
|
if l != nil {
|
||||||
|
listeners[l] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for l := range listeners {
|
||||||
|
err = nonil(err, l.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,121 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff"
|
||||||
|
)
|
||||||
|
|
||||||
|
// async is a helper function to convert a blocking function to a function
|
||||||
|
// returning an error. Useful for plugging function closures into select and co
|
||||||
|
func async(fn func() error) <-chan error {
|
||||||
|
errChan := make(chan error, 0)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case errChan <- fn():
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
close(errChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
type expBackoff struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
bk *backoff.ExponentialBackOff
|
||||||
|
}
|
||||||
|
|
||||||
|
func newForeverBackoff() *expBackoff {
|
||||||
|
eb := &expBackoff{
|
||||||
|
bk: backoff.NewExponentialBackOff(),
|
||||||
|
}
|
||||||
|
eb.bk.MaxElapsedTime = 0 // never stops
|
||||||
|
return eb
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eb *expBackoff) NextBackOff() time.Duration {
|
||||||
|
eb.mu.Lock()
|
||||||
|
defer eb.mu.Unlock()
|
||||||
|
|
||||||
|
return eb.bk.NextBackOff()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eb *expBackoff) Reset() {
|
||||||
|
eb.mu.Lock()
|
||||||
|
eb.bk.Reset()
|
||||||
|
eb.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
type callbacks struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
name string
|
||||||
|
funcs map[string]func() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCallbacks(name string) *callbacks {
|
||||||
|
return &callbacks{
|
||||||
|
name: name,
|
||||||
|
funcs: make(map[string]func() error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *callbacks) add(identifier string, fn func() error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.funcs[identifier] = fn
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *callbacks) pop(identifier string) (func() error, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
fn, ok := c.funcs[identifier]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil // nop
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(c.funcs, identifier)
|
||||||
|
|
||||||
|
if fn == nil {
|
||||||
|
return nil, fmt.Errorf("nil callback set for %q client", identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *callbacks) call(identifier string) error {
|
||||||
|
fn, err := c.pop(identifier)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if fn == nil {
|
||||||
|
return nil // nop
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns server control url as a string. Reads scheme and remote address from connection.
|
||||||
|
func controlURL(conn net.Conn) string {
|
||||||
|
return fmt.Sprint(scheme(conn), "://", conn.RemoteAddr(), proto.ControlPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func scheme(conn net.Conn) (scheme string) {
|
||||||
|
switch conn.(type) {
|
||||||
|
case *tls.Conn:
|
||||||
|
scheme = "https"
|
||||||
|
default:
|
||||||
|
scheme = "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,179 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/koding/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type listener struct {
|
||||||
|
net.Listener
|
||||||
|
*vaddrOptions
|
||||||
|
|
||||||
|
done int32
|
||||||
|
|
||||||
|
// ips keeps track of registered clients for ip-based routing;
|
||||||
|
// when last client is deleted from the ip routing map, we stop
|
||||||
|
// listening on connections
|
||||||
|
ips map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type vaddrOptions struct {
|
||||||
|
connCh chan<- net.Conn
|
||||||
|
log logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type vaddrStorage struct {
|
||||||
|
*vaddrOptions
|
||||||
|
|
||||||
|
listeners map[net.Listener]*listener
|
||||||
|
ports map[int]string // port-based routing: maps port number to identifier
|
||||||
|
ips map[string]string // ip-based routing: maps ip address to identifier
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVirtualAddrs(opts *vaddrOptions) *vaddrStorage {
|
||||||
|
return &vaddrStorage{
|
||||||
|
vaddrOptions: opts,
|
||||||
|
listeners: make(map[net.Listener]*listener),
|
||||||
|
ports: make(map[int]string),
|
||||||
|
ips: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *listener) serve() {
|
||||||
|
for {
|
||||||
|
conn, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
l.log.Error("failue listening on %q: %s", l.Addr(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if atomic.LoadInt32(&l.done) != 0 {
|
||||||
|
l.log.Debug("stopped serving %q", l.Addr())
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.connCh <- conn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *listener) localAddr() string {
|
||||||
|
if addr, ok := l.Addr().(*net.TCPAddr); ok {
|
||||||
|
if addr.IP.Equal(net.IPv4zero) {
|
||||||
|
return net.JoinHostPort("127.0.0.1", strconv.Itoa(addr.Port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return l.Addr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *listener) stop() {
|
||||||
|
if atomic.CompareAndSwapInt32(&l.done, 0, 1) {
|
||||||
|
// stop is called when no more connections should be accepted by
|
||||||
|
// the user-provided listener; as we can't simple close the listener
|
||||||
|
// to not break the guarantee given by the (*Server).DeleteAddr
|
||||||
|
// method, we make a dummy connection to break out of serve loop.
|
||||||
|
// It is safe to make a dummy connection, as either the following
|
||||||
|
// dial will time out when the listener is busy accepting connections,
|
||||||
|
// or will get closed immadiately after idle listeners accepts connection
|
||||||
|
// and returns from the serve loop.
|
||||||
|
conn, err := net.DialTimeout("tcp", l.localAddr(), defaultTimeout)
|
||||||
|
if err == nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vaddr *vaddrStorage) Add(l net.Listener, ip net.IP, ident string) {
|
||||||
|
vaddr.mu.Lock()
|
||||||
|
defer vaddr.mu.Unlock()
|
||||||
|
|
||||||
|
lis, ok := vaddr.listeners[l]
|
||||||
|
if !ok {
|
||||||
|
lis = vaddr.newListener(l)
|
||||||
|
vaddr.listeners[l] = lis
|
||||||
|
go lis.serve()
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip != nil {
|
||||||
|
lis.ips[ip.String()] = struct{}{}
|
||||||
|
vaddr.ips[ip.String()] = ident
|
||||||
|
} else {
|
||||||
|
vaddr.ports[mustPort(l)] = ident
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vaddr *vaddrStorage) Delete(l net.Listener, ip net.IP) {
|
||||||
|
vaddr.mu.Lock()
|
||||||
|
defer vaddr.mu.Unlock()
|
||||||
|
|
||||||
|
lis, ok := vaddr.listeners[l]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var stop bool
|
||||||
|
|
||||||
|
if ip != nil {
|
||||||
|
delete(lis.ips, ip.String())
|
||||||
|
delete(vaddr.ips, ip.String())
|
||||||
|
|
||||||
|
stop = len(lis.ips) == 0
|
||||||
|
} else {
|
||||||
|
delete(vaddr.ports, mustPort(l))
|
||||||
|
|
||||||
|
stop = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only stop listening for connections when listener has clients
|
||||||
|
// registered to tunnel the connections to.
|
||||||
|
if stop {
|
||||||
|
lis.stop()
|
||||||
|
delete(vaddr.listeners, l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vaddr *vaddrStorage) newListener(l net.Listener) *listener {
|
||||||
|
return &listener{
|
||||||
|
Listener: l,
|
||||||
|
vaddrOptions: vaddr.vaddrOptions,
|
||||||
|
ips: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vaddr *vaddrStorage) getIdent(conn net.Conn) (string, bool) {
|
||||||
|
vaddr.mu.Lock()
|
||||||
|
defer vaddr.mu.Unlock()
|
||||||
|
|
||||||
|
ip, port, err := parseHostPort(conn.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
vaddr.log.Debug("failed to get identifier for connection %q: %s", conn.LocalAddr(), err)
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// First lookup if there's a ip-based route, then try port-base one.
|
||||||
|
|
||||||
|
if ident, ok := vaddr.ips[ip]; ok {
|
||||||
|
return ident, true
|
||||||
|
}
|
||||||
|
|
||||||
|
ident, ok := vaddr.ports[port]
|
||||||
|
return ident, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustPort(l net.Listener) int {
|
||||||
|
_, port, err := parseHostPort(l.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
// This can happened when user passed custom type that
|
||||||
|
// implements net.Listener, which returns ill-formed
|
||||||
|
// net.Addr value.
|
||||||
|
panic("ill-formed net.Addr: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return port
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type vhostStorage interface {
|
||||||
|
// AddHost adds the given host and identifier to the storage
|
||||||
|
AddHost(host, identifier string)
|
||||||
|
|
||||||
|
// DeleteHost deletes the given host
|
||||||
|
DeleteHost(host string)
|
||||||
|
|
||||||
|
// GetHost returns the host name for the given identifier
|
||||||
|
GetHost(identifier string) (string, bool)
|
||||||
|
|
||||||
|
// GetIdentifier returns the identifier for the given host
|
||||||
|
GetIdentifier(host string) (string, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type virtualHost struct {
|
||||||
|
identifier string
|
||||||
|
}
|
||||||
|
|
||||||
|
// virtualHosts is used for mapping host to users example: host
|
||||||
|
// "fs-1-fatih.kd.io" belongs to user "arslan"
|
||||||
|
type virtualHosts struct {
|
||||||
|
mapping map[string]*virtualHost
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// newVirtualHosts provides an in memory virtual host storage for mapping
|
||||||
|
// virtual hosts to identifiers.
|
||||||
|
func newVirtualHosts() *virtualHosts {
|
||||||
|
return &virtualHosts{
|
||||||
|
mapping: make(map[string]*virtualHost),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *virtualHosts) AddHost(host, identifier string) {
|
||||||
|
v.Lock()
|
||||||
|
v.mapping[host] = &virtualHost{identifier: identifier}
|
||||||
|
v.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *virtualHosts) DeleteHost(host string) {
|
||||||
|
v.Lock()
|
||||||
|
delete(v.mapping, host)
|
||||||
|
v.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIdentifier returns the identifier associated with the given host
|
||||||
|
func (v *virtualHosts) GetIdentifier(host string) (string, bool) {
|
||||||
|
v.Lock()
|
||||||
|
ht, ok := v.mapping[host]
|
||||||
|
v.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return ht.identifier, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHost returns the host associated with the given identifier
|
||||||
|
func (v *virtualHosts) GetHost(identifier string) (string, bool) {
|
||||||
|
v.Lock()
|
||||||
|
defer v.Unlock()
|
||||||
|
|
||||||
|
for hostname, hst := range v.mapping {
|
||||||
|
if hst.identifier == identifier {
|
||||||
|
return hostname, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
|
@ -0,0 +1,69 @@
|
||||||
|
package tunnel_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testWebsocket(name string, n int, t *testing.T, tt *tunneltest.TunnelTest) {
|
||||||
|
conn, err := websocketDial(tt, "http")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial()=%s", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
want := &EchoMessage{
|
||||||
|
Value: fmt.Sprintf("message #%d", i),
|
||||||
|
Close: i == (n - 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := conn.WriteJSON(want)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("(test %s) %d: failed sending %q: %s", name, i, want, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
got := &EchoMessage{}
|
||||||
|
|
||||||
|
err = conn.ReadJSON(got)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("(test %s) %d: failed reading: %s", name, i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("(test %s) %d: got %+v, want %+v", name, i, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testHandler(t *testing.T, fn func(w http.ResponseWriter, r *http.Request) error) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := fn(w, r); err != nil {
|
||||||
|
t.Errorf("handler func error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocket(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(nil))))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testWebsocket("handlerEchoWS", 100, t, tt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLatencyWebsocket(t *testing.T) {
|
||||||
|
tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(sleep))))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testWebsocket("handlerLatencyEchoWS", 20, t, tt)
|
||||||
|
}
|
Loading…
Reference in New Issue