2017-01-20 01:27:14 +00:00
|
|
|
// Package tunnel is a server/client package that enables to proxy public
|
|
|
|
// connections to your local machine over a tunnel connection from the local
|
|
|
|
// machine to the public server.
|
|
|
|
package tunnel
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bufio"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"net"
|
|
|
|
"net/http"
|
|
|
|
"os"
|
|
|
|
"path"
|
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
|
|
|
"github.com/hashicorp/yamux"
|
2017-01-22 17:36:44 +00:00
|
|
|
"github.com/koding/logging"
|
2017-01-20 01:27:14 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
errNoClientSession = errors.New("no client session established")
|
|
|
|
defaultTimeout = 10 * time.Second
|
|
|
|
)
|
|
|
|
|
|
|
|
// Server is responsible for proxying public connections to the client over a
|
|
|
|
// tunnel connection. It also listens to control messages from the client.
|
|
|
|
type Server struct {
|
|
|
|
// pending contains the channel that is associated with each new tunnel request.
|
|
|
|
pending map[string]chan net.Conn
|
|
|
|
// pendingMu protects the pending map.
|
|
|
|
pendingMu sync.Mutex
|
|
|
|
|
|
|
|
// sessions contains a session per virtual host.
|
|
|
|
// Sessions provides multiplexing over one connection.
|
|
|
|
sessions map[string]*yamux.Session
|
|
|
|
// sessionsMu protects sessions.
|
|
|
|
sessionsMu sync.Mutex
|
|
|
|
|
|
|
|
// controls contains the control connection from the client to the server.
|
|
|
|
controls *controls
|
|
|
|
|
|
|
|
// virtualHosts is used to map public hosts to remote clients.
|
|
|
|
virtualHosts vhostStorage
|
|
|
|
|
|
|
|
// virtualAddrs.
|
|
|
|
virtualAddrs *vaddrStorage
|
|
|
|
|
|
|
|
// connCh is used to publish accepted connections for tcp tunnels.
|
|
|
|
connCh chan net.Conn
|
|
|
|
|
|
|
|
// onConnectCallbacks contains client callbacks called when control
|
|
|
|
// session is established for a client with given identifier.
|
|
|
|
onConnectCallbacks *callbacks
|
|
|
|
|
|
|
|
// onDisconnectCallbacks contains client callbacks called when control
|
|
|
|
// session is closed for a client with given identifier.
|
|
|
|
onDisconnectCallbacks *callbacks
|
|
|
|
|
|
|
|
// states represents current clients' connections state.
|
|
|
|
states map[string]ClientState
|
|
|
|
// statesMu protects states.
|
|
|
|
statesMu sync.RWMutex
|
|
|
|
// stateCh notifies receiver about client state changes.
|
|
|
|
stateCh chan<- *ClientStateChange
|
|
|
|
|
|
|
|
// httpDirector is provided by ServerConfig, if not nil decorates http requests
|
|
|
|
// before forwarding them to client.
|
|
|
|
httpDirector func(*http.Request)
|
|
|
|
|
|
|
|
// yamuxConfig is passed to new yamux.Session's
|
|
|
|
yamuxConfig *yamux.Config
|
|
|
|
|
|
|
|
log logging.Logger
|
|
|
|
}
|
|
|
|
|
|
|
|
// ServerConfig defines the configuration for the Server
|
|
|
|
type ServerConfig struct {
|
|
|
|
// StateChanges receives state transition details each time client
|
|
|
|
// connection state changes. The channel is expected to be sufficiently
|
|
|
|
// buffered to keep up with event pace.
|
|
|
|
//
|
|
|
|
// If nil, no information about state transitions are dispatched
|
|
|
|
// by the library.
|
|
|
|
StateChanges chan<- *ClientStateChange
|
|
|
|
|
|
|
|
// Director is a function that modifies HTTP request into a new HTTP request
|
|
|
|
// before sending to client. If nil no modifications are done.
|
|
|
|
Director func(*http.Request)
|
|
|
|
|
|
|
|
// Debug enables debug mode, enable only if you want to debug the server
|
|
|
|
Debug bool
|
|
|
|
|
|
|
|
// Log defines the logger. If nil a default logging.Logger is used.
|
|
|
|
Log logging.Logger
|
|
|
|
|
|
|
|
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
|
|
|
|
// yamux.DefaultConfig() is used.
|
|
|
|
YamuxConfig *yamux.Config
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewServer creates a new Server. The defaults are used if config is nil.
|
|
|
|
func NewServer(cfg *ServerConfig) (*Server, error) {
|
|
|
|
yamuxConfig := yamux.DefaultConfig()
|
|
|
|
if cfg.YamuxConfig != nil {
|
|
|
|
if err := yamux.VerifyConfig(cfg.YamuxConfig); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
yamuxConfig = cfg.YamuxConfig
|
|
|
|
}
|
|
|
|
|
|
|
|
log := newLogger("tunnel-server", cfg.Debug)
|
|
|
|
if cfg.Log != nil {
|
|
|
|
log = cfg.Log
|
|
|
|
}
|
|
|
|
|
|
|
|
connCh := make(chan net.Conn)
|
|
|
|
|
|
|
|
opts := &vaddrOptions{
|
|
|
|
connCh: connCh,
|
|
|
|
log: log,
|
|
|
|
}
|
|
|
|
|
|
|
|
s := &Server{
|
|
|
|
pending: make(map[string]chan net.Conn),
|
|
|
|
sessions: make(map[string]*yamux.Session),
|
|
|
|
onConnectCallbacks: newCallbacks("OnConnect"),
|
|
|
|
onDisconnectCallbacks: newCallbacks("OnDisconnect"),
|
|
|
|
virtualHosts: newVirtualHosts(),
|
|
|
|
virtualAddrs: newVirtualAddrs(opts),
|
|
|
|
controls: newControls(),
|
|
|
|
states: make(map[string]ClientState),
|
|
|
|
stateCh: cfg.StateChanges,
|
|
|
|
httpDirector: cfg.Director,
|
|
|
|
yamuxConfig: yamuxConfig,
|
|
|
|
connCh: connCh,
|
|
|
|
log: log,
|
|
|
|
}
|
|
|
|
|
|
|
|
go s.serveTCP()
|
|
|
|
|
|
|
|
return s, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// ServeHTTP is a tunnel that creates an http/websocket tunnel between a
|
|
|
|
// public connection and the client connection.
|
|
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
|
// if the user didn't add the control and tunnel handler manually, we'll
|
|
|
|
// going to infer and call the respective path handlers.
|
|
|
|
switch path.Clean(r.URL.Path) + "/" {
|
|
|
|
case proto.ControlPath:
|
|
|
|
s.checkConnect(s.controlHandler).ServeHTTP(w, r)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := s.handleHTTP(w, r); err != nil {
|
|
|
|
if !strings.Contains(err.Error(), "no virtual host available") { // this one is outputted too much, unnecessarily
|
|
|
|
s.log.Error("remote %s (%s): %s", r.RemoteAddr, r.RequestURI, err)
|
|
|
|
}
|
|
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// handleHTTP handles a single HTTP request
|
|
|
|
func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) error {
|
|
|
|
s.log.Debug("HandleHTTP request:")
|
|
|
|
s.log.Debug("%v", r)
|
|
|
|
|
|
|
|
if s.httpDirector != nil {
|
|
|
|
s.httpDirector(r)
|
|
|
|
}
|
|
|
|
|
|
|
|
hostPort := strings.ToLower(r.Host)
|
|
|
|
if hostPort == "" {
|
|
|
|
return errors.New("request host is empty")
|
|
|
|
}
|
|
|
|
|
|
|
|
// if someone hits foo.example.com:8080, this should be proxied to
|
|
|
|
// localhost:8080, so send the port to the client so it knows how to proxy
|
|
|
|
// correctly. If no port is available, it's up to client how to interpret it
|
|
|
|
host, port, err := parseHostPort(hostPort)
|
|
|
|
if err != nil {
|
|
|
|
// no need to return, just continue lazily, port will be 0, which in
|
|
|
|
// our case will be proxied to client's local servers port 80
|
|
|
|
s.log.Debug("No port available for %q, sending port 80 to client", hostPort)
|
|
|
|
}
|
|
|
|
|
|
|
|
// get the identifier associated with this host
|
|
|
|
identifier, ok := s.getIdentifier(hostPort)
|
|
|
|
if !ok {
|
|
|
|
// fallback to host
|
|
|
|
identifier, ok = s.getIdentifier(host)
|
|
|
|
if !ok {
|
|
|
|
return fmt.Errorf("no virtual host available for %q", hostPort)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if isWebsocketConn(r) {
|
|
|
|
s.log.Debug("handling websocket connection")
|
|
|
|
|
|
|
|
return s.handleWSConn(w, r, identifier, port)
|
|
|
|
}
|
|
|
|
|
|
|
|
stream, err := s.dial(identifier, proto.HTTP, port)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer func() {
|
|
|
|
s.log.Debug("Closing stream")
|
|
|
|
stream.Close()
|
|
|
|
}()
|
|
|
|
|
|
|
|
if err := r.Write(stream); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
s.log.Debug("Session opened to client, writing request to client")
|
|
|
|
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("read from tunnel: %s", err.Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
if resp.Body != nil {
|
|
|
|
if err := resp.Body.Close(); err != nil && err != io.ErrUnexpectedEOF {
|
|
|
|
s.log.Error("resp.Body Close error: %s", err.Error())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
s.log.Debug("Response received, writing back to public connection: %+v", resp)
|
|
|
|
|
|
|
|
copyHeader(w.Header(), resp.Header)
|
|
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
|
|
|
|
if _, err := io.Copy(w, resp.Body); err != nil {
|
|
|
|
if err == io.ErrUnexpectedEOF {
|
|
|
|
s.log.Debug("Client closed the connection, couldn't copy response")
|
|
|
|
} else {
|
|
|
|
s.log.Error("copy err: %s", err) // do not return, because we might write multipe headers
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) serveTCP() {
|
|
|
|
for conn := range s.connCh {
|
|
|
|
go s.serveTCPConn(conn)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) serveTCPConn(conn net.Conn) {
|
|
|
|
err := s.handleTCPConn(conn)
|
|
|
|
if err != nil {
|
|
|
|
s.log.Warning("failed to serve %q: %s", conn.RemoteAddr(), err)
|
|
|
|
conn.Close()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) handleWSConn(w http.ResponseWriter, r *http.Request, ident string, port int) error {
|
|
|
|
hj, ok := w.(http.Hijacker)
|
|
|
|
if !ok {
|
|
|
|
return errors.New("webserver doesn't support hijacking")
|
|
|
|
}
|
|
|
|
|
|
|
|
conn, _, err := hj.Hijack()
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("hijack not possible: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
stream, err := s.dial(ident, proto.WS, port)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := r.Write(stream); err != nil {
|
|
|
|
err = errors.New("unable to write upgrade request: " + err.Error())
|
|
|
|
return nonil(err, stream.Close())
|
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
|
|
|
if err != nil {
|
|
|
|
err = errors.New("unable to read upgrade response: " + err.Error())
|
|
|
|
return nonil(err, stream.Close())
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := resp.Write(conn); err != nil {
|
|
|
|
err = errors.New("unable to write upgrade response: " + err.Error())
|
|
|
|
return nonil(err, stream.Close())
|
|
|
|
}
|
|
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
wg.Add(2)
|
|
|
|
|
|
|
|
go s.proxy(&wg, conn, stream)
|
|
|
|
go s.proxy(&wg, stream, conn)
|
|
|
|
|
|
|
|
wg.Wait()
|
|
|
|
|
|
|
|
return nonil(stream.Close(), conn.Close())
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) handleTCPConn(conn net.Conn) error {
|
|
|
|
ident, ok := s.virtualAddrs.getIdent(conn)
|
|
|
|
if !ok {
|
|
|
|
return fmt.Errorf("no virtual address available for %s", conn.LocalAddr())
|
|
|
|
}
|
|
|
|
|
|
|
|
_, port, err := parseHostPort(conn.LocalAddr().String())
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
stream, err := s.dial(ident, proto.TCP, port)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
wg.Add(2)
|
|
|
|
|
|
|
|
go s.proxy(&wg, conn, stream)
|
|
|
|
go s.proxy(&wg, stream, conn)
|
|
|
|
|
|
|
|
wg.Wait()
|
|
|
|
|
|
|
|
return nonil(stream.Close(), conn.Close())
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) proxy(wg *sync.WaitGroup, dst, src net.Conn) {
|
|
|
|
defer wg.Done()
|
|
|
|
|
|
|
|
s.log.Debug("tunneling %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
|
|
|
|
n, err := io.Copy(dst, src)
|
|
|
|
s.log.Debug("tunneled %d bytes %s -> %s: %v", n, src.RemoteAddr(), dst.RemoteAddr(), err)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) dial(identifier string, p proto.Type, port int) (net.Conn, error) {
|
|
|
|
control, ok := s.getControl(identifier)
|
|
|
|
if !ok {
|
|
|
|
return nil, errNoClientSession
|
|
|
|
}
|
|
|
|
|
|
|
|
session, err := s.getSession(identifier)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
msg := proto.ControlMessage{
|
|
|
|
Action: proto.RequestClientSession,
|
|
|
|
Protocol: p,
|
|
|
|
LocalPort: port,
|
|
|
|
}
|
|
|
|
|
|
|
|
s.log.Debug("Sending control msg %+v", msg)
|
|
|
|
|
|
|
|
// ask client to open a session to us, so we can accept it
|
|
|
|
if err := control.send(msg); err != nil {
|
|
|
|
// we might have several issues here, either the stream is closed, or
|
|
|
|
// the session is going be shut down, the underlying connection might
|
|
|
|
// be broken. In all cases, it's not reliable anymore having a client
|
|
|
|
// session.
|
|
|
|
control.Close()
|
|
|
|
s.deleteControl(identifier)
|
|
|
|
return nil, errNoClientSession
|
|
|
|
}
|
|
|
|
|
|
|
|
var stream net.Conn
|
|
|
|
acceptStream := func() error {
|
|
|
|
stream, err = session.Accept()
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// if we don't receive anything from the client, we'll timeout
|
|
|
|
s.log.Debug("Waiting for session accept")
|
|
|
|
|
|
|
|
select {
|
|
|
|
case err := <-async(acceptStream):
|
|
|
|
return stream, err
|
|
|
|
case <-time.After(defaultTimeout):
|
|
|
|
return nil, errors.New("timeout getting session")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// controlHandler is used to capture incoming tunnel connect requests into raw
|
|
|
|
// tunnel TCP connections.
|
|
|
|
func (s *Server) controlHandler(w http.ResponseWriter, r *http.Request) (ctErr error) {
|
|
|
|
identifier := r.Header.Get(proto.ClientIdentifierHeader)
|
|
|
|
_, ok := s.getHost(identifier)
|
|
|
|
if !ok {
|
|
|
|
return fmt.Errorf("no host associated for identifier %s. please use server.AddHost()", identifier)
|
|
|
|
}
|
|
|
|
|
|
|
|
ct, ok := s.getControl(identifier)
|
|
|
|
if ok {
|
|
|
|
ct.Close()
|
|
|
|
s.deleteControl(identifier)
|
|
|
|
s.deleteSession(identifier)
|
|
|
|
s.log.Warning("Control connection for %q already exists. This is a race condition and needs to be fixed on client implementation", identifier)
|
|
|
|
return fmt.Errorf("control conn for %s already exist. \n", identifier)
|
|
|
|
}
|
|
|
|
|
|
|
|
s.log.Debug("Tunnel with identifier %s", identifier)
|
|
|
|
|
|
|
|
hj, ok := w.(http.Hijacker)
|
|
|
|
if !ok {
|
|
|
|
return errors.New("webserver doesn't support hijacking")
|
|
|
|
}
|
|
|
|
|
|
|
|
conn, _, err := hj.Hijack()
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("hijack not possible: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if _, err := io.WriteString(conn, "HTTP/1.1 "+proto.Connected+"\n\n"); err != nil {
|
|
|
|
return fmt.Errorf("error writing response: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := conn.SetDeadline(time.Time{}); err != nil {
|
|
|
|
return fmt.Errorf("error setting connection deadline: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
s.log.Debug("Creating control session")
|
|
|
|
session, err := yamux.Server(conn, s.yamuxConfig)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
s.addSession(identifier, session)
|
|
|
|
|
|
|
|
var stream net.Conn
|
|
|
|
|
|
|
|
// close and delete the session/stream if something goes wrong
|
|
|
|
defer func() {
|
|
|
|
if ctErr != nil {
|
|
|
|
if stream != nil {
|
|
|
|
stream.Close()
|
|
|
|
}
|
|
|
|
s.deleteSession(identifier)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
acceptStream := func() error {
|
|
|
|
stream, err = session.Accept()
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// if we don't receive anything from the client, we'll timeout
|
|
|
|
select {
|
|
|
|
case err := <-async(acceptStream):
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
case <-time.After(time.Second * 10):
|
|
|
|
return errors.New("timeout getting session")
|
|
|
|
}
|
|
|
|
|
|
|
|
s.log.Debug("Initiating handshake protocol")
|
|
|
|
buf := make([]byte, len(proto.HandshakeRequest))
|
|
|
|
if _, err := stream.Read(buf); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if string(buf) != proto.HandshakeRequest {
|
|
|
|
return fmt.Errorf("handshake aborted. got: %s", string(buf))
|
|
|
|
}
|
|
|
|
|
|
|
|
if _, err := stream.Write([]byte(proto.HandshakeResponse)); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// setup control stream and start to listen to messages
|
|
|
|
ct = newControl(stream)
|
|
|
|
s.addControl(identifier, ct)
|
|
|
|
go s.listenControl(ct)
|
|
|
|
|
|
|
|
s.log.Debug("Control connection is setup")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// listenControl listens to messages coming from the client.
|
|
|
|
func (s *Server) listenControl(ct *control) {
|
|
|
|
s.onConnect(ct.identifier)
|
|
|
|
|
|
|
|
for {
|
|
|
|
var msg map[string]interface{}
|
|
|
|
err := ct.dec.Decode(&msg)
|
|
|
|
if err != nil {
|
|
|
|
host, _ := s.getHost(ct.identifier)
|
|
|
|
s.log.Debug("Closing client connection: '%s', %s'", host, ct.identifier)
|
|
|
|
|
|
|
|
// close client connection so it reconnects again
|
|
|
|
ct.Close()
|
|
|
|
|
|
|
|
// don't forget to cleanup anything
|
|
|
|
s.deleteControl(ct.identifier)
|
|
|
|
s.deleteSession(ct.identifier)
|
|
|
|
|
|
|
|
s.onDisconnect(ct.identifier, err)
|
|
|
|
|
|
|
|
if err != io.EOF {
|
|
|
|
s.log.Error("decode err: %s", err)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// right now we don't do anything with the messages, but because the
|
|
|
|
// underlying connection needs to establihsed, we know when we have
|
|
|
|
// disconnection(above), so we can cleanup the connection.
|
|
|
|
s.log.Debug("msg: %s", msg)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// OnConnect invokes a callback for client with given identifier,
|
|
|
|
// when it establishes a control session.
|
|
|
|
// After a client is connected, the associated function
|
|
|
|
// is also removed and needs to be added again.
|
|
|
|
func (s *Server) OnConnect(identifier string, fn func() error) {
|
|
|
|
s.onConnectCallbacks.add(identifier, fn)
|
|
|
|
}
|
|
|
|
|
|
|
|
// onConnect sends notifications to listeners (registered in onConnectCallbacks
|
|
|
|
// or stateChanges chanel readers) when client connects.
|
|
|
|
func (s *Server) onConnect(identifier string) {
|
|
|
|
if err := s.onConnectCallbacks.call(identifier); err != nil {
|
|
|
|
s.log.Error("OnConnect: error calling callback for %q: %s", identifier, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
s.changeState(identifier, ClientConnected, nil)
|
|
|
|
}
|
|
|
|
|
|
|
|
// OnDisconnect calls the function when the client connected with the
|
|
|
|
// associated identifier disconnects from the server.
|
|
|
|
// After a client is disconnected, the associated function
|
|
|
|
// is also removed and needs to be added again.
|
|
|
|
func (s *Server) OnDisconnect(identifier string, fn func() error) {
|
|
|
|
s.onDisconnectCallbacks.add(identifier, fn)
|
|
|
|
}
|
|
|
|
|
|
|
|
// onDisconnect sends notifications to listeners (registered in onDisconnectCallbacks
|
|
|
|
// or stateChanges chanel readers) when client disconnects.
|
|
|
|
func (s *Server) onDisconnect(identifier string, err error) {
|
|
|
|
if err := s.onDisconnectCallbacks.call(identifier); err != nil {
|
|
|
|
s.log.Error("OnDisconnect: error calling callback for %q: %s", identifier, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
s.changeState(identifier, ClientClosed, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) changeState(identifier string, state ClientState, err error) (prev ClientState) {
|
|
|
|
s.statesMu.Lock()
|
|
|
|
defer s.statesMu.Unlock()
|
|
|
|
|
|
|
|
prev = s.states[identifier]
|
|
|
|
s.states[identifier] = state
|
|
|
|
|
|
|
|
if s.stateCh != nil {
|
|
|
|
change := &ClientStateChange{
|
|
|
|
Identifier: identifier,
|
|
|
|
Previous: prev,
|
|
|
|
Current: state,
|
|
|
|
Error: err,
|
|
|
|
}
|
|
|
|
|
|
|
|
select {
|
|
|
|
case s.stateCh <- change:
|
|
|
|
default:
|
|
|
|
s.log.Warning("Dropping state change due to slow reader: %s", change)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return prev
|
|
|
|
}
|
|
|
|
|
|
|
|
// AddHost adds the given virtual host and maps it to the identifier.
|
|
|
|
func (s *Server) AddHost(host, identifier string) {
|
|
|
|
s.virtualHosts.AddHost(host, identifier)
|
|
|
|
}
|
|
|
|
|
|
|
|
// DeleteHost deletes the given virtual host. Once removed any request to this
|
|
|
|
// host is denied.
|
|
|
|
func (s *Server) DeleteHost(host string) {
|
|
|
|
s.virtualHosts.DeleteHost(host)
|
|
|
|
}
|
|
|
|
|
|
|
|
// AddAddr starts accepting connections on listener l, routing every connection
|
|
|
|
// to a tunnel client given by the identifier.
|
|
|
|
//
|
|
|
|
// When ip parameter is nil, all connections accepted from the listener are
|
|
|
|
// routed to the tunnel client specified by the identifier (port-based routing).
|
|
|
|
//
|
|
|
|
// When ip parameter is non-nil, only those connections are routed whose local
|
|
|
|
// address matches the specified ip (ip-based routing).
|
|
|
|
//
|
|
|
|
// If l listens on multiple interfaces it's desirable to call AddAddr multiple
|
|
|
|
// times with the same l value but different ip one.
|
|
|
|
func (s *Server) AddAddr(l net.Listener, ip net.IP, identifier string) {
|
|
|
|
s.virtualAddrs.Add(l, ip, identifier)
|
|
|
|
}
|
|
|
|
|
|
|
|
// DeleteAddr stops listening for connections on the given listener.
|
|
|
|
//
|
|
|
|
// Upon return no more connections will be tunneled, but as the method does not
|
|
|
|
// close the listener, so any ongoing connection won't get interrupted.
|
|
|
|
func (s *Server) DeleteAddr(l net.Listener, ip net.IP) {
|
|
|
|
s.virtualAddrs.Delete(l, ip)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) getIdentifier(host string) (string, bool) {
|
|
|
|
identifier, ok := s.virtualHosts.GetIdentifier(host)
|
|
|
|
return identifier, ok
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) getHost(identifier string) (string, bool) {
|
|
|
|
host, ok := s.virtualHosts.GetHost(identifier)
|
|
|
|
return host, ok
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) addControl(identifier string, conn *control) {
|
|
|
|
s.controls.addControl(identifier, conn)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) getControl(identifier string) (*control, bool) {
|
|
|
|
return s.controls.getControl(identifier)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) deleteControl(identifier string) {
|
|
|
|
s.controls.deleteControl(identifier)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) getSession(identifier string) (*yamux.Session, error) {
|
|
|
|
s.sessionsMu.Lock()
|
|
|
|
session, ok := s.sessions[identifier]
|
|
|
|
s.sessionsMu.Unlock()
|
|
|
|
|
|
|
|
if !ok {
|
|
|
|
return nil, fmt.Errorf("no session available for identifier: '%s'", identifier)
|
|
|
|
}
|
|
|
|
|
|
|
|
return session, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) addSession(identifier string, session *yamux.Session) {
|
|
|
|
s.sessionsMu.Lock()
|
|
|
|
s.sessions[identifier] = session
|
|
|
|
s.sessionsMu.Unlock()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) deleteSession(identifier string) {
|
|
|
|
s.sessionsMu.Lock()
|
|
|
|
defer s.sessionsMu.Unlock()
|
|
|
|
|
|
|
|
session, ok := s.sessions[identifier]
|
|
|
|
|
|
|
|
if !ok {
|
|
|
|
return // nothing to delete
|
|
|
|
}
|
|
|
|
|
|
|
|
if session != nil {
|
|
|
|
session.GoAway() // don't accept any new connection
|
|
|
|
session.Close()
|
|
|
|
}
|
|
|
|
|
|
|
|
delete(s.sessions, identifier)
|
|
|
|
}
|
|
|
|
|
|
|
|
func copyHeader(dst, src http.Header) {
|
|
|
|
for k, v := range src {
|
|
|
|
vv := make([]string, len(v))
|
|
|
|
copy(vv, v)
|
|
|
|
dst[k] = vv
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// checkConnect checks whether the incoming request is HTTP CONNECT method.
|
|
|
|
func (s *Server) checkConnect(fn func(w http.ResponseWriter, r *http.Request) error) http.Handler {
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
if r.Method != "CONNECT" {
|
|
|
|
http.Error(w, "405 must CONNECT\n", http.StatusMethodNotAllowed)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := fn(w, r); err != nil {
|
|
|
|
s.log.Error("Handler err: %v", err.Error())
|
|
|
|
|
|
|
|
if identifier := r.Header.Get(proto.ClientIdentifierHeader); identifier != "" {
|
|
|
|
s.onDisconnect(identifier, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
http.Error(w, err.Error(), 502)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func parseHostPort(addr string) (string, int, error) {
|
|
|
|
host, port, err := net.SplitHostPort(addr)
|
|
|
|
if err != nil {
|
|
|
|
return "", 0, err
|
|
|
|
}
|
|
|
|
|
|
|
|
n, err := strconv.ParseUint(port, 10, 16)
|
|
|
|
if err != nil {
|
|
|
|
return "", 0, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return host, int(n), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func isWebsocketConn(r *http.Request) bool {
|
|
|
|
return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") &&
|
|
|
|
headerContains(r.Header["Upgrade"], "websocket")
|
|
|
|
}
|
|
|
|
|
|
|
|
// headerContains is a copy of tokenListContainsValue from gorilla/websocket/util.go
|
|
|
|
func headerContains(header []string, value string) bool {
|
|
|
|
for _, h := range header {
|
|
|
|
for _, v := range strings.Split(h, ",") {
|
|
|
|
if strings.EqualFold(strings.TrimSpace(v), value) {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
func nonil(err ...error) error {
|
|
|
|
for _, e := range err {
|
|
|
|
if e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func newLogger(name string, debug bool) logging.Logger {
|
|
|
|
log := logging.NewLogger(name)
|
|
|
|
logHandler := logging.NewWriterHandler(os.Stderr)
|
|
|
|
logHandler.Colorize = true
|
|
|
|
log.SetHandler(logHandler)
|
|
|
|
|
|
|
|
if debug {
|
|
|
|
log.SetLevel(logging.DEBUG)
|
|
|
|
logHandler.SetLevel(logging.DEBUG)
|
|
|
|
}
|
|
|
|
|
|
|
|
return log
|
|
|
|
}
|