From 48327a766903ba3abd405381bfeefc1e45d46155 Mon Sep 17 00:00:00 2001 From: Christine Dodrill Date: Mon, 23 Jan 2017 09:07:39 -0800 Subject: [PATCH] lib/tunnel: use httputil.ReverseProxy --- lib/tunnel/httpproxy.go | 63 ++++++++++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/lib/tunnel/httpproxy.go b/lib/tunnel/httpproxy.go index 2e18ab1..9aa8a83 100644 --- a/lib/tunnel/httpproxy.go +++ b/lib/tunnel/httpproxy.go @@ -7,9 +7,12 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httputil" + "net/url" + "sync" - "github.com/koding/logging" "git.xeserv.us/xena/route/lib/tunnel/proto" + "github.com/koding/logging" ) var ( @@ -29,9 +32,6 @@ type HTTPProxy struct { // LocalAddr defines the TCP address of the local server. // This is optional if you want to specify a single TCP address. LocalAddr string - // FetchLocalAddr is used for looking up TCP address of the server. - // This is optional if you want to specify a dynamic TCP address based on incommig port. - FetchLocalAddr func(port int) (string, error) // ErrorResp is custom response send to tunnel server when client cannot // establish connection to local server. If not set a default "no local server" // response is sent. @@ -39,6 +39,9 @@ type HTTPProxy struct { // Log is a custom logger that can be used for the proxy. // If not set a "http" logger is used. Log logging.Logger + + hs *http.Server + rp *httputil.ReverseProxy } // Proxy is a ProxyFunc. @@ -57,25 +60,28 @@ func (p *HTTPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) { var localAddr = fmt.Sprintf("127.0.0.1:%d", port) if p.LocalAddr != "" { localAddr = p.LocalAddr - } else if p.FetchLocalAddr != nil { - l, err := p.FetchLocalAddr(msg.LocalPort) - if err != nil { - log.Warning("Failed to get custom local address: %s", err) - p.sendError(remote) - return + } + + if p.hs == nil { + su, _ := url.Parse(fmt.Sprintf("http://%s", p.LocalAddr)) + p.rp = httputil.NewSingleHostReverseProxy(su) + p.hs = &http.Server{ + Handler: p.rp, } - localAddr = l } log.Debug("Dialing local server %q", localAddr) - local, err := net.DialTimeout("tcp", localAddr, defaultTimeout) + + sl := singleListener{ + conn: remote, + } + + err := p.hs.Serve(sl) if err != nil { log.Error("Dialing local server %q failed: %s", localAddr, err) p.sendError(remote) return } - - Join(local, remote, log) } func (p *HTTPProxy) sendError(remote net.Conn) { @@ -113,3 +119,32 @@ func (p *HTTPProxy) log() logging.Logger { } return httpLog } + +// A singleListener is a net.Listener that returns a single connection, then +// gives the error io.EOF. +type singleListener struct { + conn net.Conn + once sync.Once +} + +func (s singleListener) Accept() (net.Conn, error) { + var c net.Conn + s.once.Do(func() { + c = s.conn + }) + if c != nil { + return c, nil + } + return nil, io.EOF +} + +func (s singleListener) Close() error { + s.once.Do(func() { + s.conn.Close() + }) + return nil +} + +func (s singleListener) Addr() net.Addr { + return s.conn.LocalAddr() +}