remove tunnel
This commit is contained in:
parent
eac3883186
commit
4ea2bd35be
|
@ -1,19 +0,0 @@
|
||||||
language: go
|
|
||||||
|
|
||||||
sudo: false
|
|
||||||
|
|
||||||
addons:
|
|
||||||
apt:
|
|
||||||
packages:
|
|
||||||
- moreutils
|
|
||||||
|
|
||||||
go:
|
|
||||||
- 1.4.3
|
|
||||||
- 1.6.3
|
|
||||||
- 1.7
|
|
||||||
|
|
||||||
script:
|
|
||||||
- export GOMAXPROCS=$(nproc)
|
|
||||||
- gofmt -s -l . | ifne false
|
|
||||||
- go build ./...
|
|
||||||
- go test -race ./...
|
|
|
@ -1,28 +0,0 @@
|
||||||
Copyright (c) 2015 The Koding Authors.
|
|
||||||
All rights reserved.
|
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions are met:
|
|
||||||
|
|
||||||
* Redistributions of source code must retain the above copyright notice, this
|
|
||||||
list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
* Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
this list of conditions and the following disclaimer in the documentation
|
|
||||||
and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
* Neither the name of Koding Inc. nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived from
|
|
||||||
this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
||||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
||||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
||||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
|
@ -1,91 +0,0 @@
|
||||||
# Tunnel [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/koding/tunnel) [![Go Report Card](https://goreportcard.com/badge/github.com/koding/tunnel)](https://goreportcard.com/report/github.com/koding/tunnel) [![Build Status](http://img.shields.io/travis/koding/tunnel.svg?style=flat-square)](https://travis-ci.org/koding/tunnel)
|
|
||||||
|
|
||||||
Tunnel is a server/client package that enables to proxy public connections to
|
|
||||||
your local machine over a tunnel connection from the local machine to the
|
|
||||||
public server. What this means is, you can share your localhost even if it
|
|
||||||
doesn't have a Public IP or if it's not reachable from outside.
|
|
||||||
|
|
||||||
It uses the excellent [yamux](https://github.com/hashicorp/yamux) package to
|
|
||||||
multiplex connections between server and client.
|
|
||||||
|
|
||||||
The project is under active development, please vendor it if you want to use it.
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
|
|
||||||
The tunnel package consists of two parts. The `server` and the `client`.
|
|
||||||
|
|
||||||
Server is the public facing part. It's type that satisfies the `http.Handler`.
|
|
||||||
So it's easily pluggable into existing servers.
|
|
||||||
|
|
||||||
|
|
||||||
Let assume that you setup your DNS service so all `*.example.com` domains route
|
|
||||||
to your server at the public IP `203.0.113.0`. Let us first create the server
|
|
||||||
part:
|
|
||||||
|
|
||||||
```go
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/koding/tunnel"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
cfg := &tunnel.ServerConfig{}
|
|
||||||
server, _ := tunnel.NewServer(cfg)
|
|
||||||
server.AddHost("sub.example.com", "1234")
|
|
||||||
http.ListenAndServe(":80", server)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Once you create the `server`, you just plug it into your server. The only
|
|
||||||
detail here is to map a virtualhost to a secret token. The secret token is the
|
|
||||||
only part that needs to be known for the client side.
|
|
||||||
|
|
||||||
Let us now create the client side part:
|
|
||||||
|
|
||||||
```go
|
|
||||||
package main
|
|
||||||
|
|
||||||
import "github.com/koding/tunnel"
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
cfg := &tunnel.ClientConfig{
|
|
||||||
Identifier: "1234",
|
|
||||||
ServerAddr: "203.0.113.0:80",
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := tunnel.NewClient(cfg)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client.Start()
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
The `Start()` method is by default blocking. As you see you, we just passed the
|
|
||||||
server address and the secret token.
|
|
||||||
|
|
||||||
Now whenever someone hit `sub.example.com`, the request will be proxied to the
|
|
||||||
machine where client is running and hit the local server running `127.0.0.1:80`
|
|
||||||
(assuming there is one). If someone hits `sub.example.com:3000` (assume your
|
|
||||||
server is running at this port), it'll be routed to `127.0.0.1:3000`
|
|
||||||
|
|
||||||
That's it.
|
|
||||||
|
|
||||||
There are many options that can be changed, such as a static local address for
|
|
||||||
your client. Have alook at the
|
|
||||||
[documentation](http://godoc.org/github.com/koding/tunnel)
|
|
||||||
|
|
||||||
|
|
||||||
# Protocol
|
|
||||||
|
|
||||||
The server/client protocol is written in the [spec.md](spec.md) file. Please
|
|
||||||
have a look for more detail.
|
|
||||||
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
The BSD 3-Clause License - see LICENSE for more details
|
|
|
@ -1,564 +0,0 @@
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
|
||||||
"github.com/hashicorp/yamux"
|
|
||||||
"github.com/koding/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:generate stringer -type ClientState
|
|
||||||
|
|
||||||
// ErrRedialAborted is emitted on ClientClosed event, when backoff policy
|
|
||||||
// used by a client decided no more reconnection attempts must be made.
|
|
||||||
var ErrRedialAborted = errors.New("unable to restore the connection, aborting")
|
|
||||||
|
|
||||||
// ClientState represents client connection state to tunnel server.
|
|
||||||
type ClientState uint32
|
|
||||||
|
|
||||||
// ClientState enumeration.
|
|
||||||
const (
|
|
||||||
ClientUnknown ClientState = iota
|
|
||||||
ClientStarted
|
|
||||||
ClientConnecting
|
|
||||||
ClientConnected
|
|
||||||
ClientDisconnected
|
|
||||||
ClientClosed // keep it always last
|
|
||||||
)
|
|
||||||
|
|
||||||
// ClientStateChange represents single client state transition.
|
|
||||||
type ClientStateChange struct {
|
|
||||||
Identifier string
|
|
||||||
Previous ClientState
|
|
||||||
Current ClientState
|
|
||||||
Error error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strings implements the fmt.Stringer interface.
|
|
||||||
func (cs *ClientStateChange) String() string {
|
|
||||||
if cs.Error != nil {
|
|
||||||
return fmt.Sprintf("[%s] %s->%s (%s)", cs.Identifier, cs.Previous, cs.Current, cs.Error)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("[%s] %s->%s", cs.Identifier, cs.Previous, cs.Current)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backoff defines behavior of staggering reconnection retries.
|
|
||||||
type Backoff interface {
|
|
||||||
// Next returns the duration to sleep before retrying reconnections.
|
|
||||||
// If the returned value is negative, the retry is aborted.
|
|
||||||
NextBackOff() time.Duration
|
|
||||||
|
|
||||||
// Reset is used to signal a reconnection was successful and next
|
|
||||||
// call to Next should return desired time duration for 1st reconnection
|
|
||||||
// attempt.
|
|
||||||
Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client is responsible for creating a control connection to a tunnel server,
|
|
||||||
// creating new tunnels and proxy them to tunnel server.
|
|
||||||
type Client struct {
|
|
||||||
// underlying yamux session
|
|
||||||
session *yamux.Session
|
|
||||||
|
|
||||||
// config holds the ClientConfig
|
|
||||||
config *ClientConfig
|
|
||||||
|
|
||||||
// yamuxConfig is passed to new yamux.Session's
|
|
||||||
yamuxConfig *yamux.Config
|
|
||||||
|
|
||||||
// proxy handles local server communication.
|
|
||||||
proxy ProxyFunc
|
|
||||||
|
|
||||||
// startNotify is a chanel user can get to be notified when client is
|
|
||||||
// connected to the server. The preferred way of doing this however,
|
|
||||||
// would be using StateChanges in ClientConfig where user can provide
|
|
||||||
// his own channel.
|
|
||||||
startNotify chan bool
|
|
||||||
// closed is a flag set when client calls Close() and quits.
|
|
||||||
closed bool
|
|
||||||
// closedMu guards both closed flag and startNotify channel. Since library
|
|
||||||
// owns the channel it's cleared when trying to reconnect.
|
|
||||||
closedMu sync.RWMutex
|
|
||||||
|
|
||||||
reqWg sync.WaitGroup
|
|
||||||
ctrlWg sync.WaitGroup
|
|
||||||
|
|
||||||
state ClientState
|
|
||||||
|
|
||||||
// redialBackoff is used to reconnect in exponential backoff intervals
|
|
||||||
redialBackoff Backoff
|
|
||||||
|
|
||||||
log logging.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientConfig defines the configuration for the Client
|
|
||||||
type ClientConfig struct {
|
|
||||||
// Identifier is the secret token that needs to be passed to the server.
|
|
||||||
// Required if FetchIdentifier is not set.
|
|
||||||
Identifier string
|
|
||||||
|
|
||||||
// FetchIdentifier can be used to fetch identifier. Required if Identifier
|
|
||||||
// is not set.
|
|
||||||
FetchIdentifier func() (string, error)
|
|
||||||
|
|
||||||
// ServerAddr defines the TCP address of the tunnel server to be connected.
|
|
||||||
// Required if FetchServerAddr is not set.
|
|
||||||
ServerAddr string
|
|
||||||
|
|
||||||
// FetchServerAddr can be used to fetch tunnel server address.
|
|
||||||
// Required if ServerAddress is not set.
|
|
||||||
FetchServerAddr func() (string, error)
|
|
||||||
|
|
||||||
// Dial provides custom transport layer for client server communication.
|
|
||||||
//
|
|
||||||
// If nil, default implementation is to return net.Dial("tcp", address).
|
|
||||||
//
|
|
||||||
// It can be used for connection monitoring, setting different timeouts or
|
|
||||||
// securing the connection.
|
|
||||||
Dial func(network, address string) (net.Conn, error)
|
|
||||||
|
|
||||||
// Proxy defines custom proxing logic. This is optional extension point
|
|
||||||
// where you can provide your local server selection or communication rules.
|
|
||||||
Proxy ProxyFunc
|
|
||||||
|
|
||||||
// StateChanges receives state transition details each time client
|
|
||||||
// connection state changes. The channel is expected to be sufficiently
|
|
||||||
// buffered to keep up with event pace.
|
|
||||||
//
|
|
||||||
// If nil, no information about state transitions are dispatched
|
|
||||||
// by the library.
|
|
||||||
StateChanges chan<- *ClientStateChange
|
|
||||||
|
|
||||||
// Backoff is used to control behavior of staggering reconnection loop.
|
|
||||||
//
|
|
||||||
// If nil, default backoff policy is used which makes a client to never
|
|
||||||
// give up on reconnection.
|
|
||||||
//
|
|
||||||
// If custom backoff is used, client will emit ErrRedialAborted set
|
|
||||||
// with ClientClosed event when no more reconnection atttemps should
|
|
||||||
// be made.
|
|
||||||
Backoff Backoff
|
|
||||||
|
|
||||||
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
|
|
||||||
// yamux.DefaultConfig() is used.
|
|
||||||
YamuxConfig *yamux.Config
|
|
||||||
|
|
||||||
// Log defines the logger. If nil a default logging.Logger is used.
|
|
||||||
Log logging.Logger
|
|
||||||
|
|
||||||
// Debug enables debug mode, enable only if you want to debug the server.
|
|
||||||
Debug bool
|
|
||||||
|
|
||||||
// DEPRECATED:
|
|
||||||
|
|
||||||
// LocalAddr is DEPRECATED please use ProxyHTTP.LocalAddr, see ProxyOverwrite for more details.
|
|
||||||
LocalAddr string
|
|
||||||
|
|
||||||
// FetchLocalAddr is DEPRECATED please use ProxyTCP.FetchLocalAddr, see ProxyOverwrite for more details.
|
|
||||||
FetchLocalAddr func(port int) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify is used to verify the ClientConfig
|
|
||||||
func (c *ClientConfig) verify() error {
|
|
||||||
if c.ServerAddr == "" && c.FetchServerAddr == nil {
|
|
||||||
return errors.New("neither ServerAddr nor FetchServerAddr is set")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Identifier == "" && c.FetchIdentifier == nil {
|
|
||||||
return errors.New("neither Identifier nor FetchIdentifier is set")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.YamuxConfig != nil {
|
|
||||||
if err := yamux.VerifyConfig(c.YamuxConfig); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Proxy != nil && (c.LocalAddr != "" || c.FetchLocalAddr != nil) {
|
|
||||||
return errors.New("both Proxy and LocalAddr or FetchLocalAddr are set")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a new tunnel that is established between the serverAddr
|
|
||||||
// and localAddr. It exits if it can't create a new control connection to the
|
|
||||||
// server. If localAddr is empty client will always try to proxy to a local
|
|
||||||
// port.
|
|
||||||
func NewClient(cfg *ClientConfig) (*Client, error) {
|
|
||||||
if err := cfg.verify(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
yamuxConfig := yamux.DefaultConfig()
|
|
||||||
if cfg.YamuxConfig != nil {
|
|
||||||
yamuxConfig = cfg.YamuxConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
var proxy = DefaultProxy
|
|
||||||
if cfg.Proxy != nil {
|
|
||||||
proxy = cfg.Proxy
|
|
||||||
}
|
|
||||||
// DEPRECATED API SUPPORT
|
|
||||||
if cfg.LocalAddr != "" || cfg.FetchLocalAddr != nil {
|
|
||||||
var f ProxyFuncs
|
|
||||||
if cfg.LocalAddr != "" {
|
|
||||||
f.HTTP = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
|
|
||||||
f.WS = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
|
|
||||||
}
|
|
||||||
if cfg.FetchLocalAddr != nil {
|
|
||||||
f.TCP = (&TCPProxy{FetchLocalAddr: cfg.FetchLocalAddr}).Proxy
|
|
||||||
}
|
|
||||||
proxy = Proxy(f)
|
|
||||||
}
|
|
||||||
|
|
||||||
var bo Backoff = newForeverBackoff()
|
|
||||||
if cfg.Backoff != nil {
|
|
||||||
bo = cfg.Backoff
|
|
||||||
}
|
|
||||||
|
|
||||||
log := newLogger("tunnel-client", cfg.Debug)
|
|
||||||
if cfg.Log != nil {
|
|
||||||
log = cfg.Log
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &Client{
|
|
||||||
config: cfg,
|
|
||||||
yamuxConfig: yamuxConfig,
|
|
||||||
proxy: proxy,
|
|
||||||
startNotify: make(chan bool, 1),
|
|
||||||
redialBackoff: bo,
|
|
||||||
log: log,
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the client and connects to the server with the identifier.
|
|
||||||
// client.FetchIdentifier() will be used if it's not nil. It's supports
|
|
||||||
// reconnecting with exponential backoff intervals when the connection to the
|
|
||||||
// server disconnects. Call client.Close() to shutdown the client completely. A
|
|
||||||
// successful connection will cause StartNotify() to receive a value.
|
|
||||||
func (c *Client) Start() {
|
|
||||||
fetchIdent := func() (string, error) {
|
|
||||||
if c.config.FetchIdentifier != nil {
|
|
||||||
return c.config.FetchIdentifier()
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.config.Identifier, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fetchServerAddr := func() (string, error) {
|
|
||||||
if c.config.FetchServerAddr != nil {
|
|
||||||
return c.config.FetchServerAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.config.ServerAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c.changeState(ClientStarted, nil)
|
|
||||||
|
|
||||||
c.redialBackoff.Reset()
|
|
||||||
var lastErr error
|
|
||||||
for {
|
|
||||||
prev := c.changeState(ClientConnecting, lastErr)
|
|
||||||
|
|
||||||
if c.isRetry(prev) {
|
|
||||||
dur := c.redialBackoff.NextBackOff()
|
|
||||||
if dur < 0 {
|
|
||||||
c.setClosed(true)
|
|
||||||
c.changeState(ClientClosed, ErrRedialAborted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(dur)
|
|
||||||
|
|
||||||
// exit if closed
|
|
||||||
if c.isClosed() {
|
|
||||||
c.changeState(ClientClosed, lastErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
identifier, err := fetchIdent()
|
|
||||||
if err != nil {
|
|
||||||
lastErr = err
|
|
||||||
c.log.Critical("client fetch identifier error: %s", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
serverAddr, err := fetchServerAddr()
|
|
||||||
if err != nil {
|
|
||||||
lastErr = err
|
|
||||||
c.log.Critical("client fetch server address error: %s", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
c.setClosed(false)
|
|
||||||
|
|
||||||
if err := c.connect(identifier, serverAddr); err != nil {
|
|
||||||
lastErr = err
|
|
||||||
c.log.Debug("client connect error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// exit if closed
|
|
||||||
if c.isClosed() {
|
|
||||||
c.changeState(ClientClosed, lastErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the client and shutdowns the connection to the tunnel server
|
|
||||||
func (c *Client) Close() error {
|
|
||||||
defer c.setClosed(true)
|
|
||||||
|
|
||||||
if c.session == nil {
|
|
||||||
return errors.New("session is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until all connections are finished
|
|
||||||
waitCh := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
if err := c.session.GoAway(); err != nil {
|
|
||||||
c.log.Debug("Session go away failed: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.reqWg.Wait()
|
|
||||||
close(waitCh)
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case <-waitCh:
|
|
||||||
// ok
|
|
||||||
case <-time.After(time.Second * 10):
|
|
||||||
c.log.Info("Timeout waiting for connections to finish")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.session.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isClosed securely checks if client is marked as closed.
|
|
||||||
func (c *Client) isClosed() bool {
|
|
||||||
c.closedMu.RLock()
|
|
||||||
defer c.closedMu.RUnlock()
|
|
||||||
return c.closed
|
|
||||||
}
|
|
||||||
|
|
||||||
// setClosed securely marks client as closed (or not closed). If not closed
|
|
||||||
// also empty the value inside the startNotify channel by retrieving it (if any),
|
|
||||||
// so it doesn't block during connect, when the client was closed and started again,
|
|
||||||
// and startNotify was never listened to.
|
|
||||||
func (c *Client) setClosed(closed bool) {
|
|
||||||
c.closedMu.Lock()
|
|
||||||
defer c.closedMu.Unlock()
|
|
||||||
c.closed = closed
|
|
||||||
|
|
||||||
if !closed {
|
|
||||||
// clear channel
|
|
||||||
select {
|
|
||||||
case <-c.startNotify:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// startNotifyIfNeeded sends ok to startNotify channel if it's listened to.
|
|
||||||
// This function is called by connect when connection was successful.
|
|
||||||
func (c *Client) startNotifyIfNeeded() {
|
|
||||||
c.closedMu.RLock()
|
|
||||||
if !c.closed {
|
|
||||||
c.log.Debug("sending ok to startNotify chan")
|
|
||||||
select {
|
|
||||||
case c.startNotify <- true:
|
|
||||||
default:
|
|
||||||
// reaching here means the client never read the signal via
|
|
||||||
// StartNotify(). This is OK, we shouldn't except it the consumer
|
|
||||||
// to read from this channel. It's optional, so we just drop the
|
|
||||||
// signal.
|
|
||||||
c.log.Debug("startNotify message was dropped")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.closedMu.RUnlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartNotify returns a channel that receives a single value when the client
|
|
||||||
// established a successful connection to the server.
|
|
||||||
func (c *Client) StartNotify() <-chan bool {
|
|
||||||
return c.startNotify
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) changeState(state ClientState, err error) (prev ClientState) {
|
|
||||||
prev = ClientState(atomic.LoadUint32((*uint32)(&c.state)))
|
|
||||||
|
|
||||||
if c.config.StateChanges != nil {
|
|
||||||
change := &ClientStateChange{
|
|
||||||
Identifier: c.config.Identifier,
|
|
||||||
Previous: ClientState(prev),
|
|
||||||
Current: state,
|
|
||||||
Error: err,
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case c.config.StateChanges <- change:
|
|
||||||
default:
|
|
||||||
c.log.Warning("Dropping state change due to slow reader: %s", change)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
atomic.CompareAndSwapUint32((*uint32)(&c.state), uint32(prev), uint32(state))
|
|
||||||
|
|
||||||
return prev
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) isRetry(state ClientState) bool {
|
|
||||||
return state != ClientStarted && state != ClientClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) connect(identifier, serverAddr string) error {
|
|
||||||
c.log.Debug("Trying to connect to %q with identifier %q", serverAddr, identifier)
|
|
||||||
|
|
||||||
conn, err := c.dial(serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteURL := controlURL(conn)
|
|
||||||
c.log.Debug("CONNECT to %q", remoteURL)
|
|
||||||
req, err := http.NewRequest("CONNECT", remoteURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error creating request to %s: %s", remoteURL, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set(proto.ClientIdentifierHeader, identifier)
|
|
||||||
|
|
||||||
c.log.Debug("Writing request to TCP: %+v", req)
|
|
||||||
|
|
||||||
if err := req.Write(conn); err != nil {
|
|
||||||
return fmt.Errorf("writing CONNECT request to %s failed: %s", req.URL, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.log.Debug("Reading response from TCP")
|
|
||||||
|
|
||||||
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("reading CONNECT response from %s failed: %s", req.URL, err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK || resp.Status != proto.Connected {
|
|
||||||
out, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("tunnel server error: status=%d, error=%s", resp.StatusCode, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("tunnel server error: status=%d, body=%s", resp.StatusCode, string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
c.ctrlWg.Wait() // wait until previous listenControl observes disconnection
|
|
||||||
|
|
||||||
c.session, err = yamux.Client(conn, c.yamuxConfig)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("session initialization failed: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var stream net.Conn
|
|
||||||
openStream := func() error {
|
|
||||||
// this is blocking until client opens a session to us
|
|
||||||
stream, err = c.session.Open()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we don't receive anything from the server, we'll timeout
|
|
||||||
select {
|
|
||||||
case err := <-async(openStream):
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("waiting for session to open failed: %s", err)
|
|
||||||
}
|
|
||||||
case <-time.After(time.Second * 10):
|
|
||||||
if stream != nil {
|
|
||||||
stream.Close()
|
|
||||||
}
|
|
||||||
return errors.New("timeout opening session")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := stream.Write([]byte(proto.HandshakeRequest)); err != nil {
|
|
||||||
return fmt.Errorf("writing handshake request failed: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, len(proto.HandshakeResponse))
|
|
||||||
if _, err := stream.Read(buf); err != nil {
|
|
||||||
return fmt.Errorf("reading handshake response failed: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(buf) != proto.HandshakeResponse {
|
|
||||||
return fmt.Errorf("invalid handshake response, received: %s", string(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
ct := newControl(stream)
|
|
||||||
c.log.Debug("client has started successfully")
|
|
||||||
c.redialBackoff.Reset() // we successfully connected, so we can reset the backoff
|
|
||||||
|
|
||||||
c.startNotifyIfNeeded()
|
|
||||||
|
|
||||||
return c.listenControl(ct)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) dial(serverAddr string) (net.Conn, error) {
|
|
||||||
if c.config.Dial != nil {
|
|
||||||
return c.config.Dial("tcp", serverAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.Dial("tcp", serverAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) listenControl(ct *control) error {
|
|
||||||
c.ctrlWg.Add(1)
|
|
||||||
defer c.ctrlWg.Done()
|
|
||||||
|
|
||||||
c.changeState(ClientConnected, nil)
|
|
||||||
|
|
||||||
for {
|
|
||||||
var msg proto.ControlMessage
|
|
||||||
if err := ct.dec.Decode(&msg); err != nil {
|
|
||||||
c.reqWg.Wait() // wait until all requests are finished
|
|
||||||
c.session.GoAway()
|
|
||||||
c.session.Close()
|
|
||||||
c.changeState(ClientDisconnected, err)
|
|
||||||
|
|
||||||
return fmt.Errorf("failure decoding control message: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.log.Debug("Received control msg %+v", msg)
|
|
||||||
c.log.Debug("Opening a new stream from server session")
|
|
||||||
|
|
||||||
remote, err := c.session.Open()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
isHTTP := msg.Protocol == proto.HTTP
|
|
||||||
if isHTTP {
|
|
||||||
c.reqWg.Add(1)
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
c.proxy(remote, &msg)
|
|
||||||
if isHTTP {
|
|
||||||
c.reqWg.Done()
|
|
||||||
}
|
|
||||||
remote.Close()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,16 +0,0 @@
|
||||||
// Code generated by "stringer -type ClientState"; DO NOT EDIT
|
|
||||||
|
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
const _ClientState_name = "ClientUnknownClientStartedClientConnectingClientConnectedClientDisconnectedClientClosed"
|
|
||||||
|
|
||||||
var _ClientState_index = [...]uint8{0, 13, 26, 42, 57, 75, 87}
|
|
||||||
|
|
||||||
func (i ClientState) String() string {
|
|
||||||
if i >= ClientState(len(_ClientState_index)-1) {
|
|
||||||
return fmt.Sprintf("ClientState(%d)", i)
|
|
||||||
}
|
|
||||||
return _ClientState_name[_ClientState_index[i]:_ClientState_index[i+1]]
|
|
||||||
}
|
|
|
@ -1,110 +0,0 @@
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
var errControlClosed = errors.New("control connection is closed")
|
|
||||||
|
|
||||||
type control struct {
|
|
||||||
// enc and dec are responsible for encoding and decoding json values forth
|
|
||||||
// and back
|
|
||||||
enc *json.Encoder
|
|
||||||
dec *json.Decoder
|
|
||||||
|
|
||||||
// underlying connection responsible for encoder and decoder
|
|
||||||
nc net.Conn
|
|
||||||
|
|
||||||
// identifier associated with this control
|
|
||||||
identifier string
|
|
||||||
|
|
||||||
mu sync.Mutex // guards the following
|
|
||||||
closed bool // if Close() and quits
|
|
||||||
}
|
|
||||||
|
|
||||||
func newControl(nc net.Conn) *control {
|
|
||||||
c := &control{
|
|
||||||
enc: json.NewEncoder(nc),
|
|
||||||
dec: json.NewDecoder(nc),
|
|
||||||
nc: nc,
|
|
||||||
}
|
|
||||||
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *control) send(v interface{}) error {
|
|
||||||
if c.enc == nil {
|
|
||||||
return errors.New("encoder is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mu.Lock()
|
|
||||||
if c.closed {
|
|
||||||
c.mu.Unlock()
|
|
||||||
return errControlClosed
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
return c.enc.Encode(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *control) recv(v interface{}) error {
|
|
||||||
if c.dec == nil {
|
|
||||||
return errors.New("decoder is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mu.Lock()
|
|
||||||
if c.closed {
|
|
||||||
c.mu.Unlock()
|
|
||||||
return errControlClosed
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
return c.dec.Decode(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *control) Close() error {
|
|
||||||
if c.nc == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mu.Lock()
|
|
||||||
c.closed = true
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
return c.nc.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
type controls struct {
|
|
||||||
sync.Mutex
|
|
||||||
controls map[string]*control
|
|
||||||
}
|
|
||||||
|
|
||||||
func newControls() *controls {
|
|
||||||
return &controls{
|
|
||||||
controls: make(map[string]*control),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *controls) getControl(identifier string) (*control, bool) {
|
|
||||||
c.Lock()
|
|
||||||
control, ok := c.controls[identifier]
|
|
||||||
c.Unlock()
|
|
||||||
return control, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *controls) addControl(identifier string, control *control) {
|
|
||||||
control.identifier = identifier
|
|
||||||
|
|
||||||
c.Lock()
|
|
||||||
c.controls[identifier] = control
|
|
||||||
c.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *controls) deleteControl(identifier string) {
|
|
||||||
c.Lock()
|
|
||||||
delete(c.controls, identifier)
|
|
||||||
c.Unlock()
|
|
||||||
}
|
|
|
@ -1,262 +0,0 @@
|
||||||
package tunnel_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel"
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
rand.Seed(time.Now().UnixNano() + int64(os.Getpid()))
|
|
||||||
}
|
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{
|
|
||||||
ReadBufferSize: 1024,
|
|
||||||
WriteBufferSize: 1024,
|
|
||||||
}
|
|
||||||
|
|
||||||
type EchoMessage struct {
|
|
||||||
Value string `json:"value,omitempty"`
|
|
||||||
Close bool `json:"close,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var timeout = 10 * time.Second
|
|
||||||
|
|
||||||
var dialer = &websocket.Dialer{
|
|
||||||
ReadBufferSize: 1024,
|
|
||||||
WriteBufferSize: 1024,
|
|
||||||
HandshakeTimeout: timeout,
|
|
||||||
NetDial: func(_, addr string) (net.Conn, error) {
|
|
||||||
return net.DialTimeout("tcp4", addr, timeout)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func echoHTTP(tt *tunneltest.TunnelTest, echo string) (string, error) {
|
|
||||||
req := tt.Request("http", url.Values{"echo": []string{echo}})
|
|
||||||
if req == nil {
|
|
||||||
return "", fmt.Errorf(`tunnel "http" does not exist`)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Close = rand.Int()%2 == 0
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
p, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(bytes.TrimSpace(p)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func echoTCP(tt *tunneltest.TunnelTest, echo string) (string, error) {
|
|
||||||
return echoTCPIdent(tt, echo, "tcp")
|
|
||||||
}
|
|
||||||
|
|
||||||
func echoTCPIdent(tt *tunneltest.TunnelTest, echo, ident string) (string, error) {
|
|
||||||
addr := tt.Addr(ident)
|
|
||||||
if addr == nil {
|
|
||||||
return "", fmt.Errorf("tunnel %q does not exist", ident)
|
|
||||||
}
|
|
||||||
s := addr.String()
|
|
||||||
ip := tt.Tunnels[ident].IP
|
|
||||||
if ip != nil {
|
|
||||||
_, port, err := net.SplitHostPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
s = net.JoinHostPort(ip.String(), port)
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := dialTCP(s)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.out <- echo
|
|
||||||
|
|
||||||
select {
|
|
||||||
case reply := <-c.in:
|
|
||||||
return reply, nil
|
|
||||||
case <-time.After(tcpTimeout):
|
|
||||||
return "", fmt.Errorf("timed out waiting for reply from %s (%s) after %v", s, addr, tcpTimeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func websocketDial(tt *tunneltest.TunnelTest, ident string) (*websocket.Conn, error) {
|
|
||||||
req := tt.Request(ident, nil)
|
|
||||||
if req == nil {
|
|
||||||
return nil, fmt.Errorf("no client found for ident %q", ident)
|
|
||||||
}
|
|
||||||
|
|
||||||
h := http.Header{"Host": {req.Host}}
|
|
||||||
wsurl := fmt.Sprintf("ws://%s", tt.ServerAddr())
|
|
||||||
|
|
||||||
conn, _, err := dialer.Dial(wsurl, h)
|
|
||||||
return conn, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func sleep() {
|
|
||||||
time.Sleep(time.Duration(rand.Intn(2000)) * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handlerEchoWS(sleepFn func()) func(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) (e error) {
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err := conn.Close()
|
|
||||||
if e == nil {
|
|
||||||
e = err
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if sleepFn != nil {
|
|
||||||
sleepFn()
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
var msg EchoMessage
|
|
||||||
err := conn.ReadJSON(&msg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ReadJSON error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sleepFn != nil {
|
|
||||||
sleepFn()
|
|
||||||
}
|
|
||||||
|
|
||||||
err = conn.WriteJSON(&msg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("WriteJSON error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg.Close {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handlerEchoHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
io.WriteString(w, r.URL.Query().Get("echo"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func handlerLatencyEchoHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
sleep()
|
|
||||||
handlerEchoHTTP(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handlerEchoTCP(conn net.Conn) {
|
|
||||||
io.Copy(conn, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handlerLatencyEchoTCP(conn net.Conn) {
|
|
||||||
sleep()
|
|
||||||
handlerEchoTCP(conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
var tcpTimeout = 10 * time.Second
|
|
||||||
|
|
||||||
type tcpClient struct {
|
|
||||||
conn net.Conn
|
|
||||||
scanner *bufio.Scanner
|
|
||||||
in chan string
|
|
||||||
out chan string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tcpClient) loop() {
|
|
||||||
for out := range c.out {
|
|
||||||
if _, err := fmt.Fprintln(c.conn, out); err != nil {
|
|
||||||
log.Printf("[tunnelclient] error writing %q to %q: %s", out, c.conn.RemoteAddr(), err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !c.scanner.Scan() {
|
|
||||||
log.Printf("[tunnelclient] error reading from %q: %v", c.conn.RemoteAddr(), c.scanner.Err())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.in <- c.scanner.Text()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tcpClient) Close() error {
|
|
||||||
close(c.out)
|
|
||||||
return c.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func dialTCP(addr string) (*tcpClient, error) {
|
|
||||||
conn, err := net.DialTimeout("tcp", addr, tcpTimeout)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &tcpClient{
|
|
||||||
conn: conn,
|
|
||||||
scanner: bufio.NewScanner(conn),
|
|
||||||
in: make(chan string, 1),
|
|
||||||
out: make(chan string, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
go c.loop()
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func singleHTTP(handler interface{}) map[string]*tunneltest.Tunnel {
|
|
||||||
return singleRecHTTP(handler, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func singleRecHTTP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel {
|
|
||||||
return map[string]*tunneltest.Tunnel{
|
|
||||||
"http": {
|
|
||||||
Type: tunneltest.TypeHTTP,
|
|
||||||
LocalAddr: "127.0.0.1:0",
|
|
||||||
Handler: handler,
|
|
||||||
StateChanges: stateChanges,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func singleTCP(handler interface{}) map[string]*tunneltest.Tunnel {
|
|
||||||
return singleRecTCP(handler, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func singleRecTCP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel {
|
|
||||||
return map[string]*tunneltest.Tunnel{
|
|
||||||
"http": {
|
|
||||||
Type: tunneltest.TypeHTTP,
|
|
||||||
LocalAddr: "127.0.0.1:0",
|
|
||||||
Handler: handlerEchoHTTP,
|
|
||||||
StateChanges: stateChanges,
|
|
||||||
},
|
|
||||||
"tcp": {
|
|
||||||
Type: tunneltest.TypeTCP,
|
|
||||||
ClientIdent: "http",
|
|
||||||
LocalAddr: "127.0.0.1:0",
|
|
||||||
RemoteAddr: "127.0.0.1:0",
|
|
||||||
Handler: handler,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,115 +0,0 @@
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/koding/logging"
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
httpLog = logging.NewLogger("http")
|
|
||||||
)
|
|
||||||
|
|
||||||
// HTTPProxy forwards HTTP traffic.
|
|
||||||
//
|
|
||||||
// When tunnel server requests a connection it's proxied to 127.0.0.1:incomingPort
|
|
||||||
// where incomingPort is control message LocalPort.
|
|
||||||
// Usually this is tunnel server's public exposed Port.
|
|
||||||
// This behaviour can be changed by setting LocalAddr or FetchLocalAddr.
|
|
||||||
// FetchLocalAddr takes precedence over LocalAddr.
|
|
||||||
//
|
|
||||||
// When connection to local server cannot be established proxy responds with http error message.
|
|
||||||
type HTTPProxy struct {
|
|
||||||
// LocalAddr defines the TCP address of the local server.
|
|
||||||
// This is optional if you want to specify a single TCP address.
|
|
||||||
LocalAddr string
|
|
||||||
// FetchLocalAddr is used for looking up TCP address of the server.
|
|
||||||
// This is optional if you want to specify a dynamic TCP address based on incommig port.
|
|
||||||
FetchLocalAddr func(port int) (string, error)
|
|
||||||
// ErrorResp is custom response send to tunnel server when client cannot
|
|
||||||
// establish connection to local server. If not set a default "no local server"
|
|
||||||
// response is sent.
|
|
||||||
ErrorResp *http.Response
|
|
||||||
// Log is a custom logger that can be used for the proxy.
|
|
||||||
// If not set a "http" logger is used.
|
|
||||||
Log logging.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy is a ProxyFunc.
|
|
||||||
func (p *HTTPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) {
|
|
||||||
if msg.Protocol != proto.HTTP && msg.Protocol != proto.WS {
|
|
||||||
panic("Proxy mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
var log = p.log()
|
|
||||||
|
|
||||||
var port = msg.LocalPort
|
|
||||||
if port == 0 {
|
|
||||||
port = 80
|
|
||||||
}
|
|
||||||
|
|
||||||
var localAddr = fmt.Sprintf("127.0.0.1:%d", port)
|
|
||||||
if p.LocalAddr != "" {
|
|
||||||
localAddr = p.LocalAddr
|
|
||||||
} else if p.FetchLocalAddr != nil {
|
|
||||||
l, err := p.FetchLocalAddr(msg.LocalPort)
|
|
||||||
if err != nil {
|
|
||||||
log.Warning("Failed to get custom local address: %s", err)
|
|
||||||
p.sendError(remote)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
localAddr = l
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("Dialing local server %q", localAddr)
|
|
||||||
local, err := net.DialTimeout("tcp", localAddr, defaultTimeout)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Dialing local server %q failed: %s", localAddr, err)
|
|
||||||
p.sendError(remote)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Join(local, remote, log)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *HTTPProxy) sendError(remote net.Conn) {
|
|
||||||
var w = noLocalServer()
|
|
||||||
if p.ErrorResp != nil {
|
|
||||||
w = p.ErrorResp
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
w.Write(buf)
|
|
||||||
if _, err := io.Copy(remote, buf); err != nil {
|
|
||||||
var log = p.log()
|
|
||||||
log.Debug("Copy in-mem response error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
remote.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func noLocalServer() *http.Response {
|
|
||||||
body := bytes.NewBufferString("no local server")
|
|
||||||
return &http.Response{
|
|
||||||
Status: http.StatusText(http.StatusServiceUnavailable),
|
|
||||||
StatusCode: http.StatusServiceUnavailable,
|
|
||||||
Proto: "HTTP/1.1",
|
|
||||||
ProtoMajor: 1,
|
|
||||||
ProtoMinor: 1,
|
|
||||||
Body: ioutil.NopCloser(body),
|
|
||||||
ContentLength: int64(body.Len()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *HTTPProxy) log() logging.Logger {
|
|
||||||
if p.Log != nil {
|
|
||||||
return p.Log
|
|
||||||
}
|
|
||||||
return httpLog
|
|
||||||
}
|
|
|
@ -1,26 +0,0 @@
|
||||||
package proto
|
|
||||||
|
|
||||||
// ControlMessage is sent from server to client to establish tunneled connection.
|
|
||||||
type ControlMessage struct {
|
|
||||||
Action Action `json:"action"`
|
|
||||||
Protocol Type `json:"transportProtocol"`
|
|
||||||
LocalPort int `json:"localPort"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Action represents type of ControlMsg request.
|
|
||||||
type Action int
|
|
||||||
|
|
||||||
// ControlMessage actions.
|
|
||||||
const (
|
|
||||||
RequestClientSession Action = iota + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// Type represents tunneled connection type.
|
|
||||||
type Type int
|
|
||||||
|
|
||||||
// ControlMessage protocols.
|
|
||||||
const (
|
|
||||||
HTTP Type = iota + 1
|
|
||||||
TCP
|
|
||||||
WS
|
|
||||||
)
|
|
|
@ -1,19 +0,0 @@
|
||||||
// Package proto defines tunnel client server communication protocol.
|
|
||||||
package proto
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ControlPath is http.Handler url path for control connection.
|
|
||||||
ControlPath = "/_controlPath/"
|
|
||||||
|
|
||||||
// ClientIdentifierHeader is header carrying information about tunnel identifier.
|
|
||||||
ClientIdentifierHeader = "X-KTunnel-Identifier"
|
|
||||||
|
|
||||||
// control messages
|
|
||||||
|
|
||||||
// Connected is message sent by server to client when control connection was established.
|
|
||||||
Connected = "200 Connected to Tunnel"
|
|
||||||
// HandshakeRequest is hello message sent by client to server.
|
|
||||||
HandshakeRequest = "controlHandshake"
|
|
||||||
// HandshakeResponse is response to HandshakeRequest sent by server to client.
|
|
||||||
HandshakeResponse = "controlOk"
|
|
||||||
)
|
|
|
@ -1,101 +0,0 @@
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/koding/logging"
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back.
|
|
||||||
type ProxyFunc func(remote net.Conn, msg *proto.ControlMessage)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// DefaultProxyFuncs holds global default proxy functions for all transport protocols.
|
|
||||||
DefaultProxyFuncs = ProxyFuncs{
|
|
||||||
HTTP: new(HTTPProxy).Proxy,
|
|
||||||
TCP: new(TCPProxy).Proxy,
|
|
||||||
WS: new(HTTPProxy).Proxy,
|
|
||||||
}
|
|
||||||
// DefaultProxy is a ProxyFunc that uses DefaultProxyFuncs.
|
|
||||||
DefaultProxy = Proxy(ProxyFuncs{})
|
|
||||||
)
|
|
||||||
|
|
||||||
// ProxyFuncs is a collection of ProxyFunc.
|
|
||||||
type ProxyFuncs struct {
|
|
||||||
// HTTP is custom implementation of HTTP proxing.
|
|
||||||
HTTP ProxyFunc
|
|
||||||
// TCP is custom implementation of TCP proxing.
|
|
||||||
TCP ProxyFunc
|
|
||||||
// WS is custom implementation of web socket proxing.
|
|
||||||
WS ProxyFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy returns a ProxyFunc that uses custom function if provided, otherwise falls back to DefaultProxyFuncs.
|
|
||||||
func Proxy(p ProxyFuncs) ProxyFunc {
|
|
||||||
return func(remote net.Conn, msg *proto.ControlMessage) {
|
|
||||||
var f ProxyFunc
|
|
||||||
switch msg.Protocol {
|
|
||||||
case proto.HTTP:
|
|
||||||
f = DefaultProxyFuncs.HTTP
|
|
||||||
if p.HTTP != nil {
|
|
||||||
f = p.HTTP
|
|
||||||
}
|
|
||||||
case proto.TCP:
|
|
||||||
f = DefaultProxyFuncs.TCP
|
|
||||||
if p.TCP != nil {
|
|
||||||
f = p.TCP
|
|
||||||
}
|
|
||||||
case proto.WS:
|
|
||||||
f = DefaultProxyFuncs.WS
|
|
||||||
if p.WS != nil {
|
|
||||||
f = p.WS
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if f == nil {
|
|
||||||
logging.Error("Could not determine proxy function for %v", msg)
|
|
||||||
remote.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
f(remote, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Join copies data between local and remote connections.
|
|
||||||
// It reads from one connection and writes to the other.
|
|
||||||
// It's a building block for ProxyFunc implementations.
|
|
||||||
func Join(local, remote net.Conn, log logging.Logger) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
transfer := func(side string, dst, src net.Conn) {
|
|
||||||
log.Debug("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
|
|
||||||
|
|
||||||
n, err := io.Copy(dst, src)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("%s: copy error: %s", side, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := src.Close(); err != nil {
|
|
||||||
log.Debug("%s: close error: %s", side, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// not for yamux streams, but for client to local server connections
|
|
||||||
if d, ok := dst.(*net.TCPConn); ok {
|
|
||||||
if err := d.CloseWrite(); err != nil {
|
|
||||||
log.Debug("%s: closeWrite error: %s", side, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
log.Debug("done proxing %s -> %s: %d bytes", src.RemoteAddr(), dst.RemoteAddr(), n)
|
|
||||||
}
|
|
||||||
|
|
||||||
go transfer("remote to local", local, remote)
|
|
||||||
go transfer("local to remote", remote, local)
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
|
@ -1,754 +0,0 @@
|
||||||
// Package tunnel is a server/client package that enables to proxy public
|
|
||||||
// connections to your local machine over a tunnel connection from the local
|
|
||||||
// machine to the public server.
|
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
|
||||||
"github.com/hashicorp/yamux"
|
|
||||||
"github.com/koding/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errNoClientSession = errors.New("no client session established")
|
|
||||||
defaultTimeout = 10 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server is responsible for proxying public connections to the client over a
|
|
||||||
// tunnel connection. It also listens to control messages from the client.
|
|
||||||
type Server struct {
|
|
||||||
// pending contains the channel that is associated with each new tunnel request.
|
|
||||||
pending map[string]chan net.Conn
|
|
||||||
// pendingMu protects the pending map.
|
|
||||||
pendingMu sync.Mutex
|
|
||||||
|
|
||||||
// sessions contains a session per virtual host.
|
|
||||||
// Sessions provides multiplexing over one connection.
|
|
||||||
sessions map[string]*yamux.Session
|
|
||||||
// sessionsMu protects sessions.
|
|
||||||
sessionsMu sync.Mutex
|
|
||||||
|
|
||||||
// controls contains the control connection from the client to the server.
|
|
||||||
controls *controls
|
|
||||||
|
|
||||||
// virtualHosts is used to map public hosts to remote clients.
|
|
||||||
virtualHosts vhostStorage
|
|
||||||
|
|
||||||
// virtualAddrs.
|
|
||||||
virtualAddrs *vaddrStorage
|
|
||||||
|
|
||||||
// connCh is used to publish accepted connections for tcp tunnels.
|
|
||||||
connCh chan net.Conn
|
|
||||||
|
|
||||||
// onConnectCallbacks contains client callbacks called when control
|
|
||||||
// session is established for a client with given identifier.
|
|
||||||
onConnectCallbacks *callbacks
|
|
||||||
|
|
||||||
// onDisconnectCallbacks contains client callbacks called when control
|
|
||||||
// session is closed for a client with given identifier.
|
|
||||||
onDisconnectCallbacks *callbacks
|
|
||||||
|
|
||||||
// states represents current clients' connections state.
|
|
||||||
states map[string]ClientState
|
|
||||||
// statesMu protects states.
|
|
||||||
statesMu sync.RWMutex
|
|
||||||
// stateCh notifies receiver about client state changes.
|
|
||||||
stateCh chan<- *ClientStateChange
|
|
||||||
|
|
||||||
// httpDirector is provided by ServerConfig, if not nil decorates http requests
|
|
||||||
// before forwarding them to client.
|
|
||||||
httpDirector func(*http.Request)
|
|
||||||
|
|
||||||
// yamuxConfig is passed to new yamux.Session's
|
|
||||||
yamuxConfig *yamux.Config
|
|
||||||
|
|
||||||
log logging.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerConfig defines the configuration for the Server
|
|
||||||
type ServerConfig struct {
|
|
||||||
// StateChanges receives state transition details each time client
|
|
||||||
// connection state changes. The channel is expected to be sufficiently
|
|
||||||
// buffered to keep up with event pace.
|
|
||||||
//
|
|
||||||
// If nil, no information about state transitions are dispatched
|
|
||||||
// by the library.
|
|
||||||
StateChanges chan<- *ClientStateChange
|
|
||||||
|
|
||||||
// Director is a function that modifies HTTP request into a new HTTP request
|
|
||||||
// before sending to client. If nil no modifications are done.
|
|
||||||
Director func(*http.Request)
|
|
||||||
|
|
||||||
// Debug enables debug mode, enable only if you want to debug the server
|
|
||||||
Debug bool
|
|
||||||
|
|
||||||
// Log defines the logger. If nil a default logging.Logger is used.
|
|
||||||
Log logging.Logger
|
|
||||||
|
|
||||||
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
|
|
||||||
// yamux.DefaultConfig() is used.
|
|
||||||
YamuxConfig *yamux.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewServer creates a new Server. The defaults are used if config is nil.
|
|
||||||
func NewServer(cfg *ServerConfig) (*Server, error) {
|
|
||||||
yamuxConfig := yamux.DefaultConfig()
|
|
||||||
if cfg.YamuxConfig != nil {
|
|
||||||
if err := yamux.VerifyConfig(cfg.YamuxConfig); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
yamuxConfig = cfg.YamuxConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
log := newLogger("tunnel-server", cfg.Debug)
|
|
||||||
if cfg.Log != nil {
|
|
||||||
log = cfg.Log
|
|
||||||
}
|
|
||||||
|
|
||||||
connCh := make(chan net.Conn)
|
|
||||||
|
|
||||||
opts := &vaddrOptions{
|
|
||||||
connCh: connCh,
|
|
||||||
log: log,
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &Server{
|
|
||||||
pending: make(map[string]chan net.Conn),
|
|
||||||
sessions: make(map[string]*yamux.Session),
|
|
||||||
onConnectCallbacks: newCallbacks("OnConnect"),
|
|
||||||
onDisconnectCallbacks: newCallbacks("OnDisconnect"),
|
|
||||||
virtualHosts: newVirtualHosts(),
|
|
||||||
virtualAddrs: newVirtualAddrs(opts),
|
|
||||||
controls: newControls(),
|
|
||||||
states: make(map[string]ClientState),
|
|
||||||
stateCh: cfg.StateChanges,
|
|
||||||
httpDirector: cfg.Director,
|
|
||||||
yamuxConfig: yamuxConfig,
|
|
||||||
connCh: connCh,
|
|
||||||
log: log,
|
|
||||||
}
|
|
||||||
|
|
||||||
go s.serveTCP()
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTP is a tunnel that creates an http/websocket tunnel between a
|
|
||||||
// public connection and the client connection.
|
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// if the user didn't add the control and tunnel handler manually, we'll
|
|
||||||
// going to infer and call the respective path handlers.
|
|
||||||
switch path.Clean(r.URL.Path) + "/" {
|
|
||||||
case proto.ControlPath:
|
|
||||||
s.checkConnect(s.controlHandler).ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.handleHTTP(w, r); err != nil {
|
|
||||||
if !strings.Contains(err.Error(), "no virtual host available") { // this one is outputted too much, unnecessarily
|
|
||||||
s.log.Error("remote %s (%s): %s", r.RemoteAddr, r.RequestURI, err)
|
|
||||||
}
|
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleHTTP handles a single HTTP request
|
|
||||||
func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
s.log.Debug("HandleHTTP request:")
|
|
||||||
s.log.Debug("%v", r)
|
|
||||||
|
|
||||||
if s.httpDirector != nil {
|
|
||||||
s.httpDirector(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
hostPort := strings.ToLower(r.Host)
|
|
||||||
if hostPort == "" {
|
|
||||||
return errors.New("request host is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// if someone hits foo.example.com:8080, this should be proxied to
|
|
||||||
// localhost:8080, so send the port to the client so it knows how to proxy
|
|
||||||
// correctly. If no port is available, it's up to client how to interpret it
|
|
||||||
host, port, err := parseHostPort(hostPort)
|
|
||||||
if err != nil {
|
|
||||||
// no need to return, just continue lazily, port will be 0, which in
|
|
||||||
// our case will be proxied to client's local servers port 80
|
|
||||||
s.log.Debug("No port available for %q, sending port 80 to client", hostPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the identifier associated with this host
|
|
||||||
identifier, ok := s.getIdentifier(hostPort)
|
|
||||||
if !ok {
|
|
||||||
// fallback to host
|
|
||||||
identifier, ok = s.getIdentifier(host)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("no virtual host available for %q", hostPort)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isWebsocketConn(r) {
|
|
||||||
s.log.Debug("handling websocket connection")
|
|
||||||
|
|
||||||
return s.handleWSConn(w, r, identifier, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := s.dial(identifier, proto.HTTP, port)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
s.log.Debug("Closing stream")
|
|
||||||
stream.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := r.Write(stream); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.log.Debug("Session opened to client, writing request to client")
|
|
||||||
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read from tunnel: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if resp.Body != nil {
|
|
||||||
if err := resp.Body.Close(); err != nil && err != io.ErrUnexpectedEOF {
|
|
||||||
s.log.Error("resp.Body Close error: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
s.log.Debug("Response received, writing back to public connection: %+v", resp)
|
|
||||||
|
|
||||||
copyHeader(w.Header(), resp.Header)
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
if _, err := io.Copy(w, resp.Body); err != nil {
|
|
||||||
if err == io.ErrUnexpectedEOF {
|
|
||||||
s.log.Debug("Client closed the connection, couldn't copy response")
|
|
||||||
} else {
|
|
||||||
s.log.Error("copy err: %s", err) // do not return, because we might write multipe headers
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) serveTCP() {
|
|
||||||
for conn := range s.connCh {
|
|
||||||
go s.serveTCPConn(conn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) serveTCPConn(conn net.Conn) {
|
|
||||||
err := s.handleTCPConn(conn)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Warning("failed to serve %q: %s", conn.RemoteAddr(), err)
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleWSConn(w http.ResponseWriter, r *http.Request, ident string, port int) error {
|
|
||||||
hj, ok := w.(http.Hijacker)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("webserver doesn't support hijacking")
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, _, err := hj.Hijack()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("hijack not possible: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := s.dial(ident, proto.WS, port)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.Write(stream); err != nil {
|
|
||||||
err = errors.New("unable to write upgrade request: " + err.Error())
|
|
||||||
return nonil(err, stream.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
|
||||||
if err != nil {
|
|
||||||
err = errors.New("unable to read upgrade response: " + err.Error())
|
|
||||||
return nonil(err, stream.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := resp.Write(conn); err != nil {
|
|
||||||
err = errors.New("unable to write upgrade response: " + err.Error())
|
|
||||||
return nonil(err, stream.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
go s.proxy(&wg, conn, stream)
|
|
||||||
go s.proxy(&wg, stream, conn)
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
return nonil(stream.Close(), conn.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleTCPConn(conn net.Conn) error {
|
|
||||||
ident, ok := s.virtualAddrs.getIdent(conn)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("no virtual address available for %s", conn.LocalAddr())
|
|
||||||
}
|
|
||||||
|
|
||||||
_, port, err := parseHostPort(conn.LocalAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := s.dial(ident, proto.TCP, port)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
go s.proxy(&wg, conn, stream)
|
|
||||||
go s.proxy(&wg, stream, conn)
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
return nonil(stream.Close(), conn.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) proxy(wg *sync.WaitGroup, dst, src net.Conn) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
s.log.Debug("tunneling %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
|
|
||||||
n, err := io.Copy(dst, src)
|
|
||||||
s.log.Debug("tunneled %d bytes %s -> %s: %v", n, src.RemoteAddr(), dst.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) dial(identifier string, p proto.Type, port int) (net.Conn, error) {
|
|
||||||
control, ok := s.getControl(identifier)
|
|
||||||
if !ok {
|
|
||||||
return nil, errNoClientSession
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := s.getSession(identifier)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := proto.ControlMessage{
|
|
||||||
Action: proto.RequestClientSession,
|
|
||||||
Protocol: p,
|
|
||||||
LocalPort: port,
|
|
||||||
}
|
|
||||||
|
|
||||||
s.log.Debug("Sending control msg %+v", msg)
|
|
||||||
|
|
||||||
// ask client to open a session to us, so we can accept it
|
|
||||||
if err := control.send(msg); err != nil {
|
|
||||||
// we might have several issues here, either the stream is closed, or
|
|
||||||
// the session is going be shut down, the underlying connection might
|
|
||||||
// be broken. In all cases, it's not reliable anymore having a client
|
|
||||||
// session.
|
|
||||||
control.Close()
|
|
||||||
s.deleteControl(identifier)
|
|
||||||
return nil, errNoClientSession
|
|
||||||
}
|
|
||||||
|
|
||||||
var stream net.Conn
|
|
||||||
acceptStream := func() error {
|
|
||||||
stream, err = session.Accept()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we don't receive anything from the client, we'll timeout
|
|
||||||
s.log.Debug("Waiting for session accept")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-async(acceptStream):
|
|
||||||
return stream, err
|
|
||||||
case <-time.After(defaultTimeout):
|
|
||||||
return nil, errors.New("timeout getting session")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// controlHandler is used to capture incoming tunnel connect requests into raw
|
|
||||||
// tunnel TCP connections.
|
|
||||||
func (s *Server) controlHandler(w http.ResponseWriter, r *http.Request) (ctErr error) {
|
|
||||||
identifier := r.Header.Get(proto.ClientIdentifierHeader)
|
|
||||||
_, ok := s.getHost(identifier)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("no host associated for identifier %s. please use server.AddHost()", identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
ct, ok := s.getControl(identifier)
|
|
||||||
if ok {
|
|
||||||
ct.Close()
|
|
||||||
s.deleteControl(identifier)
|
|
||||||
s.deleteSession(identifier)
|
|
||||||
s.log.Warning("Control connection for %q already exists. This is a race condition and needs to be fixed on client implementation", identifier)
|
|
||||||
return fmt.Errorf("control conn for %s already exist. \n", identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.log.Debug("Tunnel with identifier %s", identifier)
|
|
||||||
|
|
||||||
hj, ok := w.(http.Hijacker)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("webserver doesn't support hijacking")
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, _, err := hj.Hijack()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("hijack not possible: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.WriteString(conn, "HTTP/1.1 "+proto.Connected+"\n\n"); err != nil {
|
|
||||||
return fmt.Errorf("error writing response: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := conn.SetDeadline(time.Time{}); err != nil {
|
|
||||||
return fmt.Errorf("error setting connection deadline: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.log.Debug("Creating control session")
|
|
||||||
session, err := yamux.Server(conn, s.yamuxConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.addSession(identifier, session)
|
|
||||||
|
|
||||||
var stream net.Conn
|
|
||||||
|
|
||||||
// close and delete the session/stream if something goes wrong
|
|
||||||
defer func() {
|
|
||||||
if ctErr != nil {
|
|
||||||
if stream != nil {
|
|
||||||
stream.Close()
|
|
||||||
}
|
|
||||||
s.deleteSession(identifier)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
acceptStream := func() error {
|
|
||||||
stream, err = session.Accept()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we don't receive anything from the client, we'll timeout
|
|
||||||
select {
|
|
||||||
case err := <-async(acceptStream):
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case <-time.After(time.Second * 10):
|
|
||||||
return errors.New("timeout getting session")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.log.Debug("Initiating handshake protocol")
|
|
||||||
buf := make([]byte, len(proto.HandshakeRequest))
|
|
||||||
if _, err := stream.Read(buf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(buf) != proto.HandshakeRequest {
|
|
||||||
return fmt.Errorf("handshake aborted. got: %s", string(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := stream.Write([]byte(proto.HandshakeResponse)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// setup control stream and start to listen to messages
|
|
||||||
ct = newControl(stream)
|
|
||||||
s.addControl(identifier, ct)
|
|
||||||
go s.listenControl(ct)
|
|
||||||
|
|
||||||
s.log.Debug("Control connection is setup")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// listenControl listens to messages coming from the client.
|
|
||||||
func (s *Server) listenControl(ct *control) {
|
|
||||||
s.onConnect(ct.identifier)
|
|
||||||
|
|
||||||
for {
|
|
||||||
var msg map[string]interface{}
|
|
||||||
err := ct.dec.Decode(&msg)
|
|
||||||
if err != nil {
|
|
||||||
host, _ := s.getHost(ct.identifier)
|
|
||||||
s.log.Debug("Closing client connection: '%s', %s'", host, ct.identifier)
|
|
||||||
|
|
||||||
// close client connection so it reconnects again
|
|
||||||
ct.Close()
|
|
||||||
|
|
||||||
// don't forget to cleanup anything
|
|
||||||
s.deleteControl(ct.identifier)
|
|
||||||
s.deleteSession(ct.identifier)
|
|
||||||
|
|
||||||
s.onDisconnect(ct.identifier, err)
|
|
||||||
|
|
||||||
if err != io.EOF {
|
|
||||||
s.log.Error("decode err: %s", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// right now we don't do anything with the messages, but because the
|
|
||||||
// underlying connection needs to establihsed, we know when we have
|
|
||||||
// disconnection(above), so we can cleanup the connection.
|
|
||||||
s.log.Debug("msg: %s", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnConnect invokes a callback for client with given identifier,
|
|
||||||
// when it establishes a control session.
|
|
||||||
// After a client is connected, the associated function
|
|
||||||
// is also removed and needs to be added again.
|
|
||||||
func (s *Server) OnConnect(identifier string, fn func() error) {
|
|
||||||
s.onConnectCallbacks.add(identifier, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// onConnect sends notifications to listeners (registered in onConnectCallbacks
|
|
||||||
// or stateChanges chanel readers) when client connects.
|
|
||||||
func (s *Server) onConnect(identifier string) {
|
|
||||||
if err := s.onConnectCallbacks.call(identifier); err != nil {
|
|
||||||
s.log.Error("OnConnect: error calling callback for %q: %s", identifier, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.changeState(identifier, ClientConnected, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnDisconnect calls the function when the client connected with the
|
|
||||||
// associated identifier disconnects from the server.
|
|
||||||
// After a client is disconnected, the associated function
|
|
||||||
// is also removed and needs to be added again.
|
|
||||||
func (s *Server) OnDisconnect(identifier string, fn func() error) {
|
|
||||||
s.onDisconnectCallbacks.add(identifier, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// onDisconnect sends notifications to listeners (registered in onDisconnectCallbacks
|
|
||||||
// or stateChanges chanel readers) when client disconnects.
|
|
||||||
func (s *Server) onDisconnect(identifier string, err error) {
|
|
||||||
if err := s.onDisconnectCallbacks.call(identifier); err != nil {
|
|
||||||
s.log.Error("OnDisconnect: error calling callback for %q: %s", identifier, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.changeState(identifier, ClientClosed, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) changeState(identifier string, state ClientState, err error) (prev ClientState) {
|
|
||||||
s.statesMu.Lock()
|
|
||||||
defer s.statesMu.Unlock()
|
|
||||||
|
|
||||||
prev = s.states[identifier]
|
|
||||||
s.states[identifier] = state
|
|
||||||
|
|
||||||
if s.stateCh != nil {
|
|
||||||
change := &ClientStateChange{
|
|
||||||
Identifier: identifier,
|
|
||||||
Previous: prev,
|
|
||||||
Current: state,
|
|
||||||
Error: err,
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case s.stateCh <- change:
|
|
||||||
default:
|
|
||||||
s.log.Warning("Dropping state change due to slow reader: %s", change)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return prev
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddHost adds the given virtual host and maps it to the identifier.
|
|
||||||
func (s *Server) AddHost(host, identifier string) {
|
|
||||||
s.virtualHosts.AddHost(host, identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteHost deletes the given virtual host. Once removed any request to this
|
|
||||||
// host is denied.
|
|
||||||
func (s *Server) DeleteHost(host string) {
|
|
||||||
s.virtualHosts.DeleteHost(host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAddr starts accepting connections on listener l, routing every connection
|
|
||||||
// to a tunnel client given by the identifier.
|
|
||||||
//
|
|
||||||
// When ip parameter is nil, all connections accepted from the listener are
|
|
||||||
// routed to the tunnel client specified by the identifier (port-based routing).
|
|
||||||
//
|
|
||||||
// When ip parameter is non-nil, only those connections are routed whose local
|
|
||||||
// address matches the specified ip (ip-based routing).
|
|
||||||
//
|
|
||||||
// If l listens on multiple interfaces it's desirable to call AddAddr multiple
|
|
||||||
// times with the same l value but different ip one.
|
|
||||||
func (s *Server) AddAddr(l net.Listener, ip net.IP, identifier string) {
|
|
||||||
s.virtualAddrs.Add(l, ip, identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteAddr stops listening for connections on the given listener.
|
|
||||||
//
|
|
||||||
// Upon return no more connections will be tunneled, but as the method does not
|
|
||||||
// close the listener, so any ongoing connection won't get interrupted.
|
|
||||||
func (s *Server) DeleteAddr(l net.Listener, ip net.IP) {
|
|
||||||
s.virtualAddrs.Delete(l, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getIdentifier(host string) (string, bool) {
|
|
||||||
identifier, ok := s.virtualHosts.GetIdentifier(host)
|
|
||||||
return identifier, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getHost(identifier string) (string, bool) {
|
|
||||||
host, ok := s.virtualHosts.GetHost(identifier)
|
|
||||||
return host, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) addControl(identifier string, conn *control) {
|
|
||||||
s.controls.addControl(identifier, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getControl(identifier string) (*control, bool) {
|
|
||||||
return s.controls.getControl(identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) deleteControl(identifier string) {
|
|
||||||
s.controls.deleteControl(identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getSession(identifier string) (*yamux.Session, error) {
|
|
||||||
s.sessionsMu.Lock()
|
|
||||||
session, ok := s.sessions[identifier]
|
|
||||||
s.sessionsMu.Unlock()
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("no session available for identifier: '%s'", identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) addSession(identifier string, session *yamux.Session) {
|
|
||||||
s.sessionsMu.Lock()
|
|
||||||
s.sessions[identifier] = session
|
|
||||||
s.sessionsMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) deleteSession(identifier string) {
|
|
||||||
s.sessionsMu.Lock()
|
|
||||||
defer s.sessionsMu.Unlock()
|
|
||||||
|
|
||||||
session, ok := s.sessions[identifier]
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return // nothing to delete
|
|
||||||
}
|
|
||||||
|
|
||||||
if session != nil {
|
|
||||||
session.GoAway() // don't accept any new connection
|
|
||||||
session.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(s.sessions, identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
func copyHeader(dst, src http.Header) {
|
|
||||||
for k, v := range src {
|
|
||||||
vv := make([]string, len(v))
|
|
||||||
copy(vv, v)
|
|
||||||
dst[k] = vv
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkConnect checks whether the incoming request is HTTP CONNECT method.
|
|
||||||
func (s *Server) checkConnect(fn func(w http.ResponseWriter, r *http.Request) error) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != "CONNECT" {
|
|
||||||
http.Error(w, "405 must CONNECT\n", http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := fn(w, r); err != nil {
|
|
||||||
s.log.Error("Handler err: %v", err.Error())
|
|
||||||
|
|
||||||
if identifier := r.Header.Get(proto.ClientIdentifierHeader); identifier != "" {
|
|
||||||
s.onDisconnect(identifier, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
http.Error(w, err.Error(), 502)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseHostPort(addr string) (string, int, error) {
|
|
||||||
host, port, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := strconv.ParseUint(port, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return host, int(n), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isWebsocketConn(r *http.Request) bool {
|
|
||||||
return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") &&
|
|
||||||
headerContains(r.Header["Upgrade"], "websocket")
|
|
||||||
}
|
|
||||||
|
|
||||||
// headerContains is a copy of tokenListContainsValue from gorilla/websocket/util.go
|
|
||||||
func headerContains(header []string, value string) bool {
|
|
||||||
for _, h := range header {
|
|
||||||
for _, v := range strings.Split(h, ",") {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(v), value) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func nonil(err ...error) error {
|
|
||||||
for _, e := range err {
|
|
||||||
if e != nil {
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newLogger(name string, debug bool) logging.Logger {
|
|
||||||
log := logging.NewLogger(name)
|
|
||||||
logHandler := logging.NewWriterHandler(os.Stderr)
|
|
||||||
logHandler.Colorize = true
|
|
||||||
log.SetHandler(logHandler)
|
|
||||||
|
|
||||||
if debug {
|
|
||||||
log.SetLevel(logging.DEBUG)
|
|
||||||
logHandler.SetLevel(logging.DEBUG)
|
|
||||||
}
|
|
||||||
|
|
||||||
return log
|
|
||||||
}
|
|
|
@ -1,100 +0,0 @@
|
||||||
# Specification
|
|
||||||
|
|
||||||
# Naming conventions
|
|
||||||
|
|
||||||
* `server` is listening to public connection and is responsible of routing
|
|
||||||
public HTTP requests to clients.
|
|
||||||
* `client` is a long running process, connected to a server and running on a local machine.
|
|
||||||
* `virtualHost` is a virtual domain that maps a domain to a single client. i.e:
|
|
||||||
`arslan.koding.io` is a virtualhost which is mapped to my `client` running on
|
|
||||||
my local machine.
|
|
||||||
* `identifier` is a secret token, which is not meant to be shared with others.
|
|
||||||
An identifier is responsible of mapping a virtualhost to a client.
|
|
||||||
* `session` is a single TCP connection which uses the library `yamux`. A
|
|
||||||
session can be created either via `yamux.Server()` or `yamux.Client`
|
|
||||||
* `stream` is a `net.Conn` compatible `virtual` connection that is multiplexed
|
|
||||||
over the `session`. A session can have hundreds of thousands streams
|
|
||||||
* `control connection` is a single `stream` which is used to communicate and
|
|
||||||
handle messaging between server and client. It uses a custom protocol which
|
|
||||||
is JSON encoded.
|
|
||||||
* `tunnel connection` is a single `stream` which is used to proxy public HTTP
|
|
||||||
requests from the `server` to the `client` and vice versa. A single `tunnel`
|
|
||||||
connection is created for every single HTTP requests.
|
|
||||||
* `public connection` is a connection from a remote machine to the `server`
|
|
||||||
* `ControlHandler` is a http.Handler which listens to requests coming to
|
|
||||||
`/_controlPath_/`. It's used to setup the initial `session` connection from
|
|
||||||
`client` to `server`. And creates the `control connection` from this session.
|
|
||||||
server and client, and also for all additional new tunnel. It literally
|
|
||||||
captures the incoming HTTP request and hijacks it and converts it into RAW TCP,
|
|
||||||
which then is used as the foundation for all yamux `sessions.`
|
|
||||||
|
|
||||||
|
|
||||||
# Server
|
|
||||||
1. Server is created with `NewServer()` which returns `*Server`, a `http.Handler`
|
|
||||||
compatible type. Plug into any HTTP server you want. The root path `"/"` is
|
|
||||||
recommended to listen and proxy any tunnels. It also listens to any request
|
|
||||||
coming to `ControlHandler`
|
|
||||||
2. Tunneling is based on virtual hosts. A virtual hosts is identified with an
|
|
||||||
unique identifier. This identifier is the only piece that both client and
|
|
||||||
server needs to known ahead. Think of it as a secret token.
|
|
||||||
3. To add a virtual host, call `server.AddHost(virtualHost, identifier)`. This
|
|
||||||
step needs to be done from the server itself. This can be could manually or
|
|
||||||
via custom auth based HTTP handlers, such as "/addhost", which adds
|
|
||||||
virtualhosts and returns the `identifier` to the requester (in our case `client`)
|
|
||||||
4. A DNS record and it's subdomains needs to point to a `server`, so it can
|
|
||||||
handle virtual hosts, i.e: `*.example.com` is routed to a server, which can
|
|
||||||
handle `foo.example.com`, `bar.example.com`, etc..
|
|
||||||
|
|
||||||
|
|
||||||
# Client
|
|
||||||
|
|
||||||
1. Client is created with `NewClient(serverAddr, localAddr)` which returns a
|
|
||||||
`*Client`. Here `serverAddr` is the TCP address to the server. `localAddr`
|
|
||||||
is the server in which all public requests are forwarded to. It's optional if
|
|
||||||
you want it to be done dynamically
|
|
||||||
2. Once a client is created, it starts with `client.Start(identifier)`. Here
|
|
||||||
`identifier` is needed upfront. This method creates the initial TCP
|
|
||||||
connection to the server. It sends the identifier back to the server. This
|
|
||||||
TCP connection is used as the foundation for `yamux.Client()`. Once a yamux
|
|
||||||
session is established, we are able to use this single connection to have
|
|
||||||
multiple streams, which are multiplexed over this one connection. A `control
|
|
||||||
connection` is created and client starts to listen it. `client.Start` is
|
|
||||||
blocking.
|
|
||||||
|
|
||||||
# Control Handshake
|
|
||||||
|
|
||||||
1. Client sends a `handshakeRequest` over the `control connection` stream
|
|
||||||
2. The server sends back a `handshakeResponse` to the client over the `control connection` stream
|
|
||||||
3. Once the client receives the `handshakeResponse` from the server, it starts
|
|
||||||
to listen from the `control connection` stream.
|
|
||||||
4. A `control connection` is json.Encoder/Decoder both for server and client
|
|
||||||
|
|
||||||
|
|
||||||
# Tunnel creation
|
|
||||||
1. When the server receives a public connection, it checks the HTTP host
|
|
||||||
headers and retrieves the corresponding identifier from the given host.
|
|
||||||
2. The server retrieves the `control connection` which was associated with this
|
|
||||||
`identifier` and sends a `ControlMsg` message with the action
|
|
||||||
`RequestClientSession`. This message is in the form of:
|
|
||||||
|
|
||||||
type ControlMsg struct {
|
|
||||||
Action Action `json:"action"`
|
|
||||||
Protocol TransportProtocol `json:"transportProtocol"`
|
|
||||||
LocalPort string `json:"localPort"`
|
|
||||||
}
|
|
||||||
|
|
||||||
Here the `LocalPort` is read from the HTTP Host header. If absent a zero
|
|
||||||
port is sent and client maps it to the local server running at port 80, unless
|
|
||||||
the `localAddr` is specified in `client.Start()` method. `Protocol` is
|
|
||||||
reserved for future features.
|
|
||||||
3. The server immediately starts to listen(accept) to a new `stream`. This is
|
|
||||||
blocking and it waits there.
|
|
||||||
4. When the client receives the `RequestClientSession` message, it opens a new
|
|
||||||
`virtual` TCP connection, a `stream` to the server.
|
|
||||||
5. The server which was waiting for a new stream in step 3, establish the stream.
|
|
||||||
6. The server copies the request over the stream to the client.
|
|
||||||
7. The client copies the request coming from the server to the local server and
|
|
||||||
copies back the result to the server
|
|
||||||
8. The server reads the response coming from the client and returns back it to
|
|
||||||
the public connection requester
|
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
|
||||||
"github.com/koding/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
tpcLog = logging.NewLogger("tcp")
|
|
||||||
)
|
|
||||||
|
|
||||||
// TCPProxy forwards TCP streams.
|
|
||||||
//
|
|
||||||
// If port-based routing is used, LocalAddr or FetchLocalAddr field is required
|
|
||||||
// for tunneling to function properly.
|
|
||||||
// Otherwise you'll be forwarding traffic to random ports and this is usually not desired.
|
|
||||||
//
|
|
||||||
// If IP-based routing is used then tunnel server connection request is
|
|
||||||
// proxied to 127.0.0.1:incomingPort where incomingPort is control message LocalPort.
|
|
||||||
// Usually this is tunnel server's public exposed Port.
|
|
||||||
// This behaviour can be changed by setting LocalAddr or FetchLocalAddr.
|
|
||||||
// FetchLocalAddr takes precedence over LocalAddr.
|
|
||||||
type TCPProxy struct {
|
|
||||||
// LocalAddr defines the TCP address of the local server.
|
|
||||||
// This is optional if you want to specify a single TCP address.
|
|
||||||
LocalAddr string
|
|
||||||
// FetchLocalAddr is used for looking up TCP address of the server.
|
|
||||||
// This is optional if you want to specify a dynamic TCP address based on incommig port.
|
|
||||||
FetchLocalAddr func(port int) (string, error)
|
|
||||||
// Log is a custom logger that can be used for the proxy.
|
|
||||||
// If not set a "tcp" logger is used.
|
|
||||||
Log logging.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy is a ProxyFunc.
|
|
||||||
func (p *TCPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) {
|
|
||||||
if msg.Protocol != proto.TCP {
|
|
||||||
panic("Proxy mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
var log = p.log()
|
|
||||||
|
|
||||||
var port = msg.LocalPort
|
|
||||||
if port == 0 {
|
|
||||||
log.Warning("TCP proxy to port 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
var localAddr = fmt.Sprintf("127.0.0.1:%d", port)
|
|
||||||
if p.LocalAddr != "" {
|
|
||||||
localAddr = p.LocalAddr
|
|
||||||
} else if p.FetchLocalAddr != nil {
|
|
||||||
l, err := p.FetchLocalAddr(msg.LocalPort)
|
|
||||||
if err != nil {
|
|
||||||
log.Warning("Failed to get custom local address: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
localAddr = l
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("Dialing local server: %q", localAddr)
|
|
||||||
local, err := net.DialTimeout("tcp", localAddr, defaultTimeout)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Dialing local server %q failed: %s", localAddr, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Join(local, remote, log)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TCPProxy) log() logging.Logger {
|
|
||||||
if p.Log != nil {
|
|
||||||
return p.Log
|
|
||||||
}
|
|
||||||
return tpcLog
|
|
||||||
}
|
|
|
@ -1,411 +0,0 @@
|
||||||
package tunnel_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel"
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
|
||||||
"github.com/cenkalti/backoff"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMultipleRequest(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
// make a request to tunnelserver, this should be tunneled to local server
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
msg := "hello" + strconv.Itoa(i)
|
|
||||||
res, err := echoHTTP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("echoHTTP error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != msg {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMultipleLatencyRequest(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
// make a request to tunnelserver, this should be tunneled to local server
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
msg := "hello" + strconv.Itoa(i)
|
|
||||||
res, err := echoHTTP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("echoHTTP error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != msg {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReconnectClient(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
msg := "hello"
|
|
||||||
res, err := echoHTTP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("echoHTTP error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != msg {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
client := tt.Clients["http"]
|
|
||||||
|
|
||||||
// close client, and start it again
|
|
||||||
client.Close()
|
|
||||||
|
|
||||||
go client.Start()
|
|
||||||
<-client.StartNotify()
|
|
||||||
|
|
||||||
msg = "helloagain"
|
|
||||||
res, err = echoHTTP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("echoHTTP error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != msg {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNoClient(t *testing.T) {
|
|
||||||
const expectedErr = "no client session established"
|
|
||||||
|
|
||||||
rec := tunneltest.NewStateRecorder()
|
|
||||||
|
|
||||||
tt, err := tunneltest.Serve(singleRecHTTP(handlerEchoHTTP, rec.C()))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
if err := rec.WaitTransitions(
|
|
||||||
tunnel.ClientStarted,
|
|
||||||
tunnel.ClientConnecting,
|
|
||||||
tunnel.ClientConnected,
|
|
||||||
); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tt.ServerStateRecorder.WaitTransition(
|
|
||||||
tunnel.ClientUnknown,
|
|
||||||
tunnel.ClientConnected,
|
|
||||||
); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// close client, this is the main point of the test
|
|
||||||
if err := tt.Clients["http"].Close(); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rec.WaitTransitions(
|
|
||||||
tunnel.ClientConnected,
|
|
||||||
tunnel.ClientDisconnected,
|
|
||||||
tunnel.ClientClosed,
|
|
||||||
); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tt.ServerStateRecorder.WaitTransition(
|
|
||||||
tunnel.ClientConnected,
|
|
||||||
tunnel.ClientClosed,
|
|
||||||
); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := "hello"
|
|
||||||
res, err := echoHTTP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("echoHTTP error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != expectedErr {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNoHost(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
noBackoff := backoff.NewConstantBackOff(time.Duration(-1))
|
|
||||||
|
|
||||||
unknown, err := tunnel.NewClient(&tunnel.ClientConfig{
|
|
||||||
Identifier: "unknown",
|
|
||||||
ServerAddr: tt.ServerAddr().String(),
|
|
||||||
Backoff: noBackoff,
|
|
||||||
Debug: testing.Verbose(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("client error: %s", err)
|
|
||||||
}
|
|
||||||
unknown.Start()
|
|
||||||
defer unknown.Close()
|
|
||||||
|
|
||||||
if err := tt.ServerStateRecorder.WaitTransition(
|
|
||||||
tunnel.ClientUnknown,
|
|
||||||
tunnel.ClientClosed,
|
|
||||||
); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
unknown.Start()
|
|
||||||
if err := tt.ServerStateRecorder.WaitTransition(
|
|
||||||
tunnel.ClientClosed,
|
|
||||||
tunnel.ClientClosed,
|
|
||||||
); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNoLocalServer(t *testing.T) {
|
|
||||||
const expectedErr = "no local server"
|
|
||||||
|
|
||||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
// close local listener, this is the main point of the test
|
|
||||||
tt.Listeners["http"][0].Close()
|
|
||||||
|
|
||||||
msg := "hello"
|
|
||||||
res, err := echoHTTP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("echoHTTP error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != expectedErr {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSingleRequest(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
msg := "hello"
|
|
||||||
res, err := echoHTTP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("echoHTTP error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != msg {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSingleLatencyRequest(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
msg := "hello"
|
|
||||||
res, err := echoHTTP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("echoHTTP error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != msg {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSingleTCP(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
msg := "hello"
|
|
||||||
res, err := echoTCP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("echoTCP error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg != res {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMultipleTCP(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
msg := "hello" + strconv.Itoa(i)
|
|
||||||
res, err := echoTCP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("echoTCP: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != msg {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMultipleLatencyTCP(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleTCP(handlerLatencyEchoTCP))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
msg := "hello" + strconv.Itoa(i)
|
|
||||||
res, err := echoTCP(tt, msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("echoTCP: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != msg {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMultipleStreamTCP(t *testing.T) {
|
|
||||||
tunnels := map[string]*tunneltest.Tunnel{
|
|
||||||
"http": {
|
|
||||||
Type: tunneltest.TypeHTTP,
|
|
||||||
LocalAddr: "127.0.0.1:0",
|
|
||||||
Handler: handlerEchoHTTP,
|
|
||||||
},
|
|
||||||
"tcp": {
|
|
||||||
Type: tunneltest.TypeTCP,
|
|
||||||
ClientIdent: "http",
|
|
||||||
LocalAddr: "127.0.0.1:0",
|
|
||||||
RemoteAddr: "127.0.0.1:0",
|
|
||||||
Handler: handlerEchoTCP,
|
|
||||||
},
|
|
||||||
"tcp_all": {
|
|
||||||
Type: tunneltest.TypeTCP,
|
|
||||||
ClientIdent: "http",
|
|
||||||
LocalAddr: "127.0.0.1:0",
|
|
||||||
RemoteAddr: "0.0.0.0:0",
|
|
||||||
Handler: handlerEchoTCP,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
addrs, err := tunneltest.UsableAddrs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
clients := []string{"tcp"}
|
|
||||||
for i, addr := range addrs {
|
|
||||||
if addr.IP.IsLoopback() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
client := fmt.Sprintf("tcp_%d", i)
|
|
||||||
|
|
||||||
tunnels[client] = &tunneltest.Tunnel{
|
|
||||||
Type: tunneltest.TypeTCP,
|
|
||||||
ClientIdent: "http",
|
|
||||||
LocalAddr: "127.0.0.1:0",
|
|
||||||
RemoteAddrIdent: "tcp_all",
|
|
||||||
IP: addr.IP,
|
|
||||||
Handler: handlerEchoTCP,
|
|
||||||
}
|
|
||||||
|
|
||||||
clients = append(clients, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
tt, err := tunneltest.Serve(tunnels)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tt.Close()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 100/len(clients); i++ {
|
|
||||||
wg.Add(len(clients))
|
|
||||||
|
|
||||||
for j, ident := range clients {
|
|
||||||
go func(ident string, i, j int) {
|
|
||||||
defer wg.Done()
|
|
||||||
msg := fmt.Sprintf("hello_%d_client_%d", j, i)
|
|
||||||
res, err := echoTCPIdent(tt, msg, ident)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("echoTCP: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != msg {
|
|
||||||
t.Errorf("got %q, want %q", res, msg)
|
|
||||||
}
|
|
||||||
}(ident, i, j)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
|
@ -1,118 +0,0 @@
|
||||||
package tunneltest
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
recWaitTimeout = 5 * time.Second
|
|
||||||
recBuffer = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
// States is a sequence of client state changes.
|
|
||||||
type States []*tunnel.ClientStateChange
|
|
||||||
|
|
||||||
func (s States) String() string {
|
|
||||||
if len(s) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
fmt.Fprintf(&buf, "[%s", s[0].String())
|
|
||||||
|
|
||||||
for _, s := range s[1:] {
|
|
||||||
fmt.Fprintf(&buf, ",%s", s.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
buf.WriteRune(']')
|
|
||||||
|
|
||||||
return buf.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// StateRecorder saves state changes pushed to StateRecorder.C().
|
|
||||||
type StateRecorder struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
ch chan *tunnel.ClientStateChange
|
|
||||||
recorded []*tunnel.ClientStateChange
|
|
||||||
offset int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStateRecorder() *StateRecorder {
|
|
||||||
rec := &StateRecorder{
|
|
||||||
ch: make(chan *tunnel.ClientStateChange, recBuffer),
|
|
||||||
}
|
|
||||||
|
|
||||||
go rec.record()
|
|
||||||
|
|
||||||
return rec
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *StateRecorder) record() {
|
|
||||||
for state := range rec.ch {
|
|
||||||
rec.mu.Lock()
|
|
||||||
rec.recorded = append(rec.recorded, state)
|
|
||||||
rec.mu.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *StateRecorder) C() chan<- *tunnel.ClientStateChange {
|
|
||||||
return rec.ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *StateRecorder) WaitTransitions(states ...tunnel.ClientState) error {
|
|
||||||
from := states[0]
|
|
||||||
for _, to := range states[1:] {
|
|
||||||
if err := rec.WaitTransition(from, to); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
from = to
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *StateRecorder) WaitTransition(from, to tunnel.ClientState) error {
|
|
||||||
timeout := time.After(recWaitTimeout)
|
|
||||||
|
|
||||||
var lastStates []*tunnel.ClientStateChange
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
return fmt.Errorf("timed out waiting for %s->%s transition: %v", from, to, States(lastStates))
|
|
||||||
default:
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
lastStates = rec.States()[rec.offset:]
|
|
||||||
|
|
||||||
for i, state := range lastStates {
|
|
||||||
if from != 0 && state.Previous != from {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if to != 0 && state.Current != to {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
rec.offset += i
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *StateRecorder) States() []*tunnel.ClientStateChange {
|
|
||||||
rec.mu.Lock()
|
|
||||||
defer rec.mu.Unlock()
|
|
||||||
|
|
||||||
states := make([]*tunnel.ClientStateChange, len(rec.recorded))
|
|
||||||
copy(states, rec.recorded)
|
|
||||||
return states
|
|
||||||
}
|
|
|
@ -1,561 +0,0 @@
|
||||||
package tunneltest
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel"
|
|
||||||
)
|
|
||||||
|
|
||||||
var debugNet = os.Getenv("DEBUGNET") == "1"
|
|
||||||
|
|
||||||
type dbgListener struct {
|
|
||||||
net.Listener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l dbgListener) Accept() (net.Conn, error) {
|
|
||||||
conn, err := l.Listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return dbgConn{conn}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type dbgConn struct {
|
|
||||||
net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c dbgConn) Read(p []byte) (int, error) {
|
|
||||||
n, err := c.Conn.Read(p)
|
|
||||||
os.Stderr.Write(p)
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c dbgConn) Write(p []byte) (int, error) {
|
|
||||||
n, err := c.Conn.Write(p)
|
|
||||||
os.Stderr.Write(p)
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func logf(format string, args ...interface{}) {
|
|
||||||
if testing.Verbose() {
|
|
||||||
log.Printf("[tunneltest] "+format, args...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func nonil(err ...error) error {
|
|
||||||
for _, e := range err {
|
|
||||||
if e != nil {
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseHostPort(addr string) (string, int, error) {
|
|
||||||
host, port, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := strconv.ParseUint(port, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return host, int(n), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsableAddrs returns all tcp addresses that we can bind a listener to.
|
|
||||||
func UsableAddrs() ([]*net.TCPAddr, error) {
|
|
||||||
addrs, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var usable []*net.TCPAddr
|
|
||||||
for _, addr := range addrs {
|
|
||||||
if ipNet, ok := addr.(*net.IPNet); ok {
|
|
||||||
if !ipNet.IP.IsLinkLocalUnicast() {
|
|
||||||
usable = append(usable, &net.TCPAddr{IP: ipNet.IP})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(usable) == 0 {
|
|
||||||
return nil, errors.New("no usable addresses found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return usable, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeHTTP = iota
|
|
||||||
TypeTCP
|
|
||||||
)
|
|
||||||
|
|
||||||
// Tunnel represents a single HTTP or TCP tunnel that can be served
|
|
||||||
// by TunnelTest.
|
|
||||||
type Tunnel struct {
|
|
||||||
// Type specifies a tunnel type - either TypeHTTP (default) or TypeTCP.
|
|
||||||
Type int
|
|
||||||
|
|
||||||
// Handler is a handler to use for serving tunneled connections on
|
|
||||||
// local server. The value of this field is required to be of type:
|
|
||||||
//
|
|
||||||
// - http.Handler or http.HandlerFunc for HTTP tunnels
|
|
||||||
// - func(net.Conn) for TCP tunnels
|
|
||||||
//
|
|
||||||
// Required field.
|
|
||||||
Handler interface{}
|
|
||||||
|
|
||||||
// LocalAddr is a network address of local server that handles
|
|
||||||
// connections/requests with Handler.
|
|
||||||
//
|
|
||||||
// Optional field, takes value of "127.0.0.1:0" when empty.
|
|
||||||
LocalAddr string
|
|
||||||
|
|
||||||
// ClientIdent is an identifier of a client that have already
|
|
||||||
// registered a HTTP tunnel and have established control connection.
|
|
||||||
//
|
|
||||||
// If the Type is TypeTCP, instead of creating new client
|
|
||||||
// for this TCP tunnel, we add it to an existing client
|
|
||||||
// specified by the field.
|
|
||||||
//
|
|
||||||
// Optional field for TCP tunnels.
|
|
||||||
// Ignored field for HTTP tunnels.
|
|
||||||
ClientIdent string
|
|
||||||
|
|
||||||
// RemoteAddr is a network address of remote server, which accepts
|
|
||||||
// connections on a tunnel server side.
|
|
||||||
//
|
|
||||||
// Required field for TCP tunnels.
|
|
||||||
// Ignored field for HTTP tunnels.
|
|
||||||
RemoteAddr string
|
|
||||||
|
|
||||||
// RemoteAddrIdent an identifier of an already existing listener,
|
|
||||||
// that listens on multiple interfaces; if the RemoteAddrIdent is valid
|
|
||||||
// identifier the IP field is required to be non-nil and RemoteAddr
|
|
||||||
// is ignored.
|
|
||||||
//
|
|
||||||
// Optional field for TCP tunnels.
|
|
||||||
// Ignored field for HTTP tunnels.
|
|
||||||
RemoteAddrIdent string
|
|
||||||
|
|
||||||
// IP specifies an IP address value for IP-based routing for TCP tunnels.
|
|
||||||
// For more details see inline documentation for (*tunnel.Server).AddAddr.
|
|
||||||
//
|
|
||||||
// Optional field for TCP tunnels.
|
|
||||||
// Ignored field for HTTP tunnels.
|
|
||||||
IP net.IP
|
|
||||||
|
|
||||||
// StateChanges listens on state transitions.
|
|
||||||
//
|
|
||||||
// If ClientIdent field is empty, the StateChanges will receive
|
|
||||||
// state transition events for the newly created client.
|
|
||||||
// Otherwise setting this field is a nop.
|
|
||||||
StateChanges chan<- *tunnel.ClientStateChange
|
|
||||||
}
|
|
||||||
|
|
||||||
type TunnelTest struct {
|
|
||||||
Server *tunnel.Server
|
|
||||||
ServerStateRecorder *StateRecorder
|
|
||||||
Clients map[string]*tunnel.Client
|
|
||||||
Listeners map[string][2]net.Listener // [0] is local listener, [1] is remote one (for TCP tunnels)
|
|
||||||
Addrs []*net.TCPAddr
|
|
||||||
Tunnels map[string]*Tunnel
|
|
||||||
DebugNet bool // for debugging network communication
|
|
||||||
|
|
||||||
mu sync.Mutex // protects Listeners
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTunnelTest() (*TunnelTest, error) {
|
|
||||||
rec := NewStateRecorder()
|
|
||||||
|
|
||||||
cfg := &tunnel.ServerConfig{
|
|
||||||
StateChanges: rec.C(),
|
|
||||||
Debug: testing.Verbose(),
|
|
||||||
}
|
|
||||||
s, err := tunnel.NewServer(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if debugNet {
|
|
||||||
l = dbgListener{l}
|
|
||||||
}
|
|
||||||
|
|
||||||
addrs, err := UsableAddrs()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
go (&http.Server{Handler: s}).Serve(l)
|
|
||||||
|
|
||||||
return &TunnelTest{
|
|
||||||
Server: s,
|
|
||||||
ServerStateRecorder: rec,
|
|
||||||
Clients: make(map[string]*tunnel.Client),
|
|
||||||
Listeners: map[string][2]net.Listener{"": {l, nil}},
|
|
||||||
Addrs: addrs,
|
|
||||||
Tunnels: make(map[string]*Tunnel),
|
|
||||||
DebugNet: debugNet,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serve creates new TunnelTest that serves the given tunnels.
|
|
||||||
//
|
|
||||||
// If tunnels is nil, DefaultTunnels() are used instead.
|
|
||||||
func Serve(tunnels map[string]*Tunnel) (*TunnelTest, error) {
|
|
||||||
tt, err := NewTunnelTest()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = tt.Serve(tunnels); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *TunnelTest) serveSingle(ident string, t *Tunnel) (bool, error) {
|
|
||||||
// Verify tunnel dependencies for TCP tunnels.
|
|
||||||
if t.Type == TypeTCP {
|
|
||||||
// If tunnel specified by t.Client was not already started,
|
|
||||||
// skip and move on.
|
|
||||||
if _, ok := tt.Clients[t.ClientIdent]; t.ClientIdent != "" && !ok {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the TCP tunnel whose remote endpoint listens on multiple
|
|
||||||
// interfaces is already served.
|
|
||||||
if t.RemoteAddrIdent != "" {
|
|
||||||
if _, ok := tt.Listeners[t.RemoteAddrIdent]; !ok {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.Tunnels[t.RemoteAddrIdent].Type != TypeTCP {
|
|
||||||
return false, fmt.Errorf("expected tunnel %q to be of TCP type", t.RemoteAddrIdent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
l, err := net.Listen("tcp", t.LocalAddr)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.LocalAddr, ident, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.DebugNet {
|
|
||||||
l = dbgListener{l}
|
|
||||||
}
|
|
||||||
|
|
||||||
localAddr := l.Addr().String()
|
|
||||||
httpProxy := &tunnel.HTTPProxy{LocalAddr: localAddr}
|
|
||||||
tcpProxy := &tunnel.TCPProxy{FetchLocalAddr: tt.fetchLocalAddr}
|
|
||||||
|
|
||||||
cfg := &tunnel.ClientConfig{
|
|
||||||
Identifier: ident,
|
|
||||||
ServerAddr: tt.ServerAddr().String(),
|
|
||||||
Proxy: tunnel.Proxy(tunnel.ProxyFuncs{
|
|
||||||
HTTP: httpProxy.Proxy,
|
|
||||||
TCP: tcpProxy.Proxy,
|
|
||||||
}),
|
|
||||||
StateChanges: t.StateChanges,
|
|
||||||
Debug: testing.Verbose(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register tunnel:
|
|
||||||
//
|
|
||||||
// - start tunnel.Client (tt.Clients[ident]) or reuse existing one (tt.Clients[t.ExistingClient])
|
|
||||||
// - listen on local address and start local server (tt.Listeners[ident][0])
|
|
||||||
// - register tunnel on tunnel.Server
|
|
||||||
//
|
|
||||||
switch t.Type {
|
|
||||||
case TypeHTTP:
|
|
||||||
// TODO(rjeczalik): refactor to separate method
|
|
||||||
|
|
||||||
h, ok := t.Handler.(http.Handler)
|
|
||||||
if !ok {
|
|
||||||
h, ok = t.Handler.(http.HandlerFunc)
|
|
||||||
if !ok {
|
|
||||||
fn, ok := t.Handler.(func(http.ResponseWriter, *http.Request))
|
|
||||||
if !ok {
|
|
||||||
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
h = http.HandlerFunc(fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
logf("serving on local %s for HTTP tunnel %q", l.Addr(), ident)
|
|
||||||
|
|
||||||
go (&http.Server{Handler: h}).Serve(l)
|
|
||||||
|
|
||||||
tt.Server.AddHost(localAddr, ident)
|
|
||||||
|
|
||||||
tt.mu.Lock()
|
|
||||||
tt.Listeners[ident] = [2]net.Listener{l, nil}
|
|
||||||
tt.mu.Unlock()
|
|
||||||
|
|
||||||
if err := tt.addClient(ident, cfg); err != nil {
|
|
||||||
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logf("registered HTTP tunnel: host=%s, ident=%s", localAddr, ident)
|
|
||||||
|
|
||||||
case TypeTCP:
|
|
||||||
// TODO(rjeczalik): refactor to separate method
|
|
||||||
|
|
||||||
h, ok := t.Handler.(func(net.Conn))
|
|
||||||
if !ok {
|
|
||||||
return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
logf("serving on local %s for TCP tunnel %q", l.Addr(), ident)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
conn, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed accepting conn for %q tunnel: %s", ident, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go h(conn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var remote net.Listener
|
|
||||||
|
|
||||||
if t.RemoteAddrIdent != "" {
|
|
||||||
tt.mu.Lock()
|
|
||||||
remote = tt.Listeners[t.RemoteAddrIdent][1]
|
|
||||||
tt.mu.Unlock()
|
|
||||||
} else {
|
|
||||||
remote, err = net.Listen("tcp", t.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.RemoteAddr, ident, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addrIdent holds identifier of client which is going to have registered
|
|
||||||
// tunnel via (*tunnel.Server).AddAddr
|
|
||||||
addrIdent := ident
|
|
||||||
if t.ClientIdent != "" {
|
|
||||||
tt.Clients[ident] = tt.Clients[t.ClientIdent]
|
|
||||||
addrIdent = t.ClientIdent
|
|
||||||
}
|
|
||||||
|
|
||||||
tt.Server.AddAddr(remote, t.IP, addrIdent)
|
|
||||||
|
|
||||||
tt.mu.Lock()
|
|
||||||
tt.Listeners[ident] = [2]net.Listener{l, remote}
|
|
||||||
tt.mu.Unlock()
|
|
||||||
|
|
||||||
if _, ok := tt.Clients[ident]; !ok {
|
|
||||||
if err := tt.addClient(ident, cfg); err != nil {
|
|
||||||
return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logf("registered TCP tunnel: listener=%s, ip=%v, ident=%s", remote.Addr(), t.IP, addrIdent)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return false, fmt.Errorf("unknown %q tunnel type: %d", ident, t.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *TunnelTest) addClient(ident string, cfg *tunnel.ClientConfig) error {
|
|
||||||
if _, ok := tt.Clients[ident]; ok {
|
|
||||||
return fmt.Errorf("tunnel %q is already being served", ident)
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := tunnel.NewClient(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
|
|
||||||
tt.Server.OnConnect(ident, func() error {
|
|
||||||
close(done)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
go c.Start()
|
|
||||||
<-c.StartNotify()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-time.After(10 * time.Second):
|
|
||||||
return errors.New("timed out after 10s waiting on client to establish control conn")
|
|
||||||
case <-done:
|
|
||||||
}
|
|
||||||
|
|
||||||
tt.Clients[ident] = c
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *TunnelTest) Serve(tunnels map[string]*Tunnel) error {
|
|
||||||
if len(tunnels) == 0 {
|
|
||||||
return errors.New("no tunnels to serve")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since one tunnels depends on others do 3 passes to start them
|
|
||||||
// all, each started tunnel is removed from the tunnels map.
|
|
||||||
// After 3 passes all of them must be started, otherwise the
|
|
||||||
// configuration is bad:
|
|
||||||
//
|
|
||||||
// - first pass starts HTTP tunnels as new client tunnels
|
|
||||||
// - second pass starts TCP tunnels that rely on on already existing client tunnels (t.ClientIdent)
|
|
||||||
// - third pass starts TCP tunnels that rely on on already existing TCP tunnels (t.RemoteAddrIdent)
|
|
||||||
//
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
if err := tt.popServedDeps(tunnels); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tunnels) != 0 {
|
|
||||||
unresolved := make([]string, len(tunnels))
|
|
||||||
for ident := range tunnels {
|
|
||||||
unresolved = append(unresolved, ident)
|
|
||||||
}
|
|
||||||
sort.Strings(unresolved)
|
|
||||||
|
|
||||||
return fmt.Errorf("unable to start tunnels due to unresolved dependencies: %v", unresolved)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *TunnelTest) popServedDeps(tunnels map[string]*Tunnel) error {
|
|
||||||
for ident, t := range tunnels {
|
|
||||||
ok, err := tt.serveSingle(ident, t)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
// Remove already started tunnels so they won't get started again.
|
|
||||||
delete(tunnels, ident)
|
|
||||||
tt.Tunnels[ident] = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *TunnelTest) fetchLocalAddr(port int) (string, error) {
|
|
||||||
tt.mu.Lock()
|
|
||||||
defer tt.mu.Unlock()
|
|
||||||
|
|
||||||
for _, l := range tt.Listeners {
|
|
||||||
if l[1] == nil {
|
|
||||||
// this listener does not belong to a TCP tunnel
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, remotePort, err := parseHostPort(l[1].Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if port == remotePort {
|
|
||||||
return l[0].Addr().String(), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("no route for %d port", port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *TunnelTest) ServerAddr() net.Addr {
|
|
||||||
return tt.Listeners[""][0].Addr()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Addr gives server endpoint of the TCP tunnel for the given ident.
|
|
||||||
//
|
|
||||||
// If the tunnel does not exist or is a HTTP one, TunnelAddr return nil.
|
|
||||||
func (tt *TunnelTest) Addr(ident string) net.Addr {
|
|
||||||
l, ok := tt.Listeners[ident]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return l[1].Addr()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request creates a HTTP request to a server endpoint of the HTTP tunnel
|
|
||||||
// for the given ident.
|
|
||||||
//
|
|
||||||
// If the tunnel does not exist, Request returns nil.
|
|
||||||
func (tt *TunnelTest) Request(ident string, query url.Values) *http.Request {
|
|
||||||
l, ok := tt.Listeners[ident]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var raw string
|
|
||||||
if query != nil {
|
|
||||||
raw = query.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Request{
|
|
||||||
Method: "GET",
|
|
||||||
URL: &url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: tt.ServerAddr().String(),
|
|
||||||
Path: "/",
|
|
||||||
RawQuery: raw,
|
|
||||||
},
|
|
||||||
Proto: "HTTP/1.1",
|
|
||||||
ProtoMajor: 1,
|
|
||||||
ProtoMinor: 1,
|
|
||||||
Host: l[0].Addr().String(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *TunnelTest) Close() (err error) {
|
|
||||||
// Close tunnel.Clients.
|
|
||||||
clients := make(map[*tunnel.Client]struct{})
|
|
||||||
for _, c := range tt.Clients {
|
|
||||||
clients[c] = struct{}{}
|
|
||||||
}
|
|
||||||
for c := range clients {
|
|
||||||
err = nonil(err, c.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop all TCP/HTTP servers.
|
|
||||||
listeners := make(map[net.Listener]struct{})
|
|
||||||
for _, l := range tt.Listeners {
|
|
||||||
for _, l := range l {
|
|
||||||
if l != nil {
|
|
||||||
listeners[l] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for l := range listeners {
|
|
||||||
err = nonil(err, l.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -1,120 +0,0 @@
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
|
||||||
"github.com/cenkalti/backoff"
|
|
||||||
)
|
|
||||||
|
|
||||||
// async is a helper function to convert a blocking function to a function
|
|
||||||
// returning an error. Useful for plugging function closures into select and co
|
|
||||||
func async(fn func() error) <-chan error {
|
|
||||||
errChan := make(chan error, 0)
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case errChan <- fn():
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
close(errChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return errChan
|
|
||||||
}
|
|
||||||
|
|
||||||
type expBackoff struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
bk *backoff.ExponentialBackOff
|
|
||||||
}
|
|
||||||
|
|
||||||
func newForeverBackoff() *expBackoff {
|
|
||||||
eb := &expBackoff{
|
|
||||||
bk: backoff.NewExponentialBackOff(),
|
|
||||||
}
|
|
||||||
eb.bk.MaxElapsedTime = 0 // never stops
|
|
||||||
return eb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (eb *expBackoff) NextBackOff() time.Duration {
|
|
||||||
eb.mu.Lock()
|
|
||||||
defer eb.mu.Unlock()
|
|
||||||
|
|
||||||
return eb.bk.NextBackOff()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (eb *expBackoff) Reset() {
|
|
||||||
eb.mu.Lock()
|
|
||||||
eb.bk.Reset()
|
|
||||||
eb.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
type callbacks struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
name string
|
|
||||||
funcs map[string]func() error
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCallbacks(name string) *callbacks {
|
|
||||||
return &callbacks{
|
|
||||||
name: name,
|
|
||||||
funcs: make(map[string]func() error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *callbacks) add(identifier string, fn func() error) {
|
|
||||||
c.mu.Lock()
|
|
||||||
c.funcs[identifier] = fn
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *callbacks) pop(identifier string) (func() error, error) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
fn, ok := c.funcs[identifier]
|
|
||||||
if !ok {
|
|
||||||
return nil, nil // nop
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(c.funcs, identifier)
|
|
||||||
|
|
||||||
if fn == nil {
|
|
||||||
return nil, fmt.Errorf("nil callback set for %q client", identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *callbacks) call(identifier string) error {
|
|
||||||
fn, err := c.pop(identifier)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if fn == nil {
|
|
||||||
return nil // nop
|
|
||||||
}
|
|
||||||
|
|
||||||
return fn()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns server control url as a string. Reads scheme and remote address from connection.
|
|
||||||
func controlURL(conn net.Conn) string {
|
|
||||||
return fmt.Sprint(scheme(conn), "://", conn.RemoteAddr(), proto.ControlPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func scheme(conn net.Conn) (scheme string) {
|
|
||||||
switch conn.(type) {
|
|
||||||
case *tls.Conn:
|
|
||||||
scheme = "https"
|
|
||||||
default:
|
|
||||||
scheme = "http"
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,179 +0,0 @@
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/koding/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
type listener struct {
|
|
||||||
net.Listener
|
|
||||||
*vaddrOptions
|
|
||||||
|
|
||||||
done int32
|
|
||||||
|
|
||||||
// ips keeps track of registered clients for ip-based routing;
|
|
||||||
// when last client is deleted from the ip routing map, we stop
|
|
||||||
// listening on connections
|
|
||||||
ips map[string]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type vaddrOptions struct {
|
|
||||||
connCh chan<- net.Conn
|
|
||||||
log logging.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
type vaddrStorage struct {
|
|
||||||
*vaddrOptions
|
|
||||||
|
|
||||||
listeners map[net.Listener]*listener
|
|
||||||
ports map[int]string // port-based routing: maps port number to identifier
|
|
||||||
ips map[string]string // ip-based routing: maps ip address to identifier
|
|
||||||
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newVirtualAddrs(opts *vaddrOptions) *vaddrStorage {
|
|
||||||
return &vaddrStorage{
|
|
||||||
vaddrOptions: opts,
|
|
||||||
listeners: make(map[net.Listener]*listener),
|
|
||||||
ports: make(map[int]string),
|
|
||||||
ips: make(map[string]string),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *listener) serve() {
|
|
||||||
for {
|
|
||||||
conn, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
l.log.Error("failue listening on %q: %s", l.Addr(), err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if atomic.LoadInt32(&l.done) != 0 {
|
|
||||||
l.log.Debug("stopped serving %q", l.Addr())
|
|
||||||
conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
l.connCh <- conn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *listener) localAddr() string {
|
|
||||||
if addr, ok := l.Addr().(*net.TCPAddr); ok {
|
|
||||||
if addr.IP.Equal(net.IPv4zero) {
|
|
||||||
return net.JoinHostPort("127.0.0.1", strconv.Itoa(addr.Port))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return l.Addr().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *listener) stop() {
|
|
||||||
if atomic.CompareAndSwapInt32(&l.done, 0, 1) {
|
|
||||||
// stop is called when no more connections should be accepted by
|
|
||||||
// the user-provided listener; as we can't simple close the listener
|
|
||||||
// to not break the guarantee given by the (*Server).DeleteAddr
|
|
||||||
// method, we make a dummy connection to break out of serve loop.
|
|
||||||
// It is safe to make a dummy connection, as either the following
|
|
||||||
// dial will time out when the listener is busy accepting connections,
|
|
||||||
// or will get closed immadiately after idle listeners accepts connection
|
|
||||||
// and returns from the serve loop.
|
|
||||||
conn, err := net.DialTimeout("tcp", l.localAddr(), defaultTimeout)
|
|
||||||
if err == nil {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vaddr *vaddrStorage) Add(l net.Listener, ip net.IP, ident string) {
|
|
||||||
vaddr.mu.Lock()
|
|
||||||
defer vaddr.mu.Unlock()
|
|
||||||
|
|
||||||
lis, ok := vaddr.listeners[l]
|
|
||||||
if !ok {
|
|
||||||
lis = vaddr.newListener(l)
|
|
||||||
vaddr.listeners[l] = lis
|
|
||||||
go lis.serve()
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip != nil {
|
|
||||||
lis.ips[ip.String()] = struct{}{}
|
|
||||||
vaddr.ips[ip.String()] = ident
|
|
||||||
} else {
|
|
||||||
vaddr.ports[mustPort(l)] = ident
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vaddr *vaddrStorage) Delete(l net.Listener, ip net.IP) {
|
|
||||||
vaddr.mu.Lock()
|
|
||||||
defer vaddr.mu.Unlock()
|
|
||||||
|
|
||||||
lis, ok := vaddr.listeners[l]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var stop bool
|
|
||||||
|
|
||||||
if ip != nil {
|
|
||||||
delete(lis.ips, ip.String())
|
|
||||||
delete(vaddr.ips, ip.String())
|
|
||||||
|
|
||||||
stop = len(lis.ips) == 0
|
|
||||||
} else {
|
|
||||||
delete(vaddr.ports, mustPort(l))
|
|
||||||
|
|
||||||
stop = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only stop listening for connections when listener has clients
|
|
||||||
// registered to tunnel the connections to.
|
|
||||||
if stop {
|
|
||||||
lis.stop()
|
|
||||||
delete(vaddr.listeners, l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vaddr *vaddrStorage) newListener(l net.Listener) *listener {
|
|
||||||
return &listener{
|
|
||||||
Listener: l,
|
|
||||||
vaddrOptions: vaddr.vaddrOptions,
|
|
||||||
ips: make(map[string]struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vaddr *vaddrStorage) getIdent(conn net.Conn) (string, bool) {
|
|
||||||
vaddr.mu.Lock()
|
|
||||||
defer vaddr.mu.Unlock()
|
|
||||||
|
|
||||||
ip, port, err := parseHostPort(conn.LocalAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
vaddr.log.Debug("failed to get identifier for connection %q: %s", conn.LocalAddr(), err)
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
// First lookup if there's a ip-based route, then try port-base one.
|
|
||||||
|
|
||||||
if ident, ok := vaddr.ips[ip]; ok {
|
|
||||||
return ident, true
|
|
||||||
}
|
|
||||||
|
|
||||||
ident, ok := vaddr.ports[port]
|
|
||||||
return ident, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func mustPort(l net.Listener) int {
|
|
||||||
_, port, err := parseHostPort(l.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
// This can happened when user passed custom type that
|
|
||||||
// implements net.Listener, which returns ill-formed
|
|
||||||
// net.Addr value.
|
|
||||||
panic("ill-formed net.Addr: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return port
|
|
||||||
}
|
|
|
@ -1,77 +0,0 @@
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type vhostStorage interface {
|
|
||||||
// AddHost adds the given host and identifier to the storage
|
|
||||||
AddHost(host, identifier string)
|
|
||||||
|
|
||||||
// DeleteHost deletes the given host
|
|
||||||
DeleteHost(host string)
|
|
||||||
|
|
||||||
// GetHost returns the host name for the given identifier
|
|
||||||
GetHost(identifier string) (string, bool)
|
|
||||||
|
|
||||||
// GetIdentifier returns the identifier for the given host
|
|
||||||
GetIdentifier(host string) (string, bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
type virtualHost struct {
|
|
||||||
identifier string
|
|
||||||
}
|
|
||||||
|
|
||||||
// virtualHosts is used for mapping host to users example: host
|
|
||||||
// "fs-1-fatih.kd.io" belongs to user "arslan"
|
|
||||||
type virtualHosts struct {
|
|
||||||
mapping map[string]*virtualHost
|
|
||||||
sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// newVirtualHosts provides an in memory virtual host storage for mapping
|
|
||||||
// virtual hosts to identifiers.
|
|
||||||
func newVirtualHosts() *virtualHosts {
|
|
||||||
return &virtualHosts{
|
|
||||||
mapping: make(map[string]*virtualHost),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *virtualHosts) AddHost(host, identifier string) {
|
|
||||||
v.Lock()
|
|
||||||
v.mapping[host] = &virtualHost{identifier: identifier}
|
|
||||||
v.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *virtualHosts) DeleteHost(host string) {
|
|
||||||
v.Lock()
|
|
||||||
delete(v.mapping, host)
|
|
||||||
v.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetIdentifier returns the identifier associated with the given host
|
|
||||||
func (v *virtualHosts) GetIdentifier(host string) (string, bool) {
|
|
||||||
v.Lock()
|
|
||||||
ht, ok := v.mapping[host]
|
|
||||||
v.Unlock()
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
return ht.identifier, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetHost returns the host associated with the given identifier
|
|
||||||
func (v *virtualHosts) GetHost(identifier string) (string, bool) {
|
|
||||||
v.Lock()
|
|
||||||
defer v.Unlock()
|
|
||||||
|
|
||||||
for hostname, hst := range v.mapping {
|
|
||||||
if hst.identifier == identifier {
|
|
||||||
return hostname, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", false
|
|
||||||
}
|
|
|
@ -1,69 +0,0 @@
|
||||||
package tunnel_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testWebsocket(name string, n int, t *testing.T, tt *tunneltest.TunnelTest) {
|
|
||||||
conn, err := websocketDial(tt, "http")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Dial()=%s", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
want := &EchoMessage{
|
|
||||||
Value: fmt.Sprintf("message #%d", i),
|
|
||||||
Close: i == (n - 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := conn.WriteJSON(want)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("(test %s) %d: failed sending %q: %s", name, i, want, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
got := &EchoMessage{}
|
|
||||||
|
|
||||||
err = conn.ReadJSON(got)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("(test %s) %d: failed reading: %s", name, i, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(got, want) {
|
|
||||||
t.Errorf("(test %s) %d: got %+v, want %+v", name, i, got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testHandler(t *testing.T, fn func(w http.ResponseWriter, r *http.Request) error) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if err := fn(w, r); err != nil {
|
|
||||||
t.Errorf("handler func error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWebsocket(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(nil))))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
testWebsocket("handlerEchoWS", 100, t, tt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLatencyWebsocket(t *testing.T) {
|
|
||||||
tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(sleep))))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
testWebsocket("handlerLatencyEchoWS", 20, t, tt)
|
|
||||||
}
|
|
Loading…
Reference in New Issue