315 lines
6.2 KiB
Go
315 lines
6.2 KiB
Go
// Copyright (C) 2017 MichaĆ Matczuk
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package tunnel
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/net/http2"
|
|
|
|
"github.com/mmatczuk/go-http-tunnel/log"
|
|
"github.com/mmatczuk/go-http-tunnel/proto"
|
|
)
|
|
|
|
var (
|
|
// DefaultTimeout specifies general purpose timeout.
|
|
DefaultTimeout = 10 * time.Second
|
|
)
|
|
|
|
// ClientConfig is configuration of the Client.
|
|
type ClientConfig struct {
|
|
// ServerAddr specifies TCP address of the tunnel server.
|
|
ServerAddr string
|
|
// TLSClientConfig specifies the tls configuration to use with
|
|
// tls.Client.
|
|
TLSClientConfig *tls.Config
|
|
// DialTLS specifies an optional dial function that creates a tls
|
|
// connection to the server. If DialTLS is nil, tls.Dial is used.
|
|
DialTLS func(network, addr string, config *tls.Config) (net.Conn, error)
|
|
// Backoff specifies backoff policy on server connection retry. If nil
|
|
// when dial fails it will not be retried.
|
|
Backoff Backoff
|
|
// Tunnels specifies the tunnels client requests to be opened on server.
|
|
Tunnels map[string]*proto.Tunnel
|
|
// Proxy is ProxyFunc responsible for transferring data between server
|
|
// and local services.
|
|
Proxy ProxyFunc
|
|
// Logger is optional logger. If nil logging is disabled.
|
|
Logger log.Logger
|
|
}
|
|
|
|
// Client is responsible for creating connection to the server, handling control
|
|
// messages. It uses ProxyFunc for transferring data between server and local
|
|
// services.
|
|
type Client struct {
|
|
config *ClientConfig
|
|
conn net.Conn
|
|
connMu sync.Mutex
|
|
httpServer *http2.Server
|
|
serverErr error
|
|
lastDisconnect time.Time
|
|
logger log.Logger
|
|
}
|
|
|
|
// NewClient creates a new unconnected Client based on configuration. Caller
|
|
// must invoke Start() on returned instance in order to connect server.
|
|
func NewClient(config *ClientConfig) *Client {
|
|
if config.ServerAddr == "" {
|
|
panic("missing ServerAddr")
|
|
}
|
|
if config.TLSClientConfig == nil {
|
|
panic("missing TLSClientConfig")
|
|
}
|
|
if config.Tunnels == nil || len(config.Tunnels) == 0 {
|
|
panic("missing Tunnels")
|
|
}
|
|
if config.Proxy == nil {
|
|
panic("missing Proxy")
|
|
}
|
|
|
|
logger := config.Logger
|
|
if logger == nil {
|
|
logger = log.NewNopLogger()
|
|
}
|
|
|
|
c := &Client{
|
|
config: config,
|
|
httpServer: &http2.Server{},
|
|
logger: logger,
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
// Start connects client to the server, it returns error if there is a
|
|
// connection error, or server cannot open requested tunnels. On connection
|
|
// error a backoff policy is used to reestablish the connection. When connected
|
|
// HTTP/2 server is started to handle ControlMessages.
|
|
func (c *Client) Start() error {
|
|
c.logger.Log(
|
|
"level", 1,
|
|
"action", "start",
|
|
)
|
|
|
|
for {
|
|
conn, err := c.connect()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
|
|
Handler: http.HandlerFunc(c.serveHTTP),
|
|
})
|
|
|
|
c.logger.Log(
|
|
"level", 1,
|
|
"action", "disconnected",
|
|
)
|
|
|
|
c.connMu.Lock()
|
|
now := time.Now()
|
|
err = c.serverErr
|
|
|
|
// detect disconnect hiccup
|
|
if err == nil && now.Sub(c.lastDisconnect).Seconds() < 5 {
|
|
err = fmt.Errorf("connection is being cut")
|
|
}
|
|
|
|
c.conn = nil
|
|
c.serverErr = nil
|
|
c.lastDisconnect = now
|
|
c.connMu.Unlock()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) connect() (net.Conn, error) {
|
|
c.connMu.Lock()
|
|
defer c.connMu.Unlock()
|
|
|
|
if c.conn != nil {
|
|
return nil, fmt.Errorf("already connected")
|
|
}
|
|
|
|
conn, err := c.dial()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to server: %s", err)
|
|
}
|
|
c.conn = conn
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func (c *Client) dial() (net.Conn, error) {
|
|
var (
|
|
network = "tcp"
|
|
addr = c.config.ServerAddr
|
|
tlsConfig = c.config.TLSClientConfig
|
|
)
|
|
|
|
doDial := func() (conn net.Conn, err error) {
|
|
c.logger.Log(
|
|
"level", 1,
|
|
"action", "dial",
|
|
"network", network,
|
|
"addr", addr,
|
|
)
|
|
|
|
if c.config.DialTLS != nil {
|
|
conn, err = c.config.DialTLS(network, addr, tlsConfig)
|
|
} else {
|
|
conn, err = tls.DialWithDialer(
|
|
&net.Dialer{Timeout: DefaultTimeout},
|
|
network, addr, tlsConfig,
|
|
)
|
|
}
|
|
|
|
if err != nil {
|
|
c.logger.Log(
|
|
"level", 0,
|
|
"msg", "dial failed",
|
|
"network", network,
|
|
"addr", addr,
|
|
"err", err,
|
|
)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
b := c.config.Backoff
|
|
if b == nil {
|
|
return doDial()
|
|
}
|
|
|
|
for {
|
|
conn, err := doDial()
|
|
|
|
// success
|
|
if err == nil {
|
|
b.Reset()
|
|
return conn, err
|
|
}
|
|
|
|
// failure
|
|
d := b.NextBackOff()
|
|
if d < 0 {
|
|
return conn, fmt.Errorf("backoff limit exeded: %s", err)
|
|
}
|
|
|
|
// backoff
|
|
c.logger.Log(
|
|
"level", 1,
|
|
"action", "backoff",
|
|
"sleep", d,
|
|
)
|
|
time.Sleep(d)
|
|
}
|
|
}
|
|
|
|
func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodConnect {
|
|
if r.Header.Get(proto.HeaderError) != "" {
|
|
c.handleHandshakeError(w, r)
|
|
} else {
|
|
c.handleHandshake(w, r)
|
|
}
|
|
return
|
|
}
|
|
|
|
msg, err := proto.ReadControlMessage(r.Header)
|
|
if err != nil {
|
|
c.logger.Log(
|
|
"level", 1,
|
|
"err", err,
|
|
)
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
c.logger.Log(
|
|
"level", 2,
|
|
"action", "handle",
|
|
"ctrlMsg", msg,
|
|
)
|
|
switch msg.Action {
|
|
case proto.ActionProxy:
|
|
c.config.Proxy(w, r.Body, msg)
|
|
default:
|
|
c.logger.Log(
|
|
"level", 0,
|
|
"msg", "unknown action",
|
|
"ctrlMsg", msg,
|
|
)
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
}
|
|
c.logger.Log(
|
|
"level", 2,
|
|
"action", "done",
|
|
"ctrlMsg", msg,
|
|
)
|
|
}
|
|
|
|
func (c *Client) handleHandshakeError(w http.ResponseWriter, r *http.Request) {
|
|
err := fmt.Errorf(r.Header.Get(proto.HeaderError))
|
|
|
|
c.logger.Log(
|
|
"level", 1,
|
|
"action", "handshake error",
|
|
"addr", r.RemoteAddr,
|
|
"err", err,
|
|
)
|
|
|
|
c.connMu.Lock()
|
|
c.serverErr = fmt.Errorf("server error: %s", err)
|
|
c.connMu.Unlock()
|
|
}
|
|
|
|
func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) {
|
|
c.logger.Log(
|
|
"level", 1,
|
|
"action", "handshake",
|
|
"addr", r.RemoteAddr,
|
|
)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
b, err := json.Marshal(c.config.Tunnels)
|
|
if err != nil {
|
|
c.logger.Log(
|
|
"level", 0,
|
|
"msg", "handshake failed",
|
|
"err", err,
|
|
)
|
|
return
|
|
}
|
|
w.Write(b)
|
|
}
|
|
|
|
// Stop disconnects client from server.
|
|
func (c *Client) Stop() {
|
|
c.connMu.Lock()
|
|
defer c.connMu.Unlock()
|
|
|
|
c.logger.Log(
|
|
"level", 1,
|
|
"action", "stop",
|
|
)
|
|
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
}
|
|
c.conn = nil
|
|
}
|