// Copyright (C) 2017 MichaƂ Matczuk // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package main import ( "fmt" "io/ioutil" "path/filepath" "time" "gopkg.in/yaml.v2" "github.com/mmatczuk/go-http-tunnel/proto" ) type backoffConfig struct { InitialInterval time.Duration `yaml:"interval,omitempty"` Multiplier float64 `yaml:"multiplier,omitempty"` MaxInterval time.Duration `yaml:"max_interval,omitempty"` MaxElapsedTime time.Duration `yaml:"max_time,omitempty"` } type tunnelConfig struct { Protocol string `yaml:"proto,omitempty"` Addr string `yaml:"addr,omitempty"` Auth string `yaml:"auth,omitempty"` Host string `yaml:"host,omitempty"` RemoteAddr string `yaml:"remote_addr,omitempty"` } type config struct { ServerAddr string `yaml:"server_addr,omitempty"` InsecureSkipVerify bool `yaml:"insecure_skip_verify,omitempty"` TLSCrt string `yaml:"tls_crt,omitempty"` TLSKey string `yaml:"tls_key,omitempty"` Backoff *backoffConfig `yaml:"backoff,omitempty"` Tunnels map[string]*tunnelConfig `yaml:"tunnels,omitempty"` } var defaultBackoffConfig = backoffConfig{ InitialInterval: 500 * time.Millisecond, Multiplier: 1.5, MaxInterval: 60 * time.Second, MaxElapsedTime: 15 * time.Minute, } func loadConfiguration(path string) (*config, error) { configBuf, err := ioutil.ReadFile(path) if err != nil { return nil, fmt.Errorf("failed to read file %q: %s", path, err) } // deserialize/parse the config var config config if err = yaml.Unmarshal(configBuf, &config); err != nil { return nil, fmt.Errorf("failed to parse file %q: %s", path, err) } // set default values if config.TLSCrt == "" { config.TLSCrt = filepath.Join(filepath.Dir(path), "client.crt") } if config.TLSKey == "" { config.TLSKey = filepath.Join(filepath.Dir(path), "client.key") } if config.Backoff == nil { config.Backoff = &defaultBackoffConfig } else { if config.Backoff.InitialInterval == 0 { config.Backoff.InitialInterval = defaultBackoffConfig.InitialInterval } if config.Backoff.Multiplier == 0 { config.Backoff.Multiplier = defaultBackoffConfig.Multiplier } if config.Backoff.MaxInterval == 0 { config.Backoff.MaxInterval = defaultBackoffConfig.MaxInterval } if config.Backoff.MaxElapsedTime == 0 { config.Backoff.MaxElapsedTime = defaultBackoffConfig.MaxElapsedTime } } // validate and normalize configuration if config.ServerAddr == "" { return nil, fmt.Errorf("server_addr: missing") } if config.ServerAddr, err = normalizeAddress(config.ServerAddr); err != nil { return nil, fmt.Errorf("server_addr: %s", err) } for name, t := range config.Tunnels { switch t.Protocol { case proto.HTTP: if err := validateHTTP(t); err != nil { return nil, fmt.Errorf("%s %s", name, err) } case proto.TCP, proto.TCP4, proto.TCP6: if err := validateTCP(t); err != nil { return nil, fmt.Errorf("%s %s", name, err) } default: return nil, fmt.Errorf("%s invalid protocol %q", name, t.Protocol) } } return &config, nil } func validateHTTP(t *tunnelConfig) error { var err error if t.Host == "" { return fmt.Errorf("host: missing") } if t.Addr == "" { return fmt.Errorf("addr: missing") } if t.Addr, err = normalizeURL(t.Addr); err != nil { return fmt.Errorf("addr: %s", err) } // unexpected if t.RemoteAddr != "" { return fmt.Errorf("remote_addr: unexpected") } return nil } func validateTCP(t *tunnelConfig) error { var err error if t.RemoteAddr, err = normalizeAddress(t.RemoteAddr); err != nil { return fmt.Errorf("remote_addr: %s", err) } if t.Addr == "" { return fmt.Errorf("addr: missing") } if t.Addr, err = normalizeAddress(t.Addr); err != nil { return fmt.Errorf("addr: %s", err) } // unexpected if t.Host != "" { return fmt.Errorf("host: unexpected") } if t.Auth != "" { return fmt.Errorf("auth: unexpected") } return nil }