route/lib/tunnel/proxy.go

102 lines
2.5 KiB
Go
Raw Normal View History

2017-01-20 01:27:14 +00:00
package tunnel
import (
"io"
"net"
"sync"
"github.com/koding/logging"
"git.xeserv.us/xena/route/lib/tunnel/proto"
)
// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back.
type ProxyFunc func(remote net.Conn, msg *proto.ControlMessage)
var (
// DefaultProxyFuncs holds global default proxy functions for all transport protocols.
DefaultProxyFuncs = ProxyFuncs{
HTTP: new(HTTPProxy).Proxy,
TCP: new(TCPProxy).Proxy,
WS: new(HTTPProxy).Proxy,
}
// DefaultProxy is a ProxyFunc that uses DefaultProxyFuncs.
DefaultProxy = Proxy(ProxyFuncs{})
)
// ProxyFuncs is a collection of ProxyFunc.
type ProxyFuncs struct {
// HTTP is custom implementation of HTTP proxing.
HTTP ProxyFunc
// TCP is custom implementation of TCP proxing.
TCP ProxyFunc
// WS is custom implementation of web socket proxing.
WS ProxyFunc
}
// Proxy returns a ProxyFunc that uses custom function if provided, otherwise falls back to DefaultProxyFuncs.
func Proxy(p ProxyFuncs) ProxyFunc {
return func(remote net.Conn, msg *proto.ControlMessage) {
var f ProxyFunc
switch msg.Protocol {
case proto.HTTP:
f = DefaultProxyFuncs.HTTP
if p.HTTP != nil {
f = p.HTTP
}
case proto.TCP:
f = DefaultProxyFuncs.TCP
if p.TCP != nil {
f = p.TCP
}
case proto.WS:
f = DefaultProxyFuncs.WS
if p.WS != nil {
f = p.WS
}
}
if f == nil {
logging.Error("Could not determine proxy function for %v", msg)
remote.Close()
}
f(remote, msg)
}
}
// Join copies data between local and remote connections.
// It reads from one connection and writes to the other.
// It's a building block for ProxyFunc implementations.
func Join(local, remote net.Conn, log logging.Logger) {
var wg sync.WaitGroup
wg.Add(2)
transfer := func(side string, dst, src net.Conn) {
log.Debug("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
n, err := io.Copy(dst, src)
if err != nil {
log.Error("%s: copy error: %s", side, err)
}
if err := src.Close(); err != nil {
log.Debug("%s: close error: %s", side, err)
}
// not for yamux streams, but for client to local server connections
if d, ok := dst.(*net.TCPConn); ok {
if err := d.CloseWrite(); err != nil {
log.Debug("%s: closeWrite error: %s", side, err)
}
}
wg.Done()
log.Debug("done proxing %s -> %s: %d bytes", src.RemoteAddr(), dst.RemoteAddr(), n)
}
go transfer("remote to local", local, remote)
go transfer("local to remote", remote, local)
wg.Wait()
}