package main import ( "bytes" "context" "crypto/tls" "fmt" "io" "io/ioutil" "log" "math/rand" "net" "net/http" "os" "sync" "time" "github.com/hashicorp/yamux" "within.website/ln" "within.website/x/localca" ) // Config uration for the server. type Config struct { HTTPPort, HTTPSPort, YamuxPort, StatusPort, DomainSuffix string } // F ields for logging. func (c Config) F() ln.F { return ln.F{ "http-port": c.HTTPPort, "https-port": c.HTTPSPort, "yamux-port": c.YamuxPort, "status-port": c.StatusPort, "domain-suffix": c.DomainSuffix, } } // Server is the iconia gateway server. type Server struct { Config clients map[string][]*yamux.Session clientsLock *sync.RWMutex certManager localca.Manager plainServer, statusServer *http.Server tlsListener, yamuxListener net.Listener tokenInfo map[string]string tokensLock *sync.Mutex } func (s *Server) goAwayClients() { s.clientsLock.Lock() defer s.clientsLock.Unlock() for _, set := range s.clients { for _, sesh := range set { sesh.GoAway() } } } func (s *Server) killClients() { s.clientsLock.RLock() defer s.clientsLock.RUnlock() for _, set := range s.clients { for _, sesh := range set { _ = sesh.Close() } } } func (s *Server) handleYamuxClientHello(chi *tls.ClientHelloInfo) (*tls.Config, error) { var found bool s.tokensLock.Lock() var token = s.tokenInfo[chi.ServerName] s.tokensLock.Unlock() for _, proto := range chi.SupportedProtos { if proto == token { found = true break } } if !found { return nil, fmt.Errorf("unknown token for domain %s", chi.ServerName) } tc := &tls.Config{ GetCertificate: s.certManager.GetCertificate, NextProtos: []string{token}, ServerName: chi.ServerName, } return tc, nil } func gen502Page(host, why string) *http.Response { template := `${WHY}

${WHY}

Please ensure a backend is running for ${HOST}.

` resbody := []byte(os.Expand(template, func(in string) string { switch in { case "HOST": return host case "WHY": return why } return "" })) reshdr := http.Header{} reshdr.Set("Content-Type", "text/html; charset=utf-8") resp := &http.Response{ Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway), StatusCode: http.StatusBadGateway, Body: ioutil.NopCloser(bytes.NewBuffer(resbody)), Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: reshdr, ContentLength: int64(len(resbody)), Close: true, Request: nil, } return resp } func (s *Server) yamuxHandler(l net.Listener) error { for { c, err := l.Accept() if err != nil { return err } go s.handleYamuxClient(c) } panic("unexpected state") } func (s *Server) handleYamuxClient(c net.Conn) { tlsConn, ok := c.(*tls.Conn) if !ok { panic("no, this should really be impossible") } tlsConn.Handshake() sName := tlsConn.ConnectionState().ServerName ctx := context.Background() ctx = ln.WithF(ctx, ln.F{ "domain-name": sName, "remote-host": c.RemoteAddr().String(), }) sesh, err := yamux.Server(c, &yamux.Config{ AcceptBacklog: 1, EnableKeepAlive: true, KeepAliveInterval: time.Minute, ConnectionWriteTimeout: 100 * time.Millisecond, MaxStreamWindowSize: 262144 * 16, Logger: log.New(os.Stderr, sName+": ", log.LstdFlags), }) if err != nil { ln.Error(ctx, err) c.Close() return } s.clientsLock.Lock() s.clients[sName] = append(s.clients[sName], sesh) i := len(s.clients[sName]) - 1 s.clientsLock.Unlock() ln.Log(ctx, ln.Info("agent registered")) go func() { <-sesh.CloseChan() ln.Log(ctx, ln.Info("client closed")) s.clientsLock.Lock() s.clients[sName] = append(s.clients[sName][:i], s.clients[sName][i+1:]...) s.clientsLock.Unlock() }() } func (s *Server) tlsForward(l net.Listener) error { for { c, err := l.Accept() if err != nil { return fmt.Errorf("error accepting connection: %w", err) } go s.handleTLSClient(c) } panic("unexpected state") } func (s *Server) handleTLSClient(c net.Conn) { tlsConn, ok := c.(*tls.Conn) if !ok { gen502Page("unknown", "this should be impossible").Write(c) c.Close() return } tlsConn.Handshake() sName := tlsConn.ConnectionState().ServerName s.clientsLock.RLock() set, ok := s.clients[sName] s.clientsLock.RUnlock() if !ok || len(set) == 0 { gen502Page(sName, "no backends connected").Write(c) c.Close() return } var ( sesh *yamux.Session stream *yamux.Stream count int err error ) retry: sesh = set[rand.Intn(len(set))] stream, err = sesh.OpenStream() if err != nil { if count > 3 { gen502Page(sName, "no working session").Write(c) c.Close() return } count++ goto retry } go io.Copy(c, stream) io.Copy(stream, c) }