iconia/server.go

243 lines
4.6 KiB
Go

package main
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"os"
"sync"
"time"
"github.com/hashicorp/yamux"
"within.website/ln"
"within.website/x/localca"
)
// Config uration for the server.
type Config struct {
HTTPPort, HTTPSPort, YamuxPort, StatusPort, DomainSuffix string
}
// F ields for logging.
func (c Config) F() ln.F {
return ln.F{
"http-port": c.HTTPPort,
"https-port": c.HTTPSPort,
"yamux-port": c.YamuxPort,
"status-port": c.StatusPort,
"domain-suffix": c.DomainSuffix,
}
}
// Server is the iconia gateway server.
type Server struct {
Config
clients map[string][]*yamux.Session
clientsLock *sync.RWMutex
certManager localca.Manager
plainServer, statusServer *http.Server
tlsListener, yamuxListener net.Listener
tokenInfo map[string]string
tokensLock *sync.Mutex
}
func (s *Server) goAwayClients() {
s.clientsLock.Lock()
defer s.clientsLock.Unlock()
for _, set := range s.clients {
for _, sesh := range set {
sesh.GoAway()
}
}
}
func (s *Server) handleYamuxClientHello(chi *tls.ClientHelloInfo) (*tls.Config, error) {
var found bool
s.tokensLock.Lock()
var token = s.tokenInfo[chi.ServerName]
s.tokensLock.Unlock()
for _, proto := range chi.SupportedProtos {
if proto == token {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("unknown token for domain %s", chi.ServerName)
}
tc := &tls.Config{
GetCertificate: s.certManager.GetCertificate,
NextProtos: []string{token},
ServerName: chi.ServerName,
}
return tc, nil
}
func gen502Page(host, why string) *http.Response {
template := `<html><head><title>${WHY}</title></head><body><h1>${WHY}</h1><p>Please ensure a backend is running for ${HOST}.</p></body></html>`
resbody := []byte(os.Expand(template, func(in string) string {
switch in {
case "HOST":
return host
case "WHY":
return why
}
return "<unknown>"
}))
reshdr := http.Header{}
reshdr.Set("Content-Type", "text/html; charset=utf-8")
resp := &http.Response{
Status: fmt.Sprintf("%d Bad Gateway", http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Body: ioutil.NopCloser(bytes.NewBuffer(resbody)),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: reshdr,
ContentLength: int64(len(resbody)),
Close: true,
Request: nil,
}
return resp
}
func (s *Server) yamuxHandler(l net.Listener) error {
for {
c, err := l.Accept()
if err != nil {
return err
}
go s.handleYamuxClient(c)
}
panic("unexpected state")
}
func (s *Server) handleYamuxClient(c net.Conn) {
tlsConn, ok := c.(*tls.Conn)
if !ok {
panic("no, this should really be impossible")
}
tlsConn.Handshake()
sName := tlsConn.ConnectionState().ServerName
ctx := context.Background()
ctx = ln.WithF(ctx, ln.F{
"domain-name": sName,
"remote-host": c.RemoteAddr().String(),
})
sesh, err := yamux.Server(c, &yamux.Config{
AcceptBacklog: 1,
EnableKeepAlive: true,
KeepAliveInterval: time.Minute,
ConnectionWriteTimeout: 100 * time.Millisecond,
MaxStreamWindowSize: 262144 * 16,
Logger: log.New(os.Stderr, sName+": ", log.LstdFlags),
})
if err != nil {
ln.Error(ctx, err)
c.Close()
return
}
s.clientsLock.Lock()
s.clients[sName] = append(s.clients[sName], sesh)
i := len(s.clients[sName]) - 1
s.clientsLock.Unlock()
ln.Log(ctx, ln.Info("agent registered"))
go func() {
<-sesh.CloseChan()
ln.Log(ctx, ln.Info("client closed"))
s.clientsLock.Lock()
s.clients[sName] = append(s.clients[sName][:i], s.clients[sName][i+1:]...)
s.clientsLock.Unlock()
}()
}
func (s *Server) tlsForward(l net.Listener) error {
for {
c, err := l.Accept()
if err != nil {
return fmt.Errorf("error accepting connection: %w", err)
}
go s.handleTLSClient(c)
}
panic("unexpected state")
}
func (s *Server) handleTLSClient(c net.Conn) {
tlsConn, ok := c.(*tls.Conn)
if !ok {
gen502Page("unknown", "this should be impossible").Write(c)
c.Close()
return
}
tlsConn.Handshake()
sName := tlsConn.ConnectionState().ServerName
s.clientsLock.RLock()
set, ok := s.clients[sName]
s.clientsLock.RUnlock()
if !ok || len(set) == 0 {
gen502Page(sName, "no backends connected").Write(c)
c.Close()
return
}
var (
sesh *yamux.Session
stream *yamux.Stream
count int
err error
)
retry:
sesh = set[rand.Intn(len(set))]
stream, err = sesh.OpenStream()
if err != nil {
if count > 3 {
gen502Page(sName, "no working session").Write(c)
c.Close()
return
}
count++
goto retry
}
go io.Copy(c, stream)
io.Copy(stream, c)
}