route/vendor/github.com/mmatczuk/go-http-tunnel/client.go

315 lines
6.2 KiB
Go

// Copyright (C) 2017 MichaƂ Matczuk
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tunnel
import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
"golang.org/x/net/http2"
"github.com/mmatczuk/go-http-tunnel/log"
"github.com/mmatczuk/go-http-tunnel/proto"
)
var (
// DefaultTimeout specifies general purpose timeout.
DefaultTimeout = 10 * time.Second
)
// ClientConfig is configuration of the Client.
type ClientConfig struct {
// ServerAddr specifies TCP address of the tunnel server.
ServerAddr string
// TLSClientConfig specifies the tls configuration to use with
// tls.Client.
TLSClientConfig *tls.Config
// DialTLS specifies an optional dial function that creates a tls
// connection to the server. If DialTLS is nil, tls.Dial is used.
DialTLS func(network, addr string, config *tls.Config) (net.Conn, error)
// Backoff specifies backoff policy on server connection retry. If nil
// when dial fails it will not be retried.
Backoff Backoff
// Tunnels specifies the tunnels client requests to be opened on server.
Tunnels map[string]*proto.Tunnel
// Proxy is ProxyFunc responsible for transferring data between server
// and local services.
Proxy ProxyFunc
// Logger is optional logger. If nil logging is disabled.
Logger log.Logger
}
// Client is responsible for creating connection to the server, handling control
// messages. It uses ProxyFunc for transferring data between server and local
// services.
type Client struct {
config *ClientConfig
conn net.Conn
connMu sync.Mutex
httpServer *http2.Server
serverErr error
lastDisconnect time.Time
logger log.Logger
}
// NewClient creates a new unconnected Client based on configuration. Caller
// must invoke Start() on returned instance in order to connect server.
func NewClient(config *ClientConfig) *Client {
if config.ServerAddr == "" {
panic("missing ServerAddr")
}
if config.TLSClientConfig == nil {
panic("missing TLSClientConfig")
}
if config.Tunnels == nil || len(config.Tunnels) == 0 {
panic("missing Tunnels")
}
if config.Proxy == nil {
panic("missing Proxy")
}
logger := config.Logger
if logger == nil {
logger = log.NewNopLogger()
}
c := &Client{
config: config,
httpServer: &http2.Server{},
logger: logger,
}
return c
}
// Start connects client to the server, it returns error if there is a
// connection error, or server cannot open requested tunnels. On connection
// error a backoff policy is used to reestablish the connection. When connected
// HTTP/2 server is started to handle ControlMessages.
func (c *Client) Start() error {
c.logger.Log(
"level", 1,
"action", "start",
)
for {
conn, err := c.connect()
if err != nil {
return err
}
c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
Handler: http.HandlerFunc(c.serveHTTP),
})
c.logger.Log(
"level", 1,
"action", "disconnected",
)
c.connMu.Lock()
now := time.Now()
err = c.serverErr
// detect disconnect hiccup
if err == nil && now.Sub(c.lastDisconnect).Seconds() < 5 {
err = fmt.Errorf("connection is being cut")
}
c.conn = nil
c.serverErr = nil
c.lastDisconnect = now
c.connMu.Unlock()
if err != nil {
return err
}
}
}
func (c *Client) connect() (net.Conn, error) {
c.connMu.Lock()
defer c.connMu.Unlock()
if c.conn != nil {
return nil, fmt.Errorf("already connected")
}
conn, err := c.dial()
if err != nil {
return nil, fmt.Errorf("failed to connect to server: %s", err)
}
c.conn = conn
return conn, nil
}
func (c *Client) dial() (net.Conn, error) {
var (
network = "tcp"
addr = c.config.ServerAddr
tlsConfig = c.config.TLSClientConfig
)
doDial := func() (conn net.Conn, err error) {
c.logger.Log(
"level", 1,
"action", "dial",
"network", network,
"addr", addr,
)
if c.config.DialTLS != nil {
conn, err = c.config.DialTLS(network, addr, tlsConfig)
} else {
conn, err = tls.DialWithDialer(
&net.Dialer{Timeout: DefaultTimeout},
network, addr, tlsConfig,
)
}
if err != nil {
c.logger.Log(
"level", 0,
"msg", "dial failed",
"network", network,
"addr", addr,
"err", err,
)
}
return
}
b := c.config.Backoff
if b == nil {
return doDial()
}
for {
conn, err := doDial()
// success
if err == nil {
b.Reset()
return conn, err
}
// failure
d := b.NextBackOff()
if d < 0 {
return conn, fmt.Errorf("backoff limit exeded: %s", err)
}
// backoff
c.logger.Log(
"level", 1,
"action", "backoff",
"sleep", d,
)
time.Sleep(d)
}
}
func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
if r.Header.Get(proto.HeaderError) != "" {
c.handleHandshakeError(w, r)
} else {
c.handleHandshake(w, r)
}
return
}
msg, err := proto.ReadControlMessage(r.Header)
if err != nil {
c.logger.Log(
"level", 1,
"err", err,
)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
c.logger.Log(
"level", 2,
"action", "handle",
"ctrlMsg", msg,
)
switch msg.Action {
case proto.ActionProxy:
c.config.Proxy(w, r.Body, msg)
default:
c.logger.Log(
"level", 0,
"msg", "unknown action",
"ctrlMsg", msg,
)
http.Error(w, err.Error(), http.StatusBadRequest)
}
c.logger.Log(
"level", 2,
"action", "done",
"ctrlMsg", msg,
)
}
func (c *Client) handleHandshakeError(w http.ResponseWriter, r *http.Request) {
err := fmt.Errorf(r.Header.Get(proto.HeaderError))
c.logger.Log(
"level", 1,
"action", "handshake error",
"addr", r.RemoteAddr,
"err", err,
)
c.connMu.Lock()
c.serverErr = fmt.Errorf("server error: %s", err)
c.connMu.Unlock()
}
func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) {
c.logger.Log(
"level", 1,
"action", "handshake",
"addr", r.RemoteAddr,
)
w.WriteHeader(http.StatusOK)
b, err := json.Marshal(c.config.Tunnels)
if err != nil {
c.logger.Log(
"level", 0,
"msg", "handshake failed",
"err", err,
)
return
}
w.Write(b)
}
// Stop disconnects client from server.
func (c *Client) Stop() {
c.connMu.Lock()
defer c.connMu.Unlock()
c.logger.Log(
"level", 1,
"action", "stop",
)
if c.conn != nil {
c.conn.Close()
}
c.conn = nil
}