565 lines
15 KiB
Go
565 lines
15 KiB
Go
package tunnel
|
|
|
|
import (
|
|
"bufio"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
|
"github.com/hashicorp/yamux"
|
|
"github.com/koding/logging"
|
|
)
|
|
|
|
//go:generate stringer -type ClientState
|
|
|
|
// ErrRedialAborted is emitted on ClientClosed event, when backoff policy
|
|
// used by a client decided no more reconnection attempts must be made.
|
|
var ErrRedialAborted = errors.New("unable to restore the connection, aborting")
|
|
|
|
// ClientState represents client connection state to tunnel server.
|
|
type ClientState uint32
|
|
|
|
// ClientState enumeration.
|
|
const (
|
|
ClientUnknown ClientState = iota
|
|
ClientStarted
|
|
ClientConnecting
|
|
ClientConnected
|
|
ClientDisconnected
|
|
ClientClosed // keep it always last
|
|
)
|
|
|
|
// ClientStateChange represents single client state transition.
|
|
type ClientStateChange struct {
|
|
Identifier string
|
|
Previous ClientState
|
|
Current ClientState
|
|
Error error
|
|
}
|
|
|
|
// Strings implements the fmt.Stringer interface.
|
|
func (cs *ClientStateChange) String() string {
|
|
if cs.Error != nil {
|
|
return fmt.Sprintf("[%s] %s->%s (%s)", cs.Identifier, cs.Previous, cs.Current, cs.Error)
|
|
}
|
|
return fmt.Sprintf("[%s] %s->%s", cs.Identifier, cs.Previous, cs.Current)
|
|
}
|
|
|
|
// Backoff defines behavior of staggering reconnection retries.
|
|
type Backoff interface {
|
|
// Next returns the duration to sleep before retrying reconnections.
|
|
// If the returned value is negative, the retry is aborted.
|
|
NextBackOff() time.Duration
|
|
|
|
// Reset is used to signal a reconnection was successful and next
|
|
// call to Next should return desired time duration for 1st reconnection
|
|
// attempt.
|
|
Reset()
|
|
}
|
|
|
|
// Client is responsible for creating a control connection to a tunnel server,
|
|
// creating new tunnels and proxy them to tunnel server.
|
|
type Client struct {
|
|
// underlying yamux session
|
|
session *yamux.Session
|
|
|
|
// config holds the ClientConfig
|
|
config *ClientConfig
|
|
|
|
// yamuxConfig is passed to new yamux.Session's
|
|
yamuxConfig *yamux.Config
|
|
|
|
// proxy handles local server communication.
|
|
proxy ProxyFunc
|
|
|
|
// startNotify is a chanel user can get to be notified when client is
|
|
// connected to the server. The preferred way of doing this however,
|
|
// would be using StateChanges in ClientConfig where user can provide
|
|
// his own channel.
|
|
startNotify chan bool
|
|
// closed is a flag set when client calls Close() and quits.
|
|
closed bool
|
|
// closedMu guards both closed flag and startNotify channel. Since library
|
|
// owns the channel it's cleared when trying to reconnect.
|
|
closedMu sync.RWMutex
|
|
|
|
reqWg sync.WaitGroup
|
|
ctrlWg sync.WaitGroup
|
|
|
|
state ClientState
|
|
|
|
// redialBackoff is used to reconnect in exponential backoff intervals
|
|
redialBackoff Backoff
|
|
|
|
log logging.Logger
|
|
}
|
|
|
|
// ClientConfig defines the configuration for the Client
|
|
type ClientConfig struct {
|
|
// Identifier is the secret token that needs to be passed to the server.
|
|
// Required if FetchIdentifier is not set.
|
|
Identifier string
|
|
|
|
// FetchIdentifier can be used to fetch identifier. Required if Identifier
|
|
// is not set.
|
|
FetchIdentifier func() (string, error)
|
|
|
|
// ServerAddr defines the TCP address of the tunnel server to be connected.
|
|
// Required if FetchServerAddr is not set.
|
|
ServerAddr string
|
|
|
|
// FetchServerAddr can be used to fetch tunnel server address.
|
|
// Required if ServerAddress is not set.
|
|
FetchServerAddr func() (string, error)
|
|
|
|
// Dial provides custom transport layer for client server communication.
|
|
//
|
|
// If nil, default implementation is to return net.Dial("tcp", address).
|
|
//
|
|
// It can be used for connection monitoring, setting different timeouts or
|
|
// securing the connection.
|
|
Dial func(network, address string) (net.Conn, error)
|
|
|
|
// Proxy defines custom proxing logic. This is optional extension point
|
|
// where you can provide your local server selection or communication rules.
|
|
Proxy ProxyFunc
|
|
|
|
// StateChanges receives state transition details each time client
|
|
// connection state changes. The channel is expected to be sufficiently
|
|
// buffered to keep up with event pace.
|
|
//
|
|
// If nil, no information about state transitions are dispatched
|
|
// by the library.
|
|
StateChanges chan<- *ClientStateChange
|
|
|
|
// Backoff is used to control behavior of staggering reconnection loop.
|
|
//
|
|
// If nil, default backoff policy is used which makes a client to never
|
|
// give up on reconnection.
|
|
//
|
|
// If custom backoff is used, client will emit ErrRedialAborted set
|
|
// with ClientClosed event when no more reconnection atttemps should
|
|
// be made.
|
|
Backoff Backoff
|
|
|
|
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
|
|
// yamux.DefaultConfig() is used.
|
|
YamuxConfig *yamux.Config
|
|
|
|
// Log defines the logger. If nil a default logging.Logger is used.
|
|
Log logging.Logger
|
|
|
|
// Debug enables debug mode, enable only if you want to debug the server.
|
|
Debug bool
|
|
|
|
// DEPRECATED:
|
|
|
|
// LocalAddr is DEPRECATED please use ProxyHTTP.LocalAddr, see ProxyOverwrite for more details.
|
|
LocalAddr string
|
|
|
|
// FetchLocalAddr is DEPRECATED please use ProxyTCP.FetchLocalAddr, see ProxyOverwrite for more details.
|
|
FetchLocalAddr func(port int) (string, error)
|
|
}
|
|
|
|
// verify is used to verify the ClientConfig
|
|
func (c *ClientConfig) verify() error {
|
|
if c.ServerAddr == "" && c.FetchServerAddr == nil {
|
|
return errors.New("neither ServerAddr nor FetchServerAddr is set")
|
|
}
|
|
|
|
if c.Identifier == "" && c.FetchIdentifier == nil {
|
|
return errors.New("neither Identifier nor FetchIdentifier is set")
|
|
}
|
|
|
|
if c.YamuxConfig != nil {
|
|
if err := yamux.VerifyConfig(c.YamuxConfig); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if c.Proxy != nil && (c.LocalAddr != "" || c.FetchLocalAddr != nil) {
|
|
return errors.New("both Proxy and LocalAddr or FetchLocalAddr are set")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// NewClient creates a new tunnel that is established between the serverAddr
|
|
// and localAddr. It exits if it can't create a new control connection to the
|
|
// server. If localAddr is empty client will always try to proxy to a local
|
|
// port.
|
|
func NewClient(cfg *ClientConfig) (*Client, error) {
|
|
if err := cfg.verify(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
yamuxConfig := yamux.DefaultConfig()
|
|
if cfg.YamuxConfig != nil {
|
|
yamuxConfig = cfg.YamuxConfig
|
|
}
|
|
|
|
var proxy = DefaultProxy
|
|
if cfg.Proxy != nil {
|
|
proxy = cfg.Proxy
|
|
}
|
|
// DEPRECATED API SUPPORT
|
|
if cfg.LocalAddr != "" || cfg.FetchLocalAddr != nil {
|
|
var f ProxyFuncs
|
|
if cfg.LocalAddr != "" {
|
|
f.HTTP = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
|
|
f.WS = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy
|
|
}
|
|
if cfg.FetchLocalAddr != nil {
|
|
f.TCP = (&TCPProxy{FetchLocalAddr: cfg.FetchLocalAddr}).Proxy
|
|
}
|
|
proxy = Proxy(f)
|
|
}
|
|
|
|
var bo Backoff = newForeverBackoff()
|
|
if cfg.Backoff != nil {
|
|
bo = cfg.Backoff
|
|
}
|
|
|
|
log := newLogger("tunnel-client", cfg.Debug)
|
|
if cfg.Log != nil {
|
|
log = cfg.Log
|
|
}
|
|
|
|
client := &Client{
|
|
config: cfg,
|
|
yamuxConfig: yamuxConfig,
|
|
proxy: proxy,
|
|
startNotify: make(chan bool, 1),
|
|
redialBackoff: bo,
|
|
log: log,
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
// Start starts the client and connects to the server with the identifier.
|
|
// client.FetchIdentifier() will be used if it's not nil. It's supports
|
|
// reconnecting with exponential backoff intervals when the connection to the
|
|
// server disconnects. Call client.Close() to shutdown the client completely. A
|
|
// successful connection will cause StartNotify() to receive a value.
|
|
func (c *Client) Start() {
|
|
fetchIdent := func() (string, error) {
|
|
if c.config.FetchIdentifier != nil {
|
|
return c.config.FetchIdentifier()
|
|
}
|
|
|
|
return c.config.Identifier, nil
|
|
}
|
|
|
|
fetchServerAddr := func() (string, error) {
|
|
if c.config.FetchServerAddr != nil {
|
|
return c.config.FetchServerAddr()
|
|
}
|
|
|
|
return c.config.ServerAddr, nil
|
|
}
|
|
|
|
c.changeState(ClientStarted, nil)
|
|
|
|
c.redialBackoff.Reset()
|
|
var lastErr error
|
|
for {
|
|
prev := c.changeState(ClientConnecting, lastErr)
|
|
|
|
if c.isRetry(prev) {
|
|
dur := c.redialBackoff.NextBackOff()
|
|
if dur < 0 {
|
|
c.setClosed(true)
|
|
c.changeState(ClientClosed, ErrRedialAborted)
|
|
return
|
|
}
|
|
|
|
time.Sleep(dur)
|
|
|
|
// exit if closed
|
|
if c.isClosed() {
|
|
c.changeState(ClientClosed, lastErr)
|
|
return
|
|
}
|
|
}
|
|
|
|
identifier, err := fetchIdent()
|
|
if err != nil {
|
|
lastErr = err
|
|
c.log.Critical("client fetch identifier error: %s", err)
|
|
continue
|
|
}
|
|
|
|
serverAddr, err := fetchServerAddr()
|
|
if err != nil {
|
|
lastErr = err
|
|
c.log.Critical("client fetch server address error: %s", err)
|
|
continue
|
|
}
|
|
|
|
c.setClosed(false)
|
|
|
|
if err := c.connect(identifier, serverAddr); err != nil {
|
|
lastErr = err
|
|
c.log.Debug("client connect error: %s", err)
|
|
}
|
|
|
|
// exit if closed
|
|
if c.isClosed() {
|
|
c.changeState(ClientClosed, lastErr)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close closes the client and shutdowns the connection to the tunnel server
|
|
func (c *Client) Close() error {
|
|
defer c.setClosed(true)
|
|
|
|
if c.session == nil {
|
|
return errors.New("session is not initialized")
|
|
}
|
|
|
|
// wait until all connections are finished
|
|
waitCh := make(chan struct{})
|
|
go func() {
|
|
if err := c.session.GoAway(); err != nil {
|
|
c.log.Debug("Session go away failed: %s", err)
|
|
}
|
|
|
|
c.reqWg.Wait()
|
|
close(waitCh)
|
|
}()
|
|
select {
|
|
case <-waitCh:
|
|
// ok
|
|
case <-time.After(time.Second * 10):
|
|
c.log.Info("Timeout waiting for connections to finish")
|
|
}
|
|
|
|
if err := c.session.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isClosed securely checks if client is marked as closed.
|
|
func (c *Client) isClosed() bool {
|
|
c.closedMu.RLock()
|
|
defer c.closedMu.RUnlock()
|
|
return c.closed
|
|
}
|
|
|
|
// setClosed securely marks client as closed (or not closed). If not closed
|
|
// also empty the value inside the startNotify channel by retrieving it (if any),
|
|
// so it doesn't block during connect, when the client was closed and started again,
|
|
// and startNotify was never listened to.
|
|
func (c *Client) setClosed(closed bool) {
|
|
c.closedMu.Lock()
|
|
defer c.closedMu.Unlock()
|
|
c.closed = closed
|
|
|
|
if !closed {
|
|
// clear channel
|
|
select {
|
|
case <-c.startNotify:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
// startNotifyIfNeeded sends ok to startNotify channel if it's listened to.
|
|
// This function is called by connect when connection was successful.
|
|
func (c *Client) startNotifyIfNeeded() {
|
|
c.closedMu.RLock()
|
|
if !c.closed {
|
|
c.log.Debug("sending ok to startNotify chan")
|
|
select {
|
|
case c.startNotify <- true:
|
|
default:
|
|
// reaching here means the client never read the signal via
|
|
// StartNotify(). This is OK, we shouldn't except it the consumer
|
|
// to read from this channel. It's optional, so we just drop the
|
|
// signal.
|
|
c.log.Debug("startNotify message was dropped")
|
|
}
|
|
}
|
|
c.closedMu.RUnlock()
|
|
}
|
|
|
|
// StartNotify returns a channel that receives a single value when the client
|
|
// established a successful connection to the server.
|
|
func (c *Client) StartNotify() <-chan bool {
|
|
return c.startNotify
|
|
}
|
|
|
|
func (c *Client) changeState(state ClientState, err error) (prev ClientState) {
|
|
prev = ClientState(atomic.LoadUint32((*uint32)(&c.state)))
|
|
|
|
if c.config.StateChanges != nil {
|
|
change := &ClientStateChange{
|
|
Identifier: c.config.Identifier,
|
|
Previous: ClientState(prev),
|
|
Current: state,
|
|
Error: err,
|
|
}
|
|
|
|
select {
|
|
case c.config.StateChanges <- change:
|
|
default:
|
|
c.log.Warning("Dropping state change due to slow reader: %s", change)
|
|
}
|
|
}
|
|
|
|
atomic.CompareAndSwapUint32((*uint32)(&c.state), uint32(prev), uint32(state))
|
|
|
|
return prev
|
|
}
|
|
|
|
func (c *Client) isRetry(state ClientState) bool {
|
|
return state != ClientStarted && state != ClientClosed
|
|
}
|
|
|
|
func (c *Client) connect(identifier, serverAddr string) error {
|
|
c.log.Debug("Trying to connect to %q with identifier %q", serverAddr, identifier)
|
|
|
|
conn, err := c.dial(serverAddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
remoteURL := controlURL(conn)
|
|
c.log.Debug("CONNECT to %q", remoteURL)
|
|
req, err := http.NewRequest("CONNECT", remoteURL, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating request to %s: %s", remoteURL, err)
|
|
}
|
|
|
|
req.Header.Set(proto.ClientIdentifierHeader, identifier)
|
|
|
|
c.log.Debug("Writing request to TCP: %+v", req)
|
|
|
|
if err := req.Write(conn); err != nil {
|
|
return fmt.Errorf("writing CONNECT request to %s failed: %s", req.URL, err)
|
|
}
|
|
|
|
c.log.Debug("Reading response from TCP")
|
|
|
|
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
|
if err != nil {
|
|
return fmt.Errorf("reading CONNECT response from %s failed: %s", req.URL, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK || resp.Status != proto.Connected {
|
|
out, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("tunnel server error: status=%d, error=%s", resp.StatusCode, err)
|
|
}
|
|
|
|
return fmt.Errorf("tunnel server error: status=%d, body=%s", resp.StatusCode, string(out))
|
|
}
|
|
|
|
c.ctrlWg.Wait() // wait until previous listenControl observes disconnection
|
|
|
|
c.session, err = yamux.Client(conn, c.yamuxConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("session initialization failed: %s", err)
|
|
}
|
|
|
|
var stream net.Conn
|
|
openStream := func() error {
|
|
// this is blocking until client opens a session to us
|
|
stream, err = c.session.Open()
|
|
return err
|
|
}
|
|
|
|
// if we don't receive anything from the server, we'll timeout
|
|
select {
|
|
case err := <-async(openStream):
|
|
if err != nil {
|
|
return fmt.Errorf("waiting for session to open failed: %s", err)
|
|
}
|
|
case <-time.After(time.Second * 10):
|
|
if stream != nil {
|
|
stream.Close()
|
|
}
|
|
return errors.New("timeout opening session")
|
|
}
|
|
|
|
if _, err := stream.Write([]byte(proto.HandshakeRequest)); err != nil {
|
|
return fmt.Errorf("writing handshake request failed: %s", err)
|
|
}
|
|
|
|
buf := make([]byte, len(proto.HandshakeResponse))
|
|
if _, err := stream.Read(buf); err != nil {
|
|
return fmt.Errorf("reading handshake response failed: %s", err)
|
|
}
|
|
|
|
if string(buf) != proto.HandshakeResponse {
|
|
return fmt.Errorf("invalid handshake response, received: %s", string(buf))
|
|
}
|
|
|
|
ct := newControl(stream)
|
|
c.log.Debug("client has started successfully")
|
|
c.redialBackoff.Reset() // we successfully connected, so we can reset the backoff
|
|
|
|
c.startNotifyIfNeeded()
|
|
|
|
return c.listenControl(ct)
|
|
}
|
|
|
|
func (c *Client) dial(serverAddr string) (net.Conn, error) {
|
|
if c.config.Dial != nil {
|
|
return c.config.Dial("tcp", serverAddr)
|
|
}
|
|
|
|
return net.Dial("tcp", serverAddr)
|
|
}
|
|
|
|
func (c *Client) listenControl(ct *control) error {
|
|
c.ctrlWg.Add(1)
|
|
defer c.ctrlWg.Done()
|
|
|
|
c.changeState(ClientConnected, nil)
|
|
|
|
for {
|
|
var msg proto.ControlMessage
|
|
if err := ct.dec.Decode(&msg); err != nil {
|
|
c.reqWg.Wait() // wait until all requests are finished
|
|
c.session.GoAway()
|
|
c.session.Close()
|
|
c.changeState(ClientDisconnected, err)
|
|
|
|
return fmt.Errorf("failure decoding control message: %s", err)
|
|
}
|
|
|
|
c.log.Debug("Received control msg %+v", msg)
|
|
c.log.Debug("Opening a new stream from server session")
|
|
|
|
remote, err := c.session.Open()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
isHTTP := msg.Protocol == proto.HTTP
|
|
if isHTTP {
|
|
c.reqWg.Add(1)
|
|
}
|
|
go func() {
|
|
c.proxy(remote, &msg)
|
|
if isHTTP {
|
|
c.reqWg.Done()
|
|
}
|
|
remote.Close()
|
|
}()
|
|
}
|
|
}
|