package tun2

import (
	"crypto/tls"
	"encoding/json"
	"errors"
	"log"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"

	kcp "github.com/xtaci/kcp-go"
	"github.com/xtaci/smux"
)

type Client struct {
	cfg *ClientConfig
}

type ClientConfig struct {
	TLSConfig  *tls.Config
	ConnType   string
	ServerAddr string
	Token      string
	Domain     string
	BackendURL string
}

func NewClient(cfg *ClientConfig) (*Client, error) {
	if cfg == nil {
		return nil, errors.New("tun2: client config needed")
	}

	c := &Client{
		cfg: cfg,
	}

	return c, nil
}

func (c *Client) Connect() error {
	return c.connect(c.cfg.ServerAddr)
}

func (c *Client) connect(serverAddr string) error {
	target, err := url.Parse(c.cfg.BackendURL)
	if err != nil {
		return err
	}

	s := &http.Server{
		Handler: httputil.NewSingleHostReverseProxy(target),
	}

	var conn net.Conn

	switch c.cfg.ConnType {
	case "tcp":
		conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig)
		if err != nil {
			return err
		}

	case "kcp":
		kc, err := kcp.Dial(serverAddr)
		if err != nil {
			return err
		}
		defer kc.Close()

		serverHost, _, _ := net.SplitHostPort(serverAddr)

		tc := c.cfg.TLSConfig.Clone()
		tc.ServerName = serverHost
		conn = tls.Client(kc, tc)
	}
	defer conn.Close()

	log.Printf("tun2: connected to %s (%v)", conn.RemoteAddr(), c.cfg.ConnType)

	session, err := smux.Client(conn, smux.DefaultConfig())
	if err != nil {
		return err
	}
	defer session.Close()

	controlStream, err := session.AcceptStream()
	if err != nil {
		return err
	}

	authData, err := json.Marshal(&Auth{
		Token:  c.cfg.Token,
		Domain: c.cfg.Domain,
	})
	if err != nil {
		return err
	}

	_, err = controlStream.Write(authData)
	if err != nil {
		return err
	}

	log.Println("tun2: client set up and waiting for requests")

	err = s.Serve(&smuxListener{
		conn:    conn,
		session: session,
	})

	if err != nil {
		return err
	}

	return nil
}

type smuxListener struct {
	conn    net.Conn
	session *smux.Session
}

var (
	_ net.Listener = &smuxListener{} // interface check
)

func (sl *smuxListener) Accept() (net.Conn, error) {
	return sl.session.AcceptStream()
}

func (sl *smuxListener) Addr() net.Addr {
	return sl.conn.LocalAddr()
}

func (sl *smuxListener) Close() error {
	return sl.session.Close()
}