// Copyright (C) 2017 MichaƂ Matczuk // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package tunnel import ( "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" "net" "net/http" "strings" "time" "golang.org/x/net/http2" "github.com/mmatczuk/go-http-tunnel/id" "github.com/mmatczuk/go-http-tunnel/log" "github.com/mmatczuk/go-http-tunnel/proto" ) // ServerConfig defines configuration for the Server. type ServerConfig struct { // Addr is TCP address to listen for client connections. If empty ":0" // is used. Addr string // TLSConfig specifies the tls configuration to use with tls.Listener. TLSConfig *tls.Config // Listener specifies optional listener for client connections. If nil // tls.Listen("tcp", Addr, TLSConfig) is used. Listener net.Listener // Logger is optional logger. If nil logging is disabled. Logger log.Logger } // Server is responsible for proxying public connections to the client over a // tunnel connection. type Server struct { *registry config *ServerConfig listener net.Listener connPool *connPool httpClient *http.Client logger log.Logger } // NewServer creates a new Server. func NewServer(config *ServerConfig) (*Server, error) { listener, err := listener(config) if err != nil { return nil, fmt.Errorf("tls listener failed: %s", err) } logger := config.Logger if logger == nil { logger = log.NewNopLogger() } s := &Server{ registry: newRegistry(logger), config: config, listener: listener, logger: logger, } t := &http2.Transport{} pool := newConnPool(t, s.disconnected) t.ConnPool = pool s.connPool = pool s.httpClient = &http.Client{Transport: t} return s, nil } func listener(config *ServerConfig) (net.Listener, error) { if config.Listener != nil { return config.Listener, nil } if config.Addr == "" { return nil, errors.New("missing Addr") } if config.TLSConfig == nil { return nil, errors.New("missing TLSConfig") } return tls.Listen("tcp", config.Addr, config.TLSConfig) } // disconnected clears resources used by client, it's invoked by connection pool // when client goes away. func (s *Server) disconnected(identifier id.ID) { s.logger.Log( "level", 1, "action", "disconnected", "identifier", identifier, ) i := s.registry.clear(identifier) if i == nil { return } for _, l := range i.Listeners { s.logger.Log( "level", 2, "action", "close listener", "identifier", identifier, "addr", l.Addr(), ) l.Close() } } // Start starts accepting connections form clients. For accepting http traffic // from end users server must be run as handler on http server. func (s *Server) Start() { addr := s.listener.Addr().String() s.logger.Log( "level", 1, "action", "start", "addr", addr, ) for { conn, err := s.listener.Accept() if err != nil { if strings.Contains(err.Error(), "use of closed network connection") { s.logger.Log( "level", 1, "action", "control connection listener closed", "addr", addr, ) return } s.logger.Log( "level", 0, "msg", "accept control connection failed", "addr", addr, "err", err, ) continue } go s.handleClient(conn) } } func (s *Server) handleClient(conn net.Conn) { logger := log.NewContext(s.logger).With("addr", conn.RemoteAddr()) logger.Log( "level", 1, "action", "try connect", ) var ( identifier id.ID req *http.Request resp *http.Response tunnels map[string]*proto.Tunnel err error ok bool inConnPool bool ) tlsConn, ok := conn.(*tls.Conn) if !ok { logger.Log( "level", 0, "msg", "invalid connection type", "err", fmt.Errorf("expected tls conn, got %T", conn), ) goto reject } identifier, err = id.PeerID(tlsConn) if err != nil { logger.Log( "level", 2, "msg", "certificate error", "err", err, ) goto reject } logger = logger.With("identifier", identifier) if !s.IsSubscribed(identifier) { logger.Log( "level", 2, "msg", "unknown client", ) goto reject } if err = conn.SetDeadline(time.Time{}); err != nil { logger.Log( "level", 2, "msg", "setting infinite deadline failed", "err", err, ) goto reject } if err := s.connPool.AddConn(conn, identifier); err != nil { logger.Log( "level", 2, "msg", "adding connection failed", "err", err, ) goto reject } inConnPool = true req, err = http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil) if err != nil { logger.Log( "level", 2, "msg", "handshake request creation failed", "err", err, ) goto reject } { ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout) defer cancel() req = req.WithContext(ctx) } resp, err = s.httpClient.Do(req) if err != nil { logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) goto reject } if resp.StatusCode != http.StatusOK { err = fmt.Errorf("Status %s", resp.Status) logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) goto reject } if resp.ContentLength == 0 { err = fmt.Errorf("Tunnels Content-Legth: 0") logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) goto reject } if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(&tunnels); err != nil { logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) goto reject } if len(tunnels) == 0 { err = fmt.Errorf("No tunnels") logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) goto reject } if err = s.addTunnels(tunnels, identifier); err != nil { logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) goto reject } logger.Log( "level", 1, "action", "connected", ) return reject: logger.Log( "level", 1, "action", "rejected", ) if inConnPool { s.notifyError(err, identifier) s.connPool.DeleteConn(identifier) } conn.Close() } // notifyError tries to send error to client. func (s *Server) notifyError(serverError error, identifier id.ID) { if serverError == nil { return } req, err := http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil) if err != nil { s.logger.Log( "level", 2, "action", "client error notification failed", "identifier", identifier, "err", err, ) return } req.Header.Set(proto.HeaderError, serverError.Error()) ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout) defer cancel() s.httpClient.Do(req.WithContext(ctx)) } // addTunnels invokes addHost or addListener based on data from proto.Tunnel. If // a tunnel cannot be added whole batch is reverted. func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error { i := &RegistryItem{ Hosts: []*HostAuth{}, Listeners: []net.Listener{}, } var err error for name, t := range tunnels { switch t.Protocol { case proto.HTTP: i.Hosts = append(i.Hosts, &HostAuth{t.Host, NewAuth(t.Auth)}) case proto.TCP, proto.TCP4, proto.TCP6, proto.UNIX: var l net.Listener l, err = net.Listen(t.Protocol, t.Addr) if err != nil { goto rollback } s.logger.Log( "level", 2, "action", "open listener", "identifier", identifier, "addr", l.Addr(), ) i.Listeners = append(i.Listeners, l) default: err = fmt.Errorf("unsupported protocol for tunnel %s: %s", name, t.Protocol) goto rollback } } err = s.set(i, identifier) if err != nil { goto rollback } for _, l := range i.Listeners { go s.listen(l, identifier) } return nil rollback: for _, l := range i.Listeners { l.Close() } return err } // Unsubscribe removes client from registry, disconnects client if already // connected and returns it's RegistryItem. func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem { s.connPool.DeleteConn(identifier) return s.registry.Unsubscribe(identifier) } func (s *Server) listen(l net.Listener, identifier id.ID) { addr := l.Addr().String() for { conn, err := l.Accept() if err != nil { if strings.Contains(err.Error(), "use of closed network connection") { s.logger.Log( "level", 2, "action", "listener closed", "identifier", identifier, "addr", addr, ) return } s.logger.Log( "level", 0, "msg", "accept connection failed", "identifier", identifier, "addr", addr, "err", err, ) continue } msg := &proto.ControlMessage{ Action: proto.ActionProxy, Protocol: l.Addr().Network(), ForwardedFor: conn.RemoteAddr().String(), ForwardedBy: l.Addr().String(), } go func() { if err := s.proxyConn(identifier, conn, msg); err != nil { s.logger.Log( "level", 0, "msg", "proxy error", "identifier", identifier, "ctrlMsg", msg, "err", err, ) } }() } } // ServeHTTP proxies http connection to the client. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { resp, err := s.RoundTrip(r) if err == errUnauthorised { w.Header().Set("WWW-Authenticate", "Basic realm=\"User Visible Realm\"") http.Error(w, err.Error(), http.StatusUnauthorized) return } if err != nil { s.logger.Log( "level", 0, "action", "round trip failed", "addr", r.RemoteAddr, "url", r.URL, "err", err, ) http.Error(w, err.Error(), http.StatusBadGateway) return } copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) if resp.Body != nil { transfer(w, resp.Body, log.NewContext(s.logger).With( "dir", "client to user", "dst", r.RemoteAddr, "src", r.Host, )) } } // RoundTrip is http.RoundTriper implementation. func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { msg := &proto.ControlMessage{ Action: proto.ActionProxy, Protocol: proto.HTTP, ForwardedFor: r.RemoteAddr, ForwardedBy: r.Host, } identifier, auth, ok := s.Subscriber(r.Host) if !ok { return nil, errClientNotSubscribed } if auth != nil { user, password, _ := r.BasicAuth() if auth.User != user || auth.Password != password { return nil, errUnauthorised } r.Header.Del("Authorization") } return s.proxyHTTP(identifier, r, msg) } func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error { s.logger.Log( "level", 2, "action", "proxy", "identifier", identifier, "ctrlMsg", msg, ) defer conn.Close() pr, pw := io.Pipe() defer pr.Close() defer pw.Close() req, err := s.connectRequest(identifier, msg, pr) if err != nil { return err } done := make(chan struct{}) go func() { transfer(pw, conn, log.NewContext(s.logger).With( "dir", "user to client", "dst", identifier, "src", conn.RemoteAddr(), )) close(done) }() resp, err := s.httpClient.Do(req) if err != nil { return fmt.Errorf("io error: %s", err) } transfer(conn, resp.Body, log.NewContext(s.logger).With( "dir", "client to user", "dst", conn.RemoteAddr(), "src", identifier, )) <-done s.logger.Log( "level", 2, "action", "proxy done", "identifier", identifier, "ctrlMsg", msg, ) return nil } func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.ControlMessage) (*http.Response, error) { s.logger.Log( "level", 2, "action", "proxy", "identifier", identifier, "ctrlMsg", msg, ) pr, pw := io.Pipe() defer pr.Close() defer pw.Close() req, err := s.connectRequest(identifier, msg, pr) if err != nil { return nil, fmt.Errorf("proxy request error: %s", err) } go func() { cw := &countWriter{pw, 0} err := r.Write(cw) if err != nil { s.logger.Log( "level", 0, "msg", "proxy error", "identifier", identifier, "ctrlMsg", msg, "err", err, ) } s.logger.Log( "level", 3, "action", "transferred", "identifier", identifier, "bytes", cw.count, "dir", "user to client", "dst", r.Host, "src", r.RemoteAddr, ) if r.Body != nil { r.Body.Close() } }() resp, err := s.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("io error: %s", err) } s.logger.Log( "level", 2, "action", "proxy done", "identifier", identifier, "ctrlMsg", msg, "status code", resp.StatusCode, ) return resp, nil } // connectRequest creates HTTP request to client with a given identifier having // control message and data input stream, output data stream results from // response the created request. func (s *Server) connectRequest(identifier id.ID, msg *proto.ControlMessage, r io.Reader) (*http.Request, error) { req, err := http.NewRequest(http.MethodPut, s.connPool.URL(identifier), r) if err != nil { return nil, fmt.Errorf("could not create request: %s", err) } msg.Update(req.Header) return req, nil } // Addr returns network address clients connect to. func (s *Server) Addr() string { if s.listener == nil { return "" } return s.listener.Addr().String() } // Stop closes the server. func (s *Server) Stop() { s.logger.Log( "level", 1, "action", "stop", ) if s.listener != nil { s.listener.Close() } }