package tun2

import (
	"bytes"
	"context"
	"crypto/tls"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"math/rand"
	"net"
	"net/http"
	"os"
	"sync"
	"time"

	"github.com/Xe/ln"
	failure "github.com/dgryski/go-failure"
	"github.com/mtneug/pkg/ulid"
	cmap "github.com/streamrail/concurrent-map"
	kcp "github.com/xtaci/kcp-go"
	"github.com/xtaci/smux"
)

// Error values
var (
	ErrNoSuchBackend             = errors.New("tun2: there is no such backend")
	ErrAuthMismatch              = errors.New("tun2: authenication doesn't match database records")
	ErrCantRemoveWhatDoesntExist = errors.New("tun2: this connection does not exist, cannot remove it")
)

// ServerConfig ...
type ServerConfig struct {
	TCPAddr   string
	KCPAddr   string
	TLSConfig *tls.Config

	SmuxConf *smux.Config
	Storage  Storage
}

// Storage is the minimal subset of features that tun2's Server needs out of a
// persistence layer.
type Storage interface {
	HasToken(token string) (user string, scopes []string, err error)
	HasRoute(domain string) (user string, err error)
}

// Server routes frontend HTTP traffic to backend TCP traffic.
type Server struct {
	cfg *ServerConfig

	connlock sync.Mutex
	conns    map[net.Conn]*Connection

	domains cmap.ConcurrentMap
}

// NewServer creates a new Server instance with a given config, acquiring all
// relevant resources.
func NewServer(cfg *ServerConfig) (*Server, error) {
	if cfg == nil {
		return nil, errors.New("tun2: config must be specified")
	}

	if cfg.SmuxConf == nil {
		cfg.SmuxConf = smux.DefaultConfig()
	}

	cfg.SmuxConf.KeepAliveInterval = time.Second
	cfg.SmuxConf.KeepAliveTimeout = 15 * time.Second

	server := &Server{
		cfg: cfg,

		conns:   map[net.Conn]*Connection{},
		domains: cmap.New(),
	}

	return server, nil
}

type backendMatcher func(*Connection) bool

func (s *Server) getBackendsForMatcher(bm backendMatcher) []Backend {
	s.connlock.Lock()
	defer s.connlock.Unlock()

	var result []Backend

	for _, c := range s.conns {
		if !bm(c) {
			continue
		}

		protocol := "tcp"
		if c.isKCP {
			protocol = "kcp"
		}

		result = append(result, Backend{
			ID:     c.id,
			Proto:  protocol,
			User:   c.user,
			Domain: c.domain,
			Phi:    float32(c.detector.Phi(time.Now())),
			Host:   c.conn.RemoteAddr().String(),
			Usable: c.usable,
		})
	}

	return result
}

func (s *Server) KillBackend(id string) error {
	s.connlock.Lock()
	defer s.connlock.Unlock()

	for _, c := range s.conns {
		if c.id == id {
			c.cancel()
			return nil
		}
	}

	return ErrNoSuchBackend
}

func (s *Server) GetBackendsForDomain(domain string) []Backend {
	return s.getBackendsForMatcher(func(c *Connection) bool {
		return c.domain == domain
	})
}

func (s *Server) GetBackendsForUser(uname string) []Backend {
	return s.getBackendsForMatcher(func(c *Connection) bool {
		return c.user == uname
	})
}

func (s *Server) GetAllBackends() []Backend {
	return s.getBackendsForMatcher(func(*Connection) bool { return true })
}

