route/vendor/github.com/mmatczuk/go-http-tunnel/cmd/tunnel/tunnel.go

169 lines
3.7 KiB
Go

// Copyright (C) 2017 MichaƂ Matczuk
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/url"
"os"
"sort"
"gopkg.in/yaml.v2"
"github.com/cenkalti/backoff"
"github.com/mmatczuk/go-http-tunnel"
"github.com/mmatczuk/go-http-tunnel/id"
"github.com/mmatczuk/go-http-tunnel/log"
"github.com/mmatczuk/go-http-tunnel/proto"
)
func main() {
opts, err := parseArgs()
if err != nil {
fatal(err.Error())
}
if opts.version {
fmt.Println(version)
return
}
logger, err := log.NewLogger(opts.logTo, opts.logLevel)
if err != nil {
fatal("failed to init logger: %s", err)
}
// read configuration file
config, err := loadConfiguration(opts.config)
if err != nil {
fatal("configuration error: %s", err)
}
switch opts.command {
case "id":
cert, err := tls.LoadX509KeyPair(config.TLSCrt, config.TLSKey)
if err != nil {
fatal("failed to load key pair: %s", err)
}
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
fatal("failed to parse certificate: %s", err)
}
fmt.Println(id.New(x509Cert.Raw))
return
case "list":
var names []string
for n := range config.Tunnels {
names = append(names, n)
}
sort.Strings(names)
for _, n := range names {
fmt.Println(n)
}
return
case "start":
tunnels := make(map[string]*tunnelConfig)
for _, arg := range opts.args {
t, ok := config.Tunnels[arg]
if !ok {
fatal("no such tunnel %q", arg)
}
tunnels[arg] = t
}
config.Tunnels = tunnels
}
cert, err := tls.LoadX509KeyPair(config.TLSCrt, config.TLSKey)
if err != nil {
fatal("failed to load certificate: %s", err)
}
b, err := yaml.Marshal(config)
if err != nil {
fatal("failed to load config: %s", err)
}
logger.Log("config", string(b))
client := tunnel.NewClient(&tunnel.ClientConfig{
ServerAddr: config.ServerAddr,
TLSClientConfig: tlsConfig(cert, config),
Backoff: expBackoff(config.Backoff),
Tunnels: tunnels(config.Tunnels),
Proxy: proxy(config.Tunnels, logger),
Logger: logger,
})
if err := client.Start(); err != nil {
fatal("%s", err)
}
}
func tlsConfig(cert tls.Certificate, config *config) *tls.Config {
return &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: config.InsecureSkipVerify,
}
}
func expBackoff(config *backoffConfig) *backoff.ExponentialBackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = config.InitialInterval
b.Multiplier = config.Multiplier
b.MaxInterval = config.MaxInterval
b.MaxElapsedTime = config.MaxElapsedTime
return b
}
func tunnels(m map[string]*tunnelConfig) map[string]*proto.Tunnel {
p := make(map[string]*proto.Tunnel)
for name, t := range m {
p[name] = &proto.Tunnel{
Protocol: t.Protocol,
Host: t.Host,
Auth: t.Auth,
Addr: t.RemoteAddr,
}
}
return p
}
func proxy(m map[string]*tunnelConfig, logger log.Logger) tunnel.ProxyFunc {
httpURL := make(map[string]*url.URL)
tcpAddr := make(map[string]string)
for _, t := range m {
switch t.Protocol {
case proto.HTTP:
u, err := url.Parse(t.Addr)
if err != nil {
fatal("invalid tunnel address: %s", err)
}
httpURL[t.Host] = u
case proto.TCP, proto.TCP4, proto.TCP6:
tcpAddr[t.RemoteAddr] = t.Addr
}
}
return tunnel.Proxy(tunnel.ProxyFuncs{
HTTP: tunnel.NewMultiHTTPProxy(httpURL, log.NewContext(logger).WithPrefix("proxy", "HTTP")).Proxy,
TCP: tunnel.NewMultiTCPProxy(tcpAddr, log.NewContext(logger).WithPrefix("proxy", "TCP")).Proxy,
})
}
func fatal(format string, a ...interface{}) {
fmt.Fprintf(os.Stderr, format, a...)
fmt.Fprint(os.Stderr, "\n")
os.Exit(1)
}