649 lines
13 KiB
Go
649 lines
13 KiB
Go
|
// Copyright (C) 2017 MichaĆ Matczuk
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package tunnel
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"golang.org/x/net/http2"
|
||
|
|
||
|
"github.com/mmatczuk/go-http-tunnel/id"
|
||
|
"github.com/mmatczuk/go-http-tunnel/log"
|
||
|
"github.com/mmatczuk/go-http-tunnel/proto"
|
||
|
)
|
||
|
|
||
|
// ServerConfig defines configuration for the Server.
|
||
|
type ServerConfig struct {
|
||
|
// Addr is TCP address to listen for client connections. If empty ":0"
|
||
|
// is used.
|
||
|
Addr string
|
||
|
// TLSConfig specifies the tls configuration to use with tls.Listener.
|
||
|
TLSConfig *tls.Config
|
||
|
// Listener specifies optional listener for client connections. If nil
|
||
|
// tls.Listen("tcp", Addr, TLSConfig) is used.
|
||
|
Listener net.Listener
|
||
|
// Logger is optional logger. If nil logging is disabled.
|
||
|
Logger log.Logger
|
||
|
}
|
||
|
|
||
|
// Server is responsible for proxying public connections to the client over a
|
||
|
// tunnel connection.
|
||
|
type Server struct {
|
||
|
*registry
|
||
|
config *ServerConfig
|
||
|
listener net.Listener
|
||
|
connPool *connPool
|
||
|
httpClient *http.Client
|
||
|
logger log.Logger
|
||
|
}
|
||
|
|
||
|
// NewServer creates a new Server.
|
||
|
func NewServer(config *ServerConfig) (*Server, error) {
|
||
|
listener, err := listener(config)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("tls listener failed: %s", err)
|
||
|
}
|
||
|
|
||
|
logger := config.Logger
|
||
|
if logger == nil {
|
||
|
logger = log.NewNopLogger()
|
||
|
}
|
||
|
|
||
|
s := &Server{
|
||
|
registry: newRegistry(logger),
|
||
|
config: config,
|
||
|
listener: listener,
|
||
|
logger: logger,
|
||
|
}
|
||
|
|
||
|
t := &http2.Transport{}
|
||
|
pool := newConnPool(t, s.disconnected)
|
||
|
t.ConnPool = pool
|
||
|
s.connPool = pool
|
||
|
s.httpClient = &http.Client{Transport: t}
|
||
|
|
||
|
return s, nil
|
||
|
}
|
||
|
|
||
|
func listener(config *ServerConfig) (net.Listener, error) {
|
||
|
if config.Listener != nil {
|
||
|
return config.Listener, nil
|
||
|
}
|
||
|
|
||
|
if config.Addr == "" {
|
||
|
return nil, errors.New("missing Addr")
|
||
|
}
|
||
|
if config.TLSConfig == nil {
|
||
|
return nil, errors.New("missing TLSConfig")
|
||
|
}
|
||
|
|
||
|
return tls.Listen("tcp", config.Addr, config.TLSConfig)
|
||
|
}
|
||
|
|
||
|
// disconnected clears resources used by client, it's invoked by connection pool
|
||
|
// when client goes away.
|
||
|
func (s *Server) disconnected(identifier id.ID) {
|
||
|
s.logger.Log(
|
||
|
"level", 1,
|
||
|
"action", "disconnected",
|
||
|
"identifier", identifier,
|
||
|
)
|
||
|
|
||
|
i := s.registry.clear(identifier)
|
||
|
if i == nil {
|
||
|
return
|
||
|
}
|
||
|
for _, l := range i.Listeners {
|
||
|
s.logger.Log(
|
||
|
"level", 2,
|
||
|
"action", "close listener",
|
||
|
"identifier", identifier,
|
||
|
"addr", l.Addr(),
|
||
|
)
|
||
|
l.Close()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Start starts accepting connections form clients. For accepting http traffic
|
||
|
// from end users server must be run as handler on http server.
|
||
|
func (s *Server) Start() {
|
||
|
addr := s.listener.Addr().String()
|
||
|
|
||
|
s.logger.Log(
|
||
|
"level", 1,
|
||
|
"action", "start",
|
||
|
"addr", addr,
|
||
|
)
|
||
|
|
||
|
for {
|
||
|
conn, err := s.listener.Accept()
|
||
|
if err != nil {
|
||
|
if strings.Contains(err.Error(), "use of closed network connection") {
|
||
|
s.logger.Log(
|
||
|
"level", 1,
|
||
|
"action", "control connection listener closed",
|
||
|
"addr", addr,
|
||
|
)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
s.logger.Log(
|
||
|
"level", 0,
|
||
|
"msg", "accept control connection failed",
|
||
|
"addr", addr,
|
||
|
"err", err,
|
||
|
)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
go s.handleClient(conn)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Server) handleClient(conn net.Conn) {
|
||
|
logger := log.NewContext(s.logger).With("addr", conn.RemoteAddr())
|
||
|
|
||
|
logger.Log(
|
||
|
"level", 1,
|
||
|
"action", "try connect",
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
identifier id.ID
|
||
|
req *http.Request
|
||
|
resp *http.Response
|
||
|
tunnels map[string]*proto.Tunnel
|
||
|
err error
|
||
|
ok bool
|
||
|
|
||
|
inConnPool bool
|
||
|
)
|
||
|
|
||
|
tlsConn, ok := conn.(*tls.Conn)
|
||
|
if !ok {
|
||
|
logger.Log(
|
||
|
"level", 0,
|
||
|
"msg", "invalid connection type",
|
||
|
"err", fmt.Errorf("expected tls conn, got %T", conn),
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
identifier, err = id.PeerID(tlsConn)
|
||
|
if err != nil {
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "certificate error",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
logger = logger.With("identifier", identifier)
|
||
|
|
||
|
if !s.IsSubscribed(identifier) {
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "unknown client",
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
if err = conn.SetDeadline(time.Time{}); err != nil {
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "setting infinite deadline failed",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
if err := s.connPool.AddConn(conn, identifier); err != nil {
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "adding connection failed",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
inConnPool = true
|
||
|
|
||
|
req, err = http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil)
|
||
|
if err != nil {
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "handshake request creation failed",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
{
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||
|
defer cancel()
|
||
|
req = req.WithContext(ctx)
|
||
|
}
|
||
|
|
||
|
resp, err = s.httpClient.Do(req)
|
||
|
if err != nil {
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "handshake failed",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
if resp.StatusCode != http.StatusOK {
|
||
|
err = fmt.Errorf("Status %s", resp.Status)
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "handshake failed",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
if resp.ContentLength == 0 {
|
||
|
err = fmt.Errorf("Tunnels Content-Legth: 0")
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "handshake failed",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(&tunnels); err != nil {
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "handshake failed",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
if len(tunnels) == 0 {
|
||
|
err = fmt.Errorf("No tunnels")
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "handshake failed",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
if err = s.addTunnels(tunnels, identifier); err != nil {
|
||
|
logger.Log(
|
||
|
"level", 2,
|
||
|
"msg", "handshake failed",
|
||
|
"err", err,
|
||
|
)
|
||
|
goto reject
|
||
|
}
|
||
|
|
||
|
logger.Log(
|
||
|
"level", 1,
|
||
|
"action", "connected",
|
||
|
)
|
||
|
|
||
|
return
|
||
|
|
||
|
reject:
|
||
|
logger.Log(
|
||
|
"level", 1,
|
||
|
"action", "rejected",
|
||
|
)
|
||
|
|
||
|
if inConnPool {
|
||
|
s.notifyError(err, identifier)
|
||
|
s.connPool.DeleteConn(identifier)
|
||
|
}
|
||
|
|
||
|
conn.Close()
|
||
|
}
|
||
|
|
||
|
// notifyError tries to send error to client.
|
||
|
func (s *Server) notifyError(serverError error, identifier id.ID) {
|
||
|
if serverError == nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
req, err := http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil)
|
||
|
if err != nil {
|
||
|
s.logger.Log(
|
||
|
"level", 2,
|
||
|
"action", "client error notification failed",
|
||
|
"identifier", identifier,
|
||
|
"err", err,
|
||
|
)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
req.Header.Set(proto.HeaderError, serverError.Error())
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||
|
defer cancel()
|
||
|
|
||
|
s.httpClient.Do(req.WithContext(ctx))
|
||
|
}
|
||
|
|
||
|
// addTunnels invokes addHost or addListener based on data from proto.Tunnel. If
|
||
|
// a tunnel cannot be added whole batch is reverted.
|
||
|
func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
|
||
|
i := &RegistryItem{
|
||
|
Hosts: []*HostAuth{},
|
||
|
Listeners: []net.Listener{},
|
||
|
}
|
||
|
|
||
|
var err error
|
||
|
for name, t := range tunnels {
|
||
|
switch t.Protocol {
|
||
|
case proto.HTTP:
|
||
|
i.Hosts = append(i.Hosts, &HostAuth{t.Host, NewAuth(t.Auth)})
|
||
|
case proto.TCP, proto.TCP4, proto.TCP6, proto.UNIX:
|
||
|
var l net.Listener
|
||
|
l, err = net.Listen(t.Protocol, t.Addr)
|
||
|
if err != nil {
|
||
|
goto rollback
|
||
|
}
|
||
|
|
||
|
s.logger.Log(
|
||
|
"level", 2,
|
||
|
"action", "open listener",
|
||
|
"identifier", identifier,
|
||
|
"addr", l.Addr(),
|
||
|
)
|
||
|
|
||
|
i.Listeners = append(i.Listeners, l)
|
||
|
default:
|
||
|
err = fmt.Errorf("unsupported protocol for tunnel %s: %s", name, t.Protocol)
|
||
|
goto rollback
|
||
|
}
|
||
|
}
|
||
|
|
||
|
err = s.set(i, identifier)
|
||
|
if err != nil {
|
||
|
goto rollback
|
||
|
}
|
||
|
|
||
|
for _, l := range i.Listeners {
|
||
|
go s.listen(l, identifier)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
|
||
|
rollback:
|
||
|
for _, l := range i.Listeners {
|
||
|
l.Close()
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Unsubscribe removes client from registry, disconnects client if already
|
||
|
// connected and returns it's RegistryItem.
|
||
|
func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
|
||
|
s.connPool.DeleteConn(identifier)
|
||
|
return s.registry.Unsubscribe(identifier)
|
||
|
}
|
||
|
|
||
|
func (s *Server) listen(l net.Listener, identifier id.ID) {
|
||
|
addr := l.Addr().String()
|
||
|
|
||
|
for {
|
||
|
conn, err := l.Accept()
|
||
|
if err != nil {
|
||
|
if strings.Contains(err.Error(), "use of closed network connection") {
|
||
|
s.logger.Log(
|
||
|
"level", 2,
|
||
|
"action", "listener closed",
|
||
|
"identifier", identifier,
|
||
|
"addr", addr,
|
||
|
)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
s.logger.Log(
|
||
|
"level", 0,
|
||
|
"msg", "accept connection failed",
|
||
|
"identifier", identifier,
|
||
|
"addr", addr,
|
||
|
"err", err,
|
||
|
)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
msg := &proto.ControlMessage{
|
||
|
Action: proto.ActionProxy,
|
||
|
Protocol: l.Addr().Network(),
|
||
|
ForwardedFor: conn.RemoteAddr().String(),
|
||
|
ForwardedBy: l.Addr().String(),
|
||
|
}
|
||
|
go func() {
|
||
|
if err := s.proxyConn(identifier, conn, msg); err != nil {
|
||
|
s.logger.Log(
|
||
|
"level", 0,
|
||
|
"msg", "proxy error",
|
||
|
"identifier", identifier,
|
||
|
"ctrlMsg", msg,
|
||
|
"err", err,
|
||
|
)
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ServeHTTP proxies http connection to the client.
|
||
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||
|
resp, err := s.RoundTrip(r)
|
||
|
if err == errUnauthorised {
|
||
|
w.Header().Set("WWW-Authenticate", "Basic realm=\"User Visible Realm\"")
|
||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||
|
return
|
||
|
}
|
||
|
if err != nil {
|
||
|
s.logger.Log(
|
||
|
"level", 0,
|
||
|
"action", "round trip failed",
|
||
|
"addr", r.RemoteAddr,
|
||
|
"url", r.URL,
|
||
|
"err", err,
|
||
|
)
|
||
|
|
||
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
copyHeader(w.Header(), resp.Header)
|
||
|
w.WriteHeader(resp.StatusCode)
|
||
|
if resp.Body != nil {
|
||
|
transfer(w, resp.Body, log.NewContext(s.logger).With(
|
||
|
"dir", "client to user",
|
||
|
"dst", r.RemoteAddr,
|
||
|
"src", r.Host,
|
||
|
))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// RoundTrip is http.RoundTriper implementation.
|
||
|
func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
|
||
|
msg := &proto.ControlMessage{
|
||
|
Action: proto.ActionProxy,
|
||
|
Protocol: proto.HTTP,
|
||
|
ForwardedFor: r.RemoteAddr,
|
||
|
ForwardedBy: r.Host,
|
||
|
}
|
||
|
|
||
|
identifier, auth, ok := s.Subscriber(r.Host)
|
||
|
if !ok {
|
||
|
return nil, errClientNotSubscribed
|
||
|
}
|
||
|
if auth != nil {
|
||
|
user, password, _ := r.BasicAuth()
|
||
|
if auth.User != user || auth.Password != password {
|
||
|
return nil, errUnauthorised
|
||
|
}
|
||
|
r.Header.Del("Authorization")
|
||
|
}
|
||
|
|
||
|
return s.proxyHTTP(identifier, r, msg)
|
||
|
}
|
||
|
|
||
|
func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error {
|
||
|
s.logger.Log(
|
||
|
"level", 2,
|
||
|
"action", "proxy",
|
||
|
"identifier", identifier,
|
||
|
"ctrlMsg", msg,
|
||
|
)
|
||
|
|
||
|
defer conn.Close()
|
||
|
|
||
|
pr, pw := io.Pipe()
|
||
|
defer pr.Close()
|
||
|
defer pw.Close()
|
||
|
|
||
|
req, err := s.connectRequest(identifier, msg, pr)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
done := make(chan struct{})
|
||
|
go func() {
|
||
|
transfer(pw, conn, log.NewContext(s.logger).With(
|
||
|
"dir", "user to client",
|
||
|
"dst", identifier,
|
||
|
"src", conn.RemoteAddr(),
|
||
|
))
|
||
|
close(done)
|
||
|
}()
|
||
|
|
||
|
resp, err := s.httpClient.Do(req)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("io error: %s", err)
|
||
|
}
|
||
|
|
||
|
transfer(conn, resp.Body, log.NewContext(s.logger).With(
|
||
|
"dir", "client to user",
|
||
|
"dst", conn.RemoteAddr(),
|
||
|
"src", identifier,
|
||
|
))
|
||
|
|
||
|
<-done
|
||
|
|
||
|
s.logger.Log(
|
||
|
"level", 2,
|
||
|
"action", "proxy done",
|
||
|
"identifier", identifier,
|
||
|
"ctrlMsg", msg,
|
||
|
)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.ControlMessage) (*http.Response, error) {
|
||
|
s.logger.Log(
|
||
|
"level", 2,
|
||
|
"action", "proxy",
|
||
|
"identifier", identifier,
|
||
|
"ctrlMsg", msg,
|
||
|
)
|
||
|
|
||
|
pr, pw := io.Pipe()
|
||
|
defer pr.Close()
|
||
|
defer pw.Close()
|
||
|
|
||
|
req, err := s.connectRequest(identifier, msg, pr)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("proxy request error: %s", err)
|
||
|
}
|
||
|
|
||
|
go func() {
|
||
|
cw := &countWriter{pw, 0}
|
||
|
err := r.Write(cw)
|
||
|
if err != nil {
|
||
|
s.logger.Log(
|
||
|
"level", 0,
|
||
|
"msg", "proxy error",
|
||
|
"identifier", identifier,
|
||
|
"ctrlMsg", msg,
|
||
|
"err", err,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
s.logger.Log(
|
||
|
"level", 3,
|
||
|
"action", "transferred",
|
||
|
"identifier", identifier,
|
||
|
"bytes", cw.count,
|
||
|
"dir", "user to client",
|
||
|
"dst", r.Host,
|
||
|
"src", r.RemoteAddr,
|
||
|
)
|
||
|
|
||
|
if r.Body != nil {
|
||
|
r.Body.Close()
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
resp, err := s.httpClient.Do(req)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("io error: %s", err)
|
||
|
}
|
||
|
|
||
|
s.logger.Log(
|
||
|
"level", 2,
|
||
|
"action", "proxy done",
|
||
|
"identifier", identifier,
|
||
|
"ctrlMsg", msg,
|
||
|
"status code", resp.StatusCode,
|
||
|
)
|
||
|
|
||
|
return resp, nil
|
||
|
}
|
||
|
|
||
|
// connectRequest creates HTTP request to client with a given identifier having
|
||
|
// control message and data input stream, output data stream results from
|
||
|
// response the created request.
|
||
|
func (s *Server) connectRequest(identifier id.ID, msg *proto.ControlMessage, r io.Reader) (*http.Request, error) {
|
||
|
req, err := http.NewRequest(http.MethodPut, s.connPool.URL(identifier), r)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("could not create request: %s", err)
|
||
|
}
|
||
|
msg.Update(req.Header)
|
||
|
|
||
|
return req, nil
|
||
|
}
|
||
|
|
||
|
// Addr returns network address clients connect to.
|
||
|
func (s *Server) Addr() string {
|
||
|
if s.listener == nil {
|
||
|
return ""
|
||
|
}
|
||
|
return s.listener.Addr().String()
|
||
|
}
|
||
|
|
||
|
// Stop closes the server.
|
||
|
func (s *Server) Stop() {
|
||
|
s.logger.Log(
|
||
|
"level", 1,
|
||
|
"action", "stop",
|
||
|
)
|
||
|
|
||
|
if s.listener != nil {
|
||
|
s.listener.Close()
|
||
|
}
|
||
|
}
|