diff --git a/cmd/routed/main.go b/cmd/routed/main.go index 1f96c5c..30327a4 100644 --- a/cmd/routed/main.go +++ b/cmd/routed/main.go @@ -7,6 +7,9 @@ import ( "math/rand" "net" "net/http" + "os" + "os/signal" + "sync" "time" _ "git.xeserv.us/xena/route/internal" @@ -27,13 +30,17 @@ func main() { flag.Parse() flagenv.Parse() rand.Seed(time.Now().Unix()) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - certKey, _ := routecrypto.ParseKey(*sslCertKey) + certKey, err := routecrypto.ParseKey(*sslCertKey) + if err != nil { + ln.FatalErr(ctx, err, ln.Action("parse cert key")) + } scfg := Config{} - err := env.Parse(&scfg) + err = env.Parse(&scfg) if err != nil { ln.FatalErr(ctx, err, ln.Action("parsing environment for config")) } @@ -44,9 +51,27 @@ func main() { ln.FatalErr(ctx, err, ln.Action("create server instance")) } - go setupQuic(s, scfg) - go setupTLS(s, scfg) + wg := &sync.WaitGroup{} + go setupQuic(ctx, wg, s, scfg) + go setupTLS(ctx, wg, s, scfg) + go setupHTTP(ctx, wg, s, scfg) + + ch := make(chan os.Signal, 2) + + go func() { + val := <-ch + + ln.Log(ctx, ln.F{"signal": val.String()}, ln.Action("signal recieved")) + cancel() + }() + + signal.Notify(ch, os.Interrupt) + + wg.Wait() +} + +func setupHTTP(ctx context.Context, wg *sync.WaitGroup, s *Server, scfg Config) { // listen on HTTP listener l, err := net.Listen("tcp", scfg.WebAddr) if err != nil { @@ -59,10 +84,24 @@ func main() { Addr: scfg.WebAddr, } - hs.Serve(l) + go ln.FatalErr(ctx, hs.Serve(l)) + + for { + select { + case <-ctx.Done(): + hs.SetKeepAlivesEnabled(false) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + hs.Shutdown(ctx) + wg.Done() + + return + } + } } -func setupQuic(s *Server, scfg Config) { +func setupQuic(ctx context.Context, wg *sync.WaitGroup, s *Server, scfg Config) { qs := &h2quic.Server{ Server: &http.Server{ Handler: middleware.Trace(s), @@ -78,12 +117,25 @@ func setupQuic(s *Server, scfg Config) { s.QuicServer = qs + go ln.FatalErr(context.Background(), qs.ListenAndServe()) + wg.Add(1) + for { - ln.FatalErr(context.Background(), qs.ListenAndServe()) + select { + case <-ctx.Done(): + qs.SetKeepAlivesEnabled(false) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + qs.Shutdown(ctx) + wg.Done() + + return + } } } -func setupTLS(s *Server, scfg Config) { +func setupTLS(ctx context.Context, wg *sync.WaitGroup, s *Server, scfg Config) { hs := &http.Server{ Handler: middleware.Trace(s), Addr: scfg.SSLAddr, @@ -95,7 +147,20 @@ func setupTLS(s *Server, scfg Config) { ReadHeaderTimeout: time.Second, } + go ln.FatalErr(context.Background(), hs.ListenAndServeTLS("", "")) + wg.Add(1) + for { - ln.FatalErr(context.Background(), hs.ListenAndServeTLS("", "")) + select { + case <-ctx.Done(): + hs.SetKeepAlivesEnabled(false) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + hs.Shutdown(ctx) + wg.Done() + + return + } } }