// ListenAndServe starts the backend TCP/KCP listeners and relays backend
// traffic to and from them.
func (s *Server) ListenAndServe() error {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	ln.Log(ctx, ln.F{
		"action": "listen_and_serve_called",
	})

	if s.cfg.TCPAddr != "" {
		go func() {
			l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig)
			if err != nil {
				panic(err)
			}

			ln.Log(ctx, ln.F{
				"action": "tcp+tls_listening",
				"addr":   l.Addr(),
			})

			for {
				conn, err := l.Accept()
				if err != nil {
					ln.Error(ctx, err, ln.F{"kind": "tcp", "addr": l.Addr().String()})
					continue
				}

				ln.Log(ctx, ln.F{
					"action": "new_client",
					"kcp":    false,
					"addr":   conn.RemoteAddr(),
				})

				go s.HandleConn(conn, false)
			}
		}()
	}

	if s.cfg.KCPAddr != "" {
		go func() {
			l, err := kcp.Listen(s.cfg.KCPAddr)
			if err != nil {
				panic(err)
			}

			ln.Log(ctx, ln.F{
				"action": "kcp+tls_listening",
				"addr":   l.Addr(),
			})

			for {
				conn, err := l.Accept()
				if err != nil {
					ln.Error(ctx, err, ln.F{"kind": "kcp", "addr": l.Addr().String()})
				}

				ln.Log(ctx, ln.F{
					"action": "new_client",
					"kcp":    true,
					"addr":   conn.RemoteAddr(),
				})

				tc := tls.Server(conn, s.cfg.TLSConfig)

				go s.HandleConn(tc, true)
			}
		}()
	}

	// XXX experimental, might get rid of this inside this process
	go func() {
		for {
			time.Sleep(time.Second)

			now := time.Now()

			s.connlock.Lock()
			for _, c := range s.conns {
				failureChance := c.detector.Phi(now)

				if failureChance > 0.8 {
					ln.Log(ctx, c.F(), ln.F{
						"action": "phi_failure_detection",
						"value":  failureChance,
					})
				}
			}
			s.connlock.Unlock()
		}
	}()

	return nil
}

// HandleConn starts up the needed mechanisms to relay HTTP traffic to/from
// the currently connected backend.
func (s *Server) HandleConn(c net.Conn, isKCP bool) {
	// XXX TODO clean this up it's really ugly.
	defer c.Close()

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	session, err := smux.Server(c, s.cfg.SmuxConf)
	if err != nil {
		ln.Error(ctx, err, ln.F{
			"action": "session_failure",
			"local":  c.LocalAddr().String(),
			"remote": c.RemoteAddr().String(),
		})

		c.Close()

		return
	}
	defer session.Close()

	controlStream, err := session.OpenStream()
	if err != nil {
		ln.Error(ctx, err, ln.F{
			"action": "control_stream_failure",
			"local":  c.LocalAddr().String(),
			"remote": c.RemoteAddr().String(),
		})

		return
	}
	defer controlStream.Close()

	csd := json.NewDecoder(controlStream)
	auth := &Auth{}
	err = csd.Decode(auth)
	if err != nil {
		ln.Error(ctx, err, ln.F{
			"action": "control_stream_auth_decoding_failure",
			"local":  c.LocalAddr().String(),
			"remote": c.RemoteAddr().String(),
		})

		return
	}

	routeUser, err := s.cfg.Storage.HasRoute(auth.Domain)
	if err != nil {
		ln.Error(ctx, err, ln.F{
			"action": "nosuch_domain",
			"local":  c.LocalAddr().String(),
			"remote": c.RemoteAddr().String(),
		})

		return
	}

	tokenUser, scopes, err := s.cfg.Storage.HasToken(auth.Token)
	if err != nil {
		ln.Error(ctx, err, ln.F{
			"action": "nosuch_token",
			"local":  c.LocalAddr().String(),
			"remote": c.RemoteAddr().String(),
		})

		return
	}

	ok := false
	for _, sc := range scopes {
		if sc == "connect" {
			ok = true
			break
		}
	}

	if !ok {
		ln.Error(ctx, ErrAuthMismatch, ln.F{
			"action": "token_not_authorized",
			"local":  c.LocalAddr().String(),
			"remote": c.RemoteAddr().String(),
		})
	}

	if routeUser != tokenUser {
		ln.Error(ctx, ErrAuthMismatch, ln.F{
			"action": "auth_mismatch",
			"local":  c.LocalAddr().String(),
			"remote": c.RemoteAddr().String(),
		})

		return
	}

	connection := &Connection{
		id:       ulid.New().String(),
		conn:     c,
		isKCP:    isKCP,
		session:  session,
		user:     tokenUser,
		domain:   auth.Domain,
		cf:       cancel,
		detector: failure.New(15, 1),
		Auth:     auth,
	}

	defer func() {
		if r := recover(); r != nil {
			ln.Log(ctx, connection, ln.F{"action": "connection handler panic", "err": r})
		}
	}()

	ln.Log(ctx, ln.F{
		"action": "backend_connected",
	}, connection.F())

	s.connlock.Lock()
	s.conns[c] = connection
	s.connlock.Unlock()

	var conns []*Connection

	val, ok := s.domains.Get(auth.Domain)
	if ok {
		conns, ok = val.([]*Connection)
		if !ok {
			conns = nil

			s.domains.Remove(auth.Domain)
		}
	}

	conns = append(conns, connection)

	s.domains.Set(auth.Domain, conns)
	connection.usable = true

	ticker := time.NewTicker(5 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			err := connection.Ping()
			if err != nil {
				connection.cancel()
			}
		case <-ctx.Done():
			s.RemoveConn(ctx, connection)
			connection.Close()

			return
		}
	}
}

