diff --git a/lib/tun2/client.go b/lib/tun2/client.go new file mode 100644 index 0000000..54d110c --- /dev/null +++ b/lib/tun2/client.go @@ -0,0 +1,126 @@ +package tun2 + +import ( + "crypto/tls" + "encoding/json" + "errors" + "net" + "net/http" + "net/http/httputil" + "net/url" + + kcp "github.com/xtaci/kcp-go" + "github.com/xtaci/smux" +) + +type Client struct { + cfg *ClientConfig +} + +type ClientConfig struct { + TLSConfig *tls.Config + ConnType string + ServerAddr string + Token string + Domain string + BackendURL string +} + +func NewClient(cfg *ClientConfig) (*Client, error) { + if cfg == nil { + return nil, errors.New("tun2: client config needed") + } + + c := &Client{ + cfg: cfg, + } + + return c, nil +} + +func (c *Client) connect(serverAddr string) error { + target, err := url.Parse(c.cfg.BackendURL) + if err != nil { + return err + } + + s := &http.Server{ + Handler: httputil.NewSingleHostReverseProxy(target), + } + + var conn net.Conn + + switch c.cfg.ConnType { + case "tcp": + conn, err = tls.Dial("tcp", serverAddr, c.cfg.TLSConfig) + if err != nil { + return err + } + + case "kcp": + kc, err := kcp.Dial(serverAddr) + if err != nil { + return err + } + defer kc.Close() + + serverHost, _, _ := net.SplitHostPort(serverAddr) + + tc := c.cfg.TLSConfig.Clone() + tc.ServerName = serverHost + conn = tls.Client(kc, tc) + } + defer conn.Close() + + session, err := smux.Client(conn, smux.DefaultConfig()) + if err != nil { + return err + } + defer session.Close() + + controlStream, err := session.AcceptStream() + if err != nil { + return err + } + + authData, err := json.Marshal(&Auth{ + Token: c.cfg.Token, + Domain: c.cfg.Domain, + }) + if err != nil { + return err + } + + _, err = controlStream.Write(authData) + if err != nil { + return err + } + + err = s.Serve(&smuxListener{ + conn: conn, + session: session, + }) + + if err != nil { + return err + } + + return nil +} + +type smuxListener struct { + conn net.Conn + session *smux.Session +} + +func (sl *smuxListener) Accept() (net.Conn, error) { + return sl.session.AcceptStream() +} + +func (sl *smuxListener) Addr() net.Addr { + return sl.conn.LocalAddr() +} + +func (sl *smuxListener) Close() error { + return sl.session.Close() +} diff --git a/lib/tun2/server.go b/lib/tun2/server.go new file mode 100644 index 0000000..29d568e --- /dev/null +++ b/lib/tun2/server.go @@ -0,0 +1,271 @@ +package tun2 + +import ( + "bufio" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "math/rand" + "net" + "net/http" + "sync" + + "git.xeserv.us/xena/route/database" + "github.com/Xe/ln" + "github.com/mtneug/pkg/ulid" + kcp "github.com/xtaci/kcp-go" + "github.com/xtaci/smux" +) + +type ServerConfig struct { + TCPAddr string + KCPAddr string + TLSConfig *tls.Config + + SmuxConf *smux.Config + Storage Storage +} + +type Storage interface { + GetRouteForHost(name string) (*database.Route, error) + //ValidateToken(token string) (username string, ok bool, err error) // XXX RIP implement when users are implemented +} + +type Server struct { + cfg *ServerConfig + + connlock sync.Mutex + conns map[net.Conn]*Connection + + domainlock sync.Mutex + domains map[string][]*Connection +} + +type Connection struct { + id string + conn net.Conn + isKCP bool + session *smux.Session + controlStream *smux.Stream + user string + domain string +} + +func NewServer(cfg *ServerConfig) (*Server, error) { + if cfg == nil { + return nil, errors.New("tun2: config must be specified") + } + + if cfg.SmuxConf == nil { + cfg.SmuxConf = smux.DefaultConfig() + } + + server := &Server{ + cfg: cfg, + + conns: map[net.Conn]*Connection{}, + domains: map[string][]*Connection{}, + } + + return server, nil +} + +func (s *Server) ListenAndServe() error { + if s.cfg.TCPAddr != "" { + go func() { + l, err := tls.Listen("tcp", s.cfg.TCPAddr, s.cfg.TLSConfig) + if err != nil { + panic(err) + } + + for { + conn, err := l.Accept() + if err != nil { + ln.Error(err, ln.F{"kind": "tcp", "addr": l.Addr().String()}) + continue + } + + go s.HandleConn(conn, false) + } + }() + } + + if s.cfg.KCPAddr != "" { + go func() { + l, err := kcp.Listen(s.cfg.KCPAddr) + if err != nil { + panic(err) + } + + for { + conn, err := l.Accept() + if err != nil { + ln.Error(err, ln.F{"kind": "kcp", "addr": l.Addr().String()}) + } + + tc := tls.Server(conn, s.cfg.TLSConfig) + + go s.HandleConn(tc, true) + } + }() + } + + return nil +} + +func (s *Server) HandleConn(c net.Conn, isKCP bool) { + session, err := smux.Server(c, s.cfg.SmuxConf) + if err != nil { + ln.Error(err, ln.F{ + "action": "session_failure", + "local": c.LocalAddr().String(), + "remote": c.RemoteAddr().String(), + }) + + c.Close() + + return + } + + controlStream, err := session.OpenStream() + if err != nil { + ln.Error(err, ln.F{ + "action": "control_stream_failure", + "local": c.LocalAddr().String(), + "remote": c.RemoteAddr().String(), + }) + + session.Close() + c.Close() + + return + } + + csd := json.NewDecoder(controlStream) + auth := &Auth{} + err = csd.Decode(auth) + if err != nil { + ln.Error(err, ln.F{ + "action": "control_stream_auth_decoding_failure", + "local": c.LocalAddr().String(), + "remote": c.RemoteAddr().String(), + }) + + controlStream.Close() + session.Close() + c.Close() + + return + } + + route, err := s.cfg.Storage.GetRouteForHost(auth.Domain) + if err != nil { + ln.Error(err, ln.F{ + "action": "nosuch_domain", + "local": c.LocalAddr().String(), + "remote": c.RemoteAddr().String(), + }) + + controlStream.Close() + session.Close() + c.Close() + + return + } + + if route.Token != auth.Token { + ln.Error(err, ln.F{ + "action": "bad_token", + "local": c.LocalAddr().String(), + "remote": c.RemoteAddr().String(), + }) + + fmt.Fprintln(controlStream, "bad token") + + controlStream.Close() + session.Close() + c.Close() + + return + } + + connection := &Connection{ + id: ulid.New().String(), + conn: c, + isKCP: isKCP, + session: session, + user: defaultUser, // XXX RIP replace this with the actual token user once users are implemented + domain: auth.Domain, + } + + s.connlock.Lock() + s.conns[c] = connection + s.connlock.Unlock() + + s.domainlock.Lock() + s.domains[auth.Domain] = append(s.domains[auth.Domain], connection) + s.domainlock.Unlock() +} + +func (s *Server) RoundTrip(req *http.Request) (*http.Response, error) { + s.domainlock.Lock() + conns, ok := s.domains[req.Host] + s.domainlock.Unlock() + if !ok { + return nil, errors.New("domain not found") + } + + c := conns[rand.Intn(len(conns))] + + stream, err := c.session.OpenStream() + if err != nil { + ln.Error(err, ln.F{ + "action": "opening_session_stream", + "backend": c.conn.RemoteAddr().String(), + "remote_addr": req.RemoteAddr, + "host": req.Host, + "uri": req.RequestURI, + }) + + return nil, err + } + defer stream.Close() + + err = req.Write(stream) + if err != nil { + ln.Error(err, ln.F{ + "action": "request_writing", + "backend": c.conn.RemoteAddr().String(), + "remote_addr": req.RemoteAddr, + "host": req.Host, + "uri": req.RequestURI, + }) + + return nil, err + } + + buf := bufio.NewReader(stream) + + resp, err := http.ReadResponse(buf, req) + if err != nil { + ln.Error(err, ln.F{ + "action": "response_reading", + "backend": c.conn.RemoteAddr().String(), + "remote_addr": req.RemoteAddr, + "host": req.Host, + "uri": req.RequestURI, + }) + + return nil, err + } + + return resp, nil +} + +type Auth struct { + Token string `json:"token"` + Domain string `json:"domain"` +} + +const defaultUser = "Cadey"