route/vendor/github.com/mmatczuk/go-http-tunnel/server.go

649 lines
13 KiB
Go

// Copyright (C) 2017 MichaƂ Matczuk
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tunnel
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"golang.org/x/net/http2"
"github.com/mmatczuk/go-http-tunnel/id"
"github.com/mmatczuk/go-http-tunnel/log"
"github.com/mmatczuk/go-http-tunnel/proto"
)
// ServerConfig defines configuration for the Server.
type ServerConfig struct {
// Addr is TCP address to listen for client connections. If empty ":0"
// is used.
Addr string
// TLSConfig specifies the tls configuration to use with tls.Listener.
TLSConfig *tls.Config
// Listener specifies optional listener for client connections. If nil
// tls.Listen("tcp", Addr, TLSConfig) is used.
Listener net.Listener
// Logger is optional logger. If nil logging is disabled.
Logger log.Logger
}
// Server is responsible for proxying public connections to the client over a
// tunnel connection.
type Server struct {
*registry
config *ServerConfig
listener net.Listener
connPool *connPool
httpClient *http.Client
logger log.Logger
}
// NewServer creates a new Server.
func NewServer(config *ServerConfig) (*Server, error) {
listener, err := listener(config)
if err != nil {
return nil, fmt.Errorf("tls listener failed: %s", err)
}
logger := config.Logger
if logger == nil {
logger = log.NewNopLogger()
}
s := &Server{
registry: newRegistry(logger),
config: config,
listener: listener,
logger: logger,
}
t := &http2.Transport{}
pool := newConnPool(t, s.disconnected)
t.ConnPool = pool
s.connPool = pool
s.httpClient = &http.Client{Transport: t}
return s, nil
}
func listener(config *ServerConfig) (net.Listener, error) {
if config.Listener != nil {
return config.Listener, nil
}
if config.Addr == "" {
return nil, errors.New("missing Addr")
}
if config.TLSConfig == nil {
return nil, errors.New("missing TLSConfig")
}
return tls.Listen("tcp", config.Addr, config.TLSConfig)
}
// disconnected clears resources used by client, it's invoked by connection pool
// when client goes away.
func (s *Server) disconnected(identifier id.ID) {
s.logger.Log(
"level", 1,
"action", "disconnected",
"identifier", identifier,
)
i := s.registry.clear(identifier)
if i == nil {
return
}
for _, l := range i.Listeners {
s.logger.Log(
"level", 2,
"action", "close listener",
"identifier", identifier,
"addr", l.Addr(),
)
l.Close()
}
}
// Start starts accepting connections form clients. For accepting http traffic
// from end users server must be run as handler on http server.
func (s *Server) Start() {
addr := s.listener.Addr().String()
s.logger.Log(
"level", 1,
"action", "start",
"addr", addr,
)
for {
conn, err := s.listener.Accept()
if err != nil {
if strings.Contains(err.Error(), "use of closed network connection") {
s.logger.Log(
"level", 1,
"action", "control connection listener closed",
"addr", addr,
)
return
}
s.logger.Log(
"level", 0,
"msg", "accept control connection failed",
"addr", addr,
"err", err,
)
continue
}
go s.handleClient(conn)
}
}
func (s *Server) handleClient(conn net.Conn) {
logger := log.NewContext(s.logger).With("addr", conn.RemoteAddr())
logger.Log(
"level", 1,
"action", "try connect",
)
var (
identifier id.ID
req *http.Request
resp *http.Response
tunnels map[string]*proto.Tunnel
err error
ok bool
inConnPool bool
)
tlsConn, ok := conn.(*tls.Conn)
if !ok {
logger.Log(
"level", 0,
"msg", "invalid connection type",
"err", fmt.Errorf("expected tls conn, got %T", conn),
)
goto reject
}
identifier, err = id.PeerID(tlsConn)
if err != nil {
logger.Log(
"level", 2,
"msg", "certificate error",
"err", err,
)
goto reject
}
logger = logger.With("identifier", identifier)
if !s.IsSubscribed(identifier) {
logger.Log(
"level", 2,
"msg", "unknown client",
)
goto reject
}
if err = conn.SetDeadline(time.Time{}); err != nil {
logger.Log(
"level", 2,
"msg", "setting infinite deadline failed",
"err", err,
)
goto reject
}
if err := s.connPool.AddConn(conn, identifier); err != nil {
logger.Log(
"level", 2,
"msg", "adding connection failed",
"err", err,
)
goto reject
}
inConnPool = true
req, err = http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil)
if err != nil {
logger.Log(
"level", 2,
"msg", "handshake request creation failed",
"err", err,
)
goto reject
}
{
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
req = req.WithContext(ctx)
}
resp, err = s.httpClient.Do(req)
if err != nil {
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("Status %s", resp.Status)
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if resp.ContentLength == 0 {
err = fmt.Errorf("Tunnels Content-Legth: 0")
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(&tunnels); err != nil {
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if len(tunnels) == 0 {
err = fmt.Errorf("No tunnels")
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if err = s.addTunnels(tunnels, identifier); err != nil {
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
logger.Log(
"level", 1,
"action", "connected",
)
return
reject:
logger.Log(
"level", 1,
"action", "rejected",
)
if inConnPool {
s.notifyError(err, identifier)
s.connPool.DeleteConn(identifier)
}
conn.Close()
}
// notifyError tries to send error to client.
func (s *Server) notifyError(serverError error, identifier id.ID) {
if serverError == nil {
return
}
req, err := http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil)
if err != nil {
s.logger.Log(
"level", 2,
"action", "client error notification failed",
"identifier", identifier,
"err", err,
)
return
}
req.Header.Set(proto.HeaderError, serverError.Error())
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
s.httpClient.Do(req.WithContext(ctx))
}
// addTunnels invokes addHost or addListener based on data from proto.Tunnel. If
// a tunnel cannot be added whole batch is reverted.
func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
i := &RegistryItem{
Hosts: []*HostAuth{},
Listeners: []net.Listener{},
}
var err error
for name, t := range tunnels {
switch t.Protocol {
case proto.HTTP:
i.Hosts = append(i.Hosts, &HostAuth{t.Host, NewAuth(t.Auth)})
case proto.TCP, proto.TCP4, proto.TCP6, proto.UNIX:
var l net.Listener
l, err = net.Listen(t.Protocol, t.Addr)
if err != nil {
goto rollback
}
s.logger.Log(
"level", 2,
"action", "open listener",
"identifier", identifier,
"addr", l.Addr(),
)
i.Listeners = append(i.Listeners, l)
default:
err = fmt.Errorf("unsupported protocol for tunnel %s: %s", name, t.Protocol)
goto rollback
}
}
err = s.set(i, identifier)
if err != nil {
goto rollback
}
for _, l := range i.Listeners {
go s.listen(l, identifier)
}
return nil
rollback:
for _, l := range i.Listeners {
l.Close()
}
return err
}
// Unsubscribe removes client from registry, disconnects client if already
// connected and returns it's RegistryItem.
func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
s.connPool.DeleteConn(identifier)
return s.registry.Unsubscribe(identifier)
}
func (s *Server) listen(l net.Listener, identifier id.ID) {
addr := l.Addr().String()
for {
conn, err := l.Accept()
if err != nil {
if strings.Contains(err.Error(), "use of closed network connection") {
s.logger.Log(
"level", 2,
"action", "listener closed",
"identifier", identifier,
"addr", addr,
)
return
}
s.logger.Log(
"level", 0,
"msg", "accept connection failed",
"identifier", identifier,
"addr", addr,
"err", err,
)
continue
}
msg := &proto.ControlMessage{
Action: proto.ActionProxy,
Protocol: l.Addr().Network(),
ForwardedFor: conn.RemoteAddr().String(),
ForwardedBy: l.Addr().String(),
}
go func() {
if err := s.proxyConn(identifier, conn, msg); err != nil {
s.logger.Log(
"level", 0,
"msg", "proxy error",
"identifier", identifier,
"ctrlMsg", msg,
"err", err,
)
}
}()
}
}
// ServeHTTP proxies http connection to the client.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resp, err := s.RoundTrip(r)
if err == errUnauthorised {
w.Header().Set("WWW-Authenticate", "Basic realm=\"User Visible Realm\"")
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
if err != nil {
s.logger.Log(
"level", 0,
"action", "round trip failed",
"addr", r.RemoteAddr,
"url", r.URL,
"err", err,
)
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
if resp.Body != nil {
transfer(w, resp.Body, log.NewContext(s.logger).With(
"dir", "client to user",
"dst", r.RemoteAddr,
"src", r.Host,
))
}
}
// RoundTrip is http.RoundTriper implementation.
func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
msg := &proto.ControlMessage{
Action: proto.ActionProxy,
Protocol: proto.HTTP,
ForwardedFor: r.RemoteAddr,
ForwardedBy: r.Host,
}
identifier, auth, ok := s.Subscriber(r.Host)
if !ok {
return nil, errClientNotSubscribed
}
if auth != nil {
user, password, _ := r.BasicAuth()
if auth.User != user || auth.Password != password {
return nil, errUnauthorised
}
r.Header.Del("Authorization")
}
return s.proxyHTTP(identifier, r, msg)
}
func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error {
s.logger.Log(
"level", 2,
"action", "proxy",
"identifier", identifier,
"ctrlMsg", msg,
)
defer conn.Close()
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
req, err := s.connectRequest(identifier, msg, pr)
if err != nil {
return err
}
done := make(chan struct{})
go func() {
transfer(pw, conn, log.NewContext(s.logger).With(
"dir", "user to client",
"dst", identifier,
"src", conn.RemoteAddr(),
))
close(done)
}()
resp, err := s.httpClient.Do(req)
if err != nil {
return fmt.Errorf("io error: %s", err)
}
transfer(conn, resp.Body, log.NewContext(s.logger).With(
"dir", "client to user",
"dst", conn.RemoteAddr(),
"src", identifier,
))
<-done
s.logger.Log(
"level", 2,
"action", "proxy done",
"identifier", identifier,
"ctrlMsg", msg,
)
return nil
}
func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.ControlMessage) (*http.Response, error) {
s.logger.Log(
"level", 2,
"action", "proxy",
"identifier", identifier,
"ctrlMsg", msg,
)
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
req, err := s.connectRequest(identifier, msg, pr)
if err != nil {
return nil, fmt.Errorf("proxy request error: %s", err)
}
go func() {
cw := &countWriter{pw, 0}
err := r.Write(cw)
if err != nil {
s.logger.Log(
"level", 0,
"msg", "proxy error",
"identifier", identifier,
"ctrlMsg", msg,
"err", err,
)
}
s.logger.Log(
"level", 3,
"action", "transferred",
"identifier", identifier,
"bytes", cw.count,
"dir", "user to client",
"dst", r.Host,
"src", r.RemoteAddr,
)
if r.Body != nil {
r.Body.Close()
}
}()
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("io error: %s", err)
}
s.logger.Log(
"level", 2,
"action", "proxy done",
"identifier", identifier,
"ctrlMsg", msg,
"status code", resp.StatusCode,
)
return resp, nil
}
// connectRequest creates HTTP request to client with a given identifier having
// control message and data input stream, output data stream results from
// response the created request.
func (s *Server) connectRequest(identifier id.ID, msg *proto.ControlMessage, r io.Reader) (*http.Request, error) {
req, err := http.NewRequest(http.MethodPut, s.connPool.URL(identifier), r)
if err != nil {
return nil, fmt.Errorf("could not create request: %s", err)
}
msg.Update(req.Header)
return req, nil
}
// Addr returns network address clients connect to.
func (s *Server) Addr() string {
if s.listener == nil {
return ""
}
return s.listener.Addr().String()
}
// Stop closes the server.
func (s *Server) Stop() {
s.logger.Log(
"level", 1,
"action", "stop",
)
if s.listener != nil {
s.listener.Close()
}
}