102 lines
2.5 KiB
Go
102 lines
2.5 KiB
Go
package tunnel
|
|
|
|
import (
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
|
|
"github.com/koding/logging"
|
|
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
|
)
|
|
|
|
// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back.
|
|
type ProxyFunc func(remote net.Conn, msg *proto.ControlMessage)
|
|
|
|
var (
|
|
// DefaultProxyFuncs holds global default proxy functions for all transport protocols.
|
|
DefaultProxyFuncs = ProxyFuncs{
|
|
HTTP: new(HTTPProxy).Proxy,
|
|
TCP: new(TCPProxy).Proxy,
|
|
WS: new(HTTPProxy).Proxy,
|
|
}
|
|
// DefaultProxy is a ProxyFunc that uses DefaultProxyFuncs.
|
|
DefaultProxy = Proxy(ProxyFuncs{})
|
|
)
|
|
|
|
// ProxyFuncs is a collection of ProxyFunc.
|
|
type ProxyFuncs struct {
|
|
// HTTP is custom implementation of HTTP proxing.
|
|
HTTP ProxyFunc
|
|
// TCP is custom implementation of TCP proxing.
|
|
TCP ProxyFunc
|
|
// WS is custom implementation of web socket proxing.
|
|
WS ProxyFunc
|
|
}
|
|
|
|
// Proxy returns a ProxyFunc that uses custom function if provided, otherwise falls back to DefaultProxyFuncs.
|
|
func Proxy(p ProxyFuncs) ProxyFunc {
|
|
return func(remote net.Conn, msg *proto.ControlMessage) {
|
|
var f ProxyFunc
|
|
switch msg.Protocol {
|
|
case proto.HTTP:
|
|
f = DefaultProxyFuncs.HTTP
|
|
if p.HTTP != nil {
|
|
f = p.HTTP
|
|
}
|
|
case proto.TCP:
|
|
f = DefaultProxyFuncs.TCP
|
|
if p.TCP != nil {
|
|
f = p.TCP
|
|
}
|
|
case proto.WS:
|
|
f = DefaultProxyFuncs.WS
|
|
if p.WS != nil {
|
|
f = p.WS
|
|
}
|
|
}
|
|
|
|
if f == nil {
|
|
logging.Error("Could not determine proxy function for %v", msg)
|
|
remote.Close()
|
|
}
|
|
|
|
f(remote, msg)
|
|
}
|
|
}
|
|
|
|
// Join copies data between local and remote connections.
|
|
// It reads from one connection and writes to the other.
|
|
// It's a building block for ProxyFunc implementations.
|
|
func Join(local, remote net.Conn, log logging.Logger) {
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
transfer := func(side string, dst, src net.Conn) {
|
|
log.Debug("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
|
|
|
|
n, err := io.Copy(dst, src)
|
|
if err != nil {
|
|
log.Error("%s: copy error: %s", side, err)
|
|
}
|
|
|
|
if err := src.Close(); err != nil {
|
|
log.Debug("%s: close error: %s", side, err)
|
|
}
|
|
|
|
// not for yamux streams, but for client to local server connections
|
|
if d, ok := dst.(*net.TCPConn); ok {
|
|
if err := d.CloseWrite(); err != nil {
|
|
log.Debug("%s: closeWrite error: %s", side, err)
|
|
}
|
|
|
|
}
|
|
wg.Done()
|
|
log.Debug("done proxing %s -> %s: %d bytes", src.RemoteAddr(), dst.RemoteAddr(), n)
|
|
}
|
|
|
|
go transfer("remote to local", local, remote)
|
|
go transfer("local to remote", remote, local)
|
|
|
|
wg.Wait()
|
|
}
|