// RemoveConn removes a connection.
func (s *Server) RemoveConn(ctx context.Context, connection *Connection) {
	s.connlock.Lock()
	delete(s.conns, connection.conn)
	s.connlock.Unlock()

	auth := connection.Auth

	var conns []*Connection

	val, ok := s.domains.Get(auth.Domain)
	if ok {
		conns, ok = val.([]*Connection)
		if !ok {
			ln.Error(ctx, ErrCantRemoveWhatDoesntExist, connection.F(), ln.F{
				"action": "looking_up_for_disconnect_removal",
			})
			return
		}
	}

	for i, cntn := range conns {
		if cntn.id == connection.id {
			conns[i] = conns[len(conns)-1]
			conns = conns[:len(conns)-1]
		}
	}

	if len(conns) != 0 {
		s.domains.Set(auth.Domain, conns)
	} else {
		s.domains.Remove(auth.Domain)
	}

	ln.Log(ctx, connection.F(), ln.F{
		"action": "client_disconnecting",
	})
}

func gen502Page(req *http.Request) *http.Response {
	template := `<html><head><title>no backends connected</title></head><body><h1>no backends connected</h1><p>Please ensure a backend is running for ${HOST}. This is request ID ${REQ_ID}.</p></body></html>`

	resbody := []byte(os.Expand(template, func(in string) string {
		switch in {
		case "HOST":
			return req.Host
		case "REQ_ID":
			return req.Header.Get("X-Request-Id")
		}

		return "<unknown>"
	}))
	reshdr := req.Header
	reshdr.Set("Content-Type", "text/html; charset=utf-8")

	resp := &http.Response{
		Status:     fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway),
		StatusCode: http.StatusBadGateway,
		Body:       ioutil.NopCloser(bytes.NewBuffer(resbody)),

		Proto:         req.Proto,
		ProtoMajor:    req.ProtoMajor,
		ProtoMinor:    req.ProtoMinor,
		Header:        reshdr,
		ContentLength: int64(len(resbody)),
		Close:         true,
		Request:       req,
	}

	return resp
}

// RoundTrip sends a HTTP request to a backend and then returns its response.
func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) {
	var conns []*Connection
	ctx := req.Context()

	val, ok := s.domains.Get(req.Host)
	if ok {
		conns, ok = val.([]*Connection)
		if !ok {
			ln.Error(ctx, ErrNoSuchBackend, ln.F{
				"action": "no_backend_connected",
				"remote": req.RemoteAddr,
				"host":   req.Host,
				"uri":    req.RequestURI,
			})

			return gen502Page(req), nil
		}
	}

	var goodConns []*Connection
	for _, conn := range conns {
		if conn.usable {
			goodConns = append(goodConns, conn)
		}
	}

	if len(goodConns) == 0 {
		ln.Error(ctx, ErrNoSuchBackend, ln.F{
			"action": "no_backend_connected",
			"remote": req.RemoteAddr,
			"host":   req.Host,
			"uri":    req.RequestURI,
		})

		return gen502Page(req), nil
	}

	c := goodConns[rand.Intn(len(goodConns))]

	resp, err := c.RoundTrip(req)
	if err != nil {
		ln.Error(ctx, err, c, ln.F{
			"action": "connection_roundtrip",
		})

		defer c.cancel()
		return nil, err
	}

	ln.Log(ctx, c, ln.F{
		"action":         "http traffic",
		"remote_addr":    req.RemoteAddr,
		"host":           req.Host,
		"uri":            req.URL.Path,
		"status":         resp.Status,
		"status_code":    resp.StatusCode,
		"content_length": resp.ContentLength,
	})

	return resp, nil
}

// Auth is the authentication info the client passes to the server.
type Auth struct {
	Token  string `json:"token"`
	Domain string `json:"domain"`
}

const defaultUser = "Cadey"