157 lines
4.0 KiB
Go
157 lines
4.0 KiB
Go
// Copyright (C) 2017 MichaĆ Matczuk
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"io/ioutil"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"gopkg.in/yaml.v2"
|
|
|
|
"github.com/mmatczuk/go-http-tunnel/proto"
|
|
)
|
|
|
|
type backoffConfig struct {
|
|
InitialInterval time.Duration `yaml:"interval,omitempty"`
|
|
Multiplier float64 `yaml:"multiplier,omitempty"`
|
|
MaxInterval time.Duration `yaml:"max_interval,omitempty"`
|
|
MaxElapsedTime time.Duration `yaml:"max_time,omitempty"`
|
|
}
|
|
|
|
type tunnelConfig struct {
|
|
Protocol string `yaml:"proto,omitempty"`
|
|
Addr string `yaml:"addr,omitempty"`
|
|
Auth string `yaml:"auth,omitempty"`
|
|
Host string `yaml:"host,omitempty"`
|
|
RemoteAddr string `yaml:"remote_addr,omitempty"`
|
|
}
|
|
|
|
type config struct {
|
|
ServerAddr string `yaml:"server_addr,omitempty"`
|
|
InsecureSkipVerify bool `yaml:"insecure_skip_verify,omitempty"`
|
|
TLSCrt string `yaml:"tls_crt,omitempty"`
|
|
TLSKey string `yaml:"tls_key,omitempty"`
|
|
Backoff *backoffConfig `yaml:"backoff,omitempty"`
|
|
Tunnels map[string]*tunnelConfig `yaml:"tunnels,omitempty"`
|
|
}
|
|
|
|
var defaultBackoffConfig = backoffConfig{
|
|
InitialInterval: 500 * time.Millisecond,
|
|
Multiplier: 1.5,
|
|
MaxInterval: 60 * time.Second,
|
|
MaxElapsedTime: 15 * time.Minute,
|
|
}
|
|
|
|
func loadConfiguration(path string) (*config, error) {
|
|
configBuf, err := ioutil.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read file %q: %s", path, err)
|
|
}
|
|
|
|
// deserialize/parse the config
|
|
var config config
|
|
if err = yaml.Unmarshal(configBuf, &config); err != nil {
|
|
return nil, fmt.Errorf("failed to parse file %q: %s", path, err)
|
|
}
|
|
|
|
// set default values
|
|
if config.TLSCrt == "" {
|
|
config.TLSCrt = filepath.Join(filepath.Dir(path), "client.crt")
|
|
}
|
|
if config.TLSKey == "" {
|
|
config.TLSKey = filepath.Join(filepath.Dir(path), "client.key")
|
|
}
|
|
|
|
if config.Backoff == nil {
|
|
config.Backoff = &defaultBackoffConfig
|
|
} else {
|
|
if config.Backoff.InitialInterval == 0 {
|
|
config.Backoff.InitialInterval = defaultBackoffConfig.InitialInterval
|
|
}
|
|
if config.Backoff.Multiplier == 0 {
|
|
config.Backoff.Multiplier = defaultBackoffConfig.Multiplier
|
|
}
|
|
if config.Backoff.MaxInterval == 0 {
|
|
config.Backoff.MaxInterval = defaultBackoffConfig.MaxInterval
|
|
}
|
|
if config.Backoff.MaxElapsedTime == 0 {
|
|
config.Backoff.MaxElapsedTime = defaultBackoffConfig.MaxElapsedTime
|
|
}
|
|
}
|
|
|
|
// validate and normalize configuration
|
|
if config.ServerAddr == "" {
|
|
return nil, fmt.Errorf("server_addr: missing")
|
|
}
|
|
|
|
if config.ServerAddr, err = normalizeAddress(config.ServerAddr); err != nil {
|
|
return nil, fmt.Errorf("server_addr: %s", err)
|
|
}
|
|
|
|
for name, t := range config.Tunnels {
|
|
switch t.Protocol {
|
|
case proto.HTTP:
|
|
if err := validateHTTP(t); err != nil {
|
|
return nil, fmt.Errorf("%s %s", name, err)
|
|
}
|
|
case proto.TCP, proto.TCP4, proto.TCP6:
|
|
if err := validateTCP(t); err != nil {
|
|
return nil, fmt.Errorf("%s %s", name, err)
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("%s invalid protocol %q", name, t.Protocol)
|
|
}
|
|
}
|
|
|
|
return &config, nil
|
|
}
|
|
|
|
func validateHTTP(t *tunnelConfig) error {
|
|
var err error
|
|
if t.Host == "" {
|
|
return fmt.Errorf("host: missing")
|
|
}
|
|
if t.Addr == "" {
|
|
return fmt.Errorf("addr: missing")
|
|
}
|
|
if t.Addr, err = normalizeURL(t.Addr); err != nil {
|
|
return fmt.Errorf("addr: %s", err)
|
|
}
|
|
|
|
// unexpected
|
|
|
|
if t.RemoteAddr != "" {
|
|
return fmt.Errorf("remote_addr: unexpected")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateTCP(t *tunnelConfig) error {
|
|
var err error
|
|
if t.RemoteAddr, err = normalizeAddress(t.RemoteAddr); err != nil {
|
|
return fmt.Errorf("remote_addr: %s", err)
|
|
}
|
|
if t.Addr == "" {
|
|
return fmt.Errorf("addr: missing")
|
|
}
|
|
if t.Addr, err = normalizeAddress(t.Addr); err != nil {
|
|
return fmt.Errorf("addr: %s", err)
|
|
}
|
|
|
|
// unexpected
|
|
|
|
if t.Host != "" {
|
|
return fmt.Errorf("host: unexpected")
|
|
}
|
|
if t.Auth != "" {
|
|
return fmt.Errorf("auth: unexpected")
|
|
}
|
|
|
|
return nil
|
|
}
|