// Package tunnel is a server/client package that enables to proxy public // connections to your local machine over a tunnel connection from the local // machine to the public server. package tunnel import ( "bufio" "errors" "fmt" "io" "net" "net/http" "os" "path" "strconv" "strings" "sync" "time" "git.xeserv.us/xena/route/lib/tunnel/proto" "github.com/hashicorp/yamux" "github.com/koding/logging" ) var ( errNoClientSession = errors.New("no client session established") defaultTimeout = 10 * time.Second ) // Server is responsible for proxying public connections to the client over a // tunnel connection. It also listens to control messages from the client. type Server struct { // pending contains the channel that is associated with each new tunnel request. pending map[string]chan net.Conn // pendingMu protects the pending map. pendingMu sync.Mutex // sessions contains a session per virtual host. // Sessions provides multiplexing over one connection. sessions map[string]*yamux.Session // sessionsMu protects sessions. sessionsMu sync.Mutex // controls contains the control connection from the client to the server. controls *controls // virtualHosts is used to map public hosts to remote clients. virtualHosts vhostStorage // virtualAddrs. virtualAddrs *vaddrStorage // connCh is used to publish accepted connections for tcp tunnels. connCh chan net.Conn // onConnectCallbacks contains client callbacks called when control // session is established for a client with given identifier. onConnectCallbacks *callbacks // onDisconnectCallbacks contains client callbacks called when control // session is closed for a client with given identifier. onDisconnectCallbacks *callbacks // states represents current clients' connections state. states map[string]ClientState // statesMu protects states. statesMu sync.RWMutex // stateCh notifies receiver about client state changes. stateCh chan<- *ClientStateChange // httpDirector is provided by ServerConfig, if not nil decorates http requests // before forwarding them to client. httpDirector func(*http.Request) // yamuxConfig is passed to new yamux.Session's yamuxConfig *yamux.Config log logging.Logger } // ServerConfig defines the configuration for the Server type ServerConfig struct { // StateChanges receives state transition details each time client // connection state changes. The channel is expected to be sufficiently // buffered to keep up with event pace. // // If nil, no information about state transitions are dispatched // by the library. StateChanges chan<- *ClientStateChange // Director is a function that modifies HTTP request into a new HTTP request // before sending to client. If nil no modifications are done. Director func(*http.Request) // Debug enables debug mode, enable only if you want to debug the server Debug bool // Log defines the logger. If nil a default logging.Logger is used. Log logging.Logger // YamuxConfig defines the config which passed to every new yamux.Session. If nil // yamux.DefaultConfig() is used. YamuxConfig *yamux.Config } // NewServer creates a new Server. The defaults are used if config is nil. func NewServer(cfg *ServerConfig) (*Server, error) { yamuxConfig := yamux.DefaultConfig() if cfg.YamuxConfig != nil { if err := yamux.VerifyConfig(cfg.YamuxConfig); err != nil { return nil, err } yamuxConfig = cfg.YamuxConfig } log := newLogger("tunnel-server", cfg.Debug) if cfg.Log != nil { log = cfg.Log } connCh := make(chan net.Conn) opts := &vaddrOptions{ connCh: connCh, log: log, } s := &Server{ pending: make(map[string]chan net.Conn), sessions: make(map[string]*yamux.Session), onConnectCallbacks: newCallbacks("OnConnect"), onDisconnectCallbacks: newCallbacks("OnDisconnect"), virtualHosts: newVirtualHosts(), virtualAddrs: newVirtualAddrs(opts), controls: newControls(), states: make(map[string]ClientState), stateCh: cfg.StateChanges, httpDirector: cfg.Director, yamuxConfig: yamuxConfig, connCh: connCh, log: log, } go s.serveTCP() return s, nil } // ServeHTTP is a tunnel that creates an http/websocket tunnel between a // public connection and the client connection. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // if the user didn't add the control and tunnel handler manually, we'll // going to infer and call the respective path handlers. switch path.Clean(r.URL.Path) + "/" { case proto.ControlPath: s.checkConnect(s.controlHandler).ServeHTTP(w, r) return } if err := s.handleHTTP(w, r); err != nil { if !strings.Contains(err.Error(), "no virtual host available") { // this one is outputted too much, unnecessarily s.log.Error("remote %s (%s): %s", r.RemoteAddr, r.RequestURI, err) } http.Error(w, err.Error(), http.StatusBadGateway) } } // handleHTTP handles a single HTTP request func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) error { s.log.Debug("HandleHTTP request:") s.log.Debug("%v", r) if s.httpDirector != nil { s.httpDirector(r) } hostPort := strings.ToLower(r.Host) if hostPort == "" { return errors.New("request host is empty") } // if someone hits foo.example.com:8080, this should be proxied to // localhost:8080, so send the port to the client so it knows how to proxy // correctly. If no port is available, it's up to client how to interpret it host, port, err := parseHostPort(hostPort) if err != nil { // no need to return, just continue lazily, port will be 0, which in // our case will be proxied to client's local servers port 80 s.log.Debug("No port available for %q, sending port 80 to client", hostPort) } // get the identifier associated with this host identifier, ok := s.getIdentifier(hostPort) if !ok { // fallback to host identifier, ok = s.getIdentifier(host) if !ok { return fmt.Errorf("no virtual host available for %q", hostPort) } } if isWebsocketConn(r) { s.log.Debug("handling websocket connection") return s.handleWSConn(w, r, identifier, port) } stream, err := s.dial(identifier, proto.HTTP, port) if err != nil { return err } defer func() { s.log.Debug("Closing stream") stream.Close() }() if err := r.Write(stream); err != nil { return err } s.log.Debug("Session opened to client, writing request to client") resp, err := http.ReadResponse(bufio.NewReader(stream), r) if err != nil { return fmt.Errorf("read from tunnel: %s", err.Error()) } defer func() { if resp.Body != nil { if err := resp.Body.Close(); err != nil && err != io.ErrUnexpectedEOF { s.log.Error("resp.Body Close error: %s", err.Error()) } } }() s.log.Debug("Response received, writing back to public connection: %+v", resp) copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) if _, err := io.Copy(w, resp.Body); err != nil { if err == io.ErrUnexpectedEOF { s.log.Debug("Client closed the connection, couldn't copy response") } else { s.log.Error("copy err: %s", err) // do not return, because we might write multipe headers } } return nil } func (s *Server) serveTCP() { for conn := range s.connCh { go s.serveTCPConn(conn) } } func (s *Server) serveTCPConn(conn net.Conn) { err := s.handleTCPConn(conn) if err != nil { s.log.Warning("failed to serve %q: %s", conn.RemoteAddr(), err) conn.Close() } } func (s *Server) handleWSConn(w http.ResponseWriter, r *http.Request, ident string, port int) error { hj, ok := w.(http.Hijacker) if !ok { return errors.New("webserver doesn't support hijacking") } conn, _, err := hj.Hijack() if err != nil { return fmt.Errorf("hijack not possible: %s", err) } stream, err := s.dial(ident, proto.WS, port) if err != nil { return err } if err := r.Write(stream); err != nil { err = errors.New("unable to write upgrade request: " + err.Error()) return nonil(err, stream.Close()) } resp, err := http.ReadResponse(bufio.NewReader(stream), r) if err != nil { err = errors.New("unable to read upgrade response: " + err.Error()) return nonil(err, stream.Close()) } if err := resp.Write(conn); err != nil { err = errors.New("unable to write upgrade response: " + err.Error()) return nonil(err, stream.Close()) } var wg sync.WaitGroup wg.Add(2) go s.proxy(&wg, conn, stream) go s.proxy(&wg, stream, conn) wg.Wait() return nonil(stream.Close(), conn.Close()) } func (s *Server) handleTCPConn(conn net.Conn) error { ident, ok := s.virtualAddrs.getIdent(conn) if !ok { return fmt.Errorf("no virtual address available for %s", conn.LocalAddr()) } _, port, err := parseHostPort(conn.LocalAddr().String()) if err != nil { return err } stream, err := s.dial(ident, proto.TCP, port) if err != nil { return err } var wg sync.WaitGroup wg.Add(2) go s.proxy(&wg, conn, stream) go s.proxy(&wg, stream, conn) wg.Wait() return nonil(stream.Close(), conn.Close()) } func (s *Server) proxy(wg *sync.WaitGroup, dst, src net.Conn) { defer wg.Done() s.log.Debug("tunneling %s -> %s", src.RemoteAddr(), dst.RemoteAddr()) n, err := io.Copy(dst, src) s.log.Debug("tunneled %d bytes %s -> %s: %v", n, src.RemoteAddr(), dst.RemoteAddr(), err) } func (s *Server) dial(identifier string, p proto.Type, port int) (net.Conn, error) { control, ok := s.getControl(identifier) if !ok { return nil, errNoClientSession } session, err := s.getSession(identifier) if err != nil { return nil, err } msg := proto.ControlMessage{ Action: proto.RequestClientSession, Protocol: p, LocalPort: port, } s.log.Debug("Sending control msg %+v", msg) // ask client to open a session to us, so we can accept it if err := control.send(msg); err != nil { // we might have several issues here, either the stream is closed, or // the session is going be shut down, the underlying connection might // be broken. In all cases, it's not reliable anymore having a client // session. control.Close() s.deleteControl(identifier) return nil, errNoClientSession } var stream net.Conn acceptStream := func() error { stream, err = session.Accept() return err } // if we don't receive anything from the client, we'll timeout s.log.Debug("Waiting for session accept") select { case err := <-async(acceptStream): return stream, err case <-time.After(defaultTimeout): return nil, errors.New("timeout getting session") } } // controlHandler is used to capture incoming tunnel connect requests into raw // tunnel TCP connections. func (s *Server) controlHandler(w http.ResponseWriter, r *http.Request) (ctErr error) { identifier := r.Header.Get(proto.ClientIdentifierHeader) _, ok := s.getHost(identifier) if !ok { return fmt.Errorf("no host associated for identifier %s. please use server.AddHost()", identifier) } ct, ok := s.getControl(identifier) if ok { ct.Close() s.deleteControl(identifier) s.deleteSession(identifier) s.log.Warning("Control connection for %q already exists. This is a race condition and needs to be fixed on client implementation", identifier) return fmt.Errorf("control conn for %s already exist. \n", identifier) } s.log.Debug("Tunnel with identifier %s", identifier) hj, ok := w.(http.Hijacker) if !ok { return errors.New("webserver doesn't support hijacking") } conn, _, err := hj.Hijack() if err != nil { return fmt.Errorf("hijack not possible: %s", err) } if _, err := io.WriteString(conn, "HTTP/1.1 "+proto.Connected+"\n\n"); err != nil { return fmt.Errorf("error writing response: %s", err) } if err := conn.SetDeadline(time.Time{}); err != nil { return fmt.Errorf("error setting connection deadline: %s", err) } s.log.Debug("Creating control session") session, err := yamux.Server(conn, s.yamuxConfig) if err != nil { return err } s.addSession(identifier, session) var stream net.Conn // close and delete the session/stream if something goes wrong defer func() { if ctErr != nil { if stream != nil { stream.Close() } s.deleteSession(identifier) } }() acceptStream := func() error { stream, err = session.Accept() return err } // if we don't receive anything from the client, we'll timeout select { case err := <-async(acceptStream): if err != nil { return err } case <-time.After(time.Second * 10): return errors.New("timeout getting session") } s.log.Debug("Initiating handshake protocol") buf := make([]byte, len(proto.HandshakeRequest)) if _, err := stream.Read(buf); err != nil { return err } if string(buf) != proto.HandshakeRequest { return fmt.Errorf("handshake aborted. got: %s", string(buf)) } if _, err := stream.Write([]byte(proto.HandshakeResponse)); err != nil { return err } // setup control stream and start to listen to messages ct = newControl(stream) s.addControl(identifier, ct) go s.listenControl(ct) s.log.Debug("Control connection is setup") return nil } // listenControl listens to messages coming from the client. func (s *Server) listenControl(ct *control) { s.onConnect(ct.identifier) for { var msg map[string]interface{} err := ct.dec.Decode(&msg) if err != nil { host, _ := s.getHost(ct.identifier) s.log.Debug("Closing client connection: '%s', %s'", host, ct.identifier) // close client connection so it reconnects again ct.Close() // don't forget to cleanup anything s.deleteControl(ct.identifier) s.deleteSession(ct.identifier) s.onDisconnect(ct.identifier, err) if err != io.EOF { s.log.Error("decode err: %s", err) } return } // right now we don't do anything with the messages, but because the // underlying connection needs to establihsed, we know when we have // disconnection(above), so we can cleanup the connection. s.log.Debug("msg: %s", msg) } } // OnConnect invokes a callback for client with given identifier, // when it establishes a control session. // After a client is connected, the associated function // is also removed and needs to be added again. func (s *Server) OnConnect(identifier string, fn func() error) { s.onConnectCallbacks.add(identifier, fn) } // onConnect sends notifications to listeners (registered in onConnectCallbacks // or stateChanges chanel readers) when client connects. func (s *Server) onConnect(identifier string) { if err := s.onConnectCallbacks.call(identifier); err != nil { s.log.Error("OnConnect: error calling callback for %q: %s", identifier, err) } s.changeState(identifier, ClientConnected, nil) } // OnDisconnect calls the function when the client connected with the // associated identifier disconnects from the server. // After a client is disconnected, the associated function // is also removed and needs to be added again. func (s *Server) OnDisconnect(identifier string, fn func() error) { s.onDisconnectCallbacks.add(identifier, fn) } // onDisconnect sends notifications to listeners (registered in onDisconnectCallbacks // or stateChanges chanel readers) when client disconnects. func (s *Server) onDisconnect(identifier string, err error) { if err := s.onDisconnectCallbacks.call(identifier); err != nil { s.log.Error("OnDisconnect: error calling callback for %q: %s", identifier, err) } s.changeState(identifier, ClientClosed, err) } func (s *Server) changeState(identifier string, state ClientState, err error) (prev ClientState) { s.statesMu.Lock() defer s.statesMu.Unlock() prev = s.states[identifier] s.states[identifier] = state if s.stateCh != nil { change := &ClientStateChange{ Identifier: identifier, Previous: prev, Current: state, Error: err, } select { case s.stateCh <- change: default: s.log.Warning("Dropping state change due to slow reader: %s", change) } } return prev } // AddHost adds the given virtual host and maps it to the identifier. func (s *Server) AddHost(host, identifier string) { s.virtualHosts.AddHost(host, identifier) } // DeleteHost deletes the given virtual host. Once removed any request to this // host is denied. func (s *Server) DeleteHost(host string) { s.virtualHosts.DeleteHost(host) } // AddAddr starts accepting connections on listener l, routing every connection // to a tunnel client given by the identifier. // // When ip parameter is nil, all connections accepted from the listener are // routed to the tunnel client specified by the identifier (port-based routing). // // When ip parameter is non-nil, only those connections are routed whose local // address matches the specified ip (ip-based routing). // // If l listens on multiple interfaces it's desirable to call AddAddr multiple // times with the same l value but different ip one. func (s *Server) AddAddr(l net.Listener, ip net.IP, identifier string) { s.virtualAddrs.Add(l, ip, identifier) } // DeleteAddr stops listening for connections on the given listener. // // Upon return no more connections will be tunneled, but as the method does not // close the listener, so any ongoing connection won't get interrupted. func (s *Server) DeleteAddr(l net.Listener, ip net.IP) { s.virtualAddrs.Delete(l, ip) } func (s *Server) getIdentifier(host string) (string, bool) { identifier, ok := s.virtualHosts.GetIdentifier(host) return identifier, ok } func (s *Server) getHost(identifier string) (string, bool) { host, ok := s.virtualHosts.GetHost(identifier) return host, ok } func (s *Server) addControl(identifier string, conn *control) { s.controls.addControl(identifier, conn) } func (s *Server) getControl(identifier string) (*control, bool) { return s.controls.getControl(identifier) } func (s *Server) deleteControl(identifier string) { s.controls.deleteControl(identifier) } func (s *Server) getSession(identifier string) (*yamux.Session, error) { s.sessionsMu.Lock() session, ok := s.sessions[identifier] s.sessionsMu.Unlock() if !ok { return nil, fmt.Errorf("no session available for identifier: '%s'", identifier) } return session, nil } func (s *Server) addSession(identifier string, session *yamux.Session) { s.sessionsMu.Lock() s.sessions[identifier] = session s.sessionsMu.Unlock() } func (s *Server) deleteSession(identifier string) { s.sessionsMu.Lock() defer s.sessionsMu.Unlock() session, ok := s.sessions[identifier] if !ok { return // nothing to delete } if session != nil { session.GoAway() // don't accept any new connection session.Close() } delete(s.sessions, identifier) } func copyHeader(dst, src http.Header) { for k, v := range src { vv := make([]string, len(v)) copy(vv, v) dst[k] = vv } } // checkConnect checks whether the incoming request is HTTP CONNECT method. func (s *Server) checkConnect(fn func(w http.ResponseWriter, r *http.Request) error) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "CONNECT" { http.Error(w, "405 must CONNECT\n", http.StatusMethodNotAllowed) return } if err := fn(w, r); err != nil { s.log.Error("Handler err: %v", err.Error()) if identifier := r.Header.Get(proto.ClientIdentifierHeader); identifier != "" { s.onDisconnect(identifier, err) } http.Error(w, err.Error(), 502) } }) } func parseHostPort(addr string) (string, int, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return "", 0, err } n, err := strconv.ParseUint(port, 10, 16) if err != nil { return "", 0, err } return host, int(n), nil } func isWebsocketConn(r *http.Request) bool { return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") && headerContains(r.Header["Upgrade"], "websocket") } // headerContains is a copy of tokenListContainsValue from gorilla/websocket/util.go func headerContains(header []string, value string) bool { for _, h := range header { for _, v := range strings.Split(h, ",") { if strings.EqualFold(strings.TrimSpace(v), value) { return true } } } return false } func nonil(err ...error) error { for _, e := range err { if e != nil { return e } } return nil } func newLogger(name string, debug bool) logging.Logger { log := logging.NewLogger(name) logHandler := logging.NewWriterHandler(os.Stderr) logHandler.Colorize = true log.SetHandler(logHandler) if debug { log.SetLevel(logging.DEBUG) logHandler.SetLevel(logging.DEBUG) } return log }