diff --git a/main.go b/main.go index 9584470..b5cc1b3 100644 --- a/main.go +++ b/main.go @@ -10,14 +10,15 @@ import ( "strings" "time" + "github.com/turnage/graw" "github.com/turnage/graw/reddit" - "github.com/turnage/graw/streams" ) // UA is the user agent for Reddit const UA = `NixOS:tulpa.dev/cadey/snoo2nebby:v0.1.0 (by /u/shadowh511)` var ( + agentFile = flag.String("agent-file", "./var/agent.yml", "the path to the bot's reddit agent config") webhookFile = flag.String("webhook-file", "./var/webhook.txt", "where the Discord webhook file is located") subreddit = flag.String("subreddit", "tulpas", "the subreddit to monitor") pokeFreq = flag.Duration("poke-frequency", 5*time.Minute, "how often the bot should poke the feed") @@ -36,13 +37,49 @@ func clampLen(data string) string { return sb.String() } +type postReplicatingBot struct { + whURL string + bot reddit.Bot +} + +func (p *postReplicatingBot) Post(post *reddit.Post) error { + log.Printf("got new post: by /u/%s: %q %s, NSFW: %v", post.Author, post.URL, post.Title, post.NSFW) + + if post.NSFW && !*allowNSFW { + return nil + } + + wh := Webhook{ + Embeds: []Embed{ + { + Title: post.Title, + URL: post.URL, + Description: clampLen(post.SelfText), + Footer: EmbedFooter{ + Text: "by /u/" + post.Author, + }, + }, + }, + } + + req := Send(p.whURL, wh) + req.Header.Set("User-Agent", UA) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("can't send webhook: %w", err) + } + err = Validate(resp) + if err != nil { + return fmt.Errorf("can't validate response: %w", err) + } + + return nil +} + func main() { flag.Parse() - script, err := reddit.NewScript(UA, *pokeFreq) - if err != nil { - log.Fatal(err) - } + script, err := reddit.NewBotFromAgentFile(*agentFile, 5*time.Minute) whSlc, err := ioutil.ReadFile(*webhookFile) if err != nil { @@ -50,53 +87,11 @@ func main() { } whURL := string(bytes.TrimSpace(whSlc)) - kill := make(chan bool) - errs := make(chan error) - - stream, err := streams.Subreddits(script, kill, errs, *subreddit) - if err != nil { - log.Fatal(err) - } - - go func(errs chan error) { - for err := range errs { - log.Printf("%v", err) - } - }(errs) - log.Printf("listening for new posts on /r/%s", *subreddit) - for post := range stream { - log.Printf("got new post: by /u/%s: %q %s, NSFW: %v", post.Author, post.URL, post.Title, post.NSFW) - - if post.NSFW && !*allowNSFW { - continue - } - - wh := Webhook{ - Embeds: []Embed{ - { - Title: post.Title, - URL: post.URL, - Description: clampLen(post.SelfText), - Footer: EmbedFooter{ - Text: "by /u/" + post.Author, - }, - }, - }, - } - - req := Send(whURL, wh) - req.Header.Set("User-Agent", UA) - resp, err := http.DefaultClient.Do(req) - if err != nil { - errs <- fmt.Errorf("can't send webhook: %w", err) - continue - } - err = Validate(resp) - if err != nil { - errs <- fmt.Errorf("can't validate response: %w", err) - continue - } + if _, wait, err := graw.Run(&postReplicatingBot{whURL, script}, script, graw.Config{Subreddits: []string{*subreddit}}); err != nil { + log.Fatalf("can't connect to reddit: %v", err) + } else { + log.Fatalf("graw run failed: %v", wait()) } }