diff --git a/migrations/csrfToken/main.go b/migrations/csrfToken/main.go new file mode 100644 index 0000000..fcd49f2 --- /dev/null +++ b/migrations/csrfToken/main.go @@ -0,0 +1,79 @@ +package main + +import ( + "log" + "math/rand" + "os" + "path/filepath" + "time" + + "bloat/config" + "bloat/kv" + "bloat/repository" + "bloat/util" +) + +var ( + configFile = "bloat.conf" +) + +func init() { + rand.Seed(time.Now().Unix()) +} + +func getKeys(sessionRepoPath string) (keys []string, err error) { + f, err := os.Open(sessionRepoPath) + if err != nil { + return + } + return f.Readdirnames(0) +} + +func main() { + opts, _, err := util.Getopts(os.Args, "f:") + if err != nil { + log.Fatal(err) + } + + for _, opt := range opts { + switch opt.Option { + case 'f': + configFile = opt.Value + } + } + + config, err := config.ParseFile(configFile) + if err != nil { + log.Fatal(err) + } + + if !config.IsValid() { + log.Fatal("invalid config") + } + + sessionRepoPath := filepath.Join(config.DatabasePath, "session") + sessionDB, err := kv.NewDatabse(sessionRepoPath) + if err != nil { + log.Fatal(err) + } + + sessionRepo := repository.NewSessionRepository(sessionDB) + + sessionIds, err := getKeys(sessionRepoPath) + if err != nil { + log.Fatal(err) + } + + for _, id := range sessionIds { + s, err := sessionRepo.Get(id) + if err != nil { + log.Fatal(err) + } + s.CSRFToken = util.NewCSRFToken() + err = sessionRepo.Add(s) + if err != nil { + log.Fatal(err) + } + } + +} diff --git a/model/session.go b/model/session.go index fce6173..6bc8a63 100644 --- a/model/session.go +++ b/model/session.go @@ -12,6 +12,7 @@ type Session struct { ID string `json:"id"` InstanceDomain string `json:"instance_domain"` AccessToken string `json:"access_token"` + CSRFToken string `json:"csrf_token"` Settings Settings `json:"settings"` } diff --git a/renderer/model.go b/renderer/model.go index cc0a6ce..25fa0c6 100644 --- a/renderer/model.go +++ b/renderer/model.go @@ -10,12 +10,14 @@ type Context struct { FluorideMode bool ThreadInNewTab bool DarkMode bool + CSRFToken string } type HeaderData struct { Title string NotificationCount int CustomCSS string + CSRFToken string } type NavbarData struct { diff --git a/service/auth.go b/service/auth.go index e517383..909a9a2 100644 --- a/service/auth.go +++ b/service/auth.go @@ -11,7 +11,8 @@ import ( ) var ( - ErrInvalidSession = errors.New("invalid session") + ErrInvalidSession = errors.New("invalid session") + ErrInvalidCSRFToken = errors.New("invalid csrf token") ) type authService struct { @@ -47,6 +48,14 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error return c, nil } +func checkCSRF(ctx context.Context, c *model.Client) (err error) { + csrfToken, ok := ctx.Value("csrf_token").(string) + if !ok || csrfToken != c.Session.CSRFToken { + return ErrInvalidCSRFToken + } + return nil +} + func (s *authService) GetAuthUrl(ctx context.Context, instance string) ( redirectUrl string, sessionID string, err error) { return s.Service.GetAuthUrl(ctx, instance) @@ -184,6 +193,10 @@ func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *mod if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.SaveSettings(ctx, client, c, settings) } @@ -192,6 +205,10 @@ func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Clien if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.Like(ctx, client, c, id) } @@ -200,6 +217,10 @@ func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Cli if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.UnLike(ctx, client, c, id) } @@ -208,6 +229,10 @@ func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Cl if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.Retweet(ctx, client, c, id) } @@ -216,6 +241,10 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model. if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.UnRetweet(ctx, client, c, id) } @@ -224,6 +253,10 @@ func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model. if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files) } @@ -232,6 +265,10 @@ func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Cli if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.Follow(ctx, client, c, id) } @@ -240,5 +277,9 @@ func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.C if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.UnFollow(ctx, client, c, id) } diff --git a/service/service.go b/service/service.go index bfacf80..db851f7 100644 --- a/service/service.go +++ b/service/service.go @@ -78,12 +78,21 @@ func NewService(clientName string, clientScope string, clientWebsite string, } } -func getRendererContext(s model.Settings) *renderer.Context { +func getRendererContext(c *model.Client) *renderer.Context { + var settings model.Settings + var session model.Session + if c != nil { + settings = c.Session.Settings + session = c.Session + } else { + settings = *model.NewSettings() + } return &renderer.Context{ - MaskNSFW: s.MaskNSFW, - ThreadInNewTab: s.ThreadInNewTab, - FluorideMode: s.FluorideMode, - DarkMode: s.DarkMode, + MaskNSFW: settings.MaskNSFW, + ThreadInNewTab: settings.ThreadInNewTab, + FluorideMode: settings.FluorideMode, + DarkMode: settings.DarkMode, + CSRFToken: session.CSRFToken, } } @@ -98,9 +107,11 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) ( } sessionID = util.NewSessionId() + csrfToken := util.NewCSRFToken() session := model.Session{ ID: sessionID, InstanceDomain: instance, + CSRFToken: csrfToken, Settings: *model.NewSettings(), } err = svc.sessionRepo.Add(session) @@ -199,13 +210,6 @@ func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model if err != nil { return } - /* - err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback") - if err != nil { - return - } - err = svc.sessionRepo.Update(sessionID, c.GetAccessToken(ctx)) - */ return res.AccessToken, nil } @@ -226,13 +230,7 @@ func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *mod Error: errStr, } - var s model.Settings - if c != nil { - s = c.Session.Settings - } else { - s = *model.NewSettings() - } - rCtx := getRendererContext(s) + rCtx := getRendererContext(c) svc.renderer.RenderErrorPage(rCtx, client, data) } @@ -247,7 +245,7 @@ func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err CommonData: commonData, } - rCtx := getRendererContext(*model.NewSettings()) + rCtx := getRendererContext(nil) return svc.renderer.RenderSigninPage(rCtx, client, data) } @@ -334,7 +332,7 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, PostContext: postContext, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderTimelinePage(rCtx, client, data) if err != nil { @@ -416,7 +414,7 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo ReplyMap: replyMap, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderThreadPage(rCtx, client, data) if err != nil { @@ -478,7 +476,7 @@ func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer, NextLink: nextLink, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderNotificationPage(rCtx, client, data) if err != nil { @@ -525,7 +523,7 @@ func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *mode NextLink: nextLink, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderUserPage(rCtx, client, data) if err != nil { @@ -544,7 +542,7 @@ func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *mod data := &renderer.AboutData{ CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderAboutPage(rCtx, client, data) if err != nil { @@ -569,7 +567,7 @@ func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *mod Emojis: emojis, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderEmojiPage(rCtx, client, data) if err != nil { @@ -594,7 +592,7 @@ func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *m CommonData: commonData, Users: likers, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderLikedByPage(rCtx, client, data) if err != nil { @@ -619,7 +617,7 @@ func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer, CommonData: commonData, Users: retweeters, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderRetweetedByPage(rCtx, client, data) if err != nil { @@ -660,7 +658,7 @@ func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c HasNext: hasNext, NextLink: nextLink, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderFollowingPage(rCtx, client, data) if err != nil { @@ -701,7 +699,7 @@ func (svc *service) ServeFollowersPage(ctx context.Context, client io.Writer, c HasNext: hasNext, NextLink: nextLink, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderFollowersPage(rCtx, client, data) if err != nil { @@ -750,7 +748,7 @@ func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *mo HasNext: hasNext, NextLink: nextLink, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderSearchPage(rCtx, client, data) if err != nil { @@ -770,7 +768,7 @@ func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c * CommonData: commonData, Settings: &c.Session.Settings, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderSettingsPage(rCtx, client, data) if err != nil { @@ -828,6 +826,7 @@ func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *mode } data.HeaderData.NotificationCount = notificationCount + data.HeaderData.CSRFToken = c.Session.CSRFToken } return diff --git a/service/transport.go b/service/transport.go index 8cca4f5..e878f8d 100644 --- a/service/transport.go +++ b/service/transport.go @@ -160,6 +160,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/like/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] retweetedByID := req.FormValue("retweeted_by_id") @@ -179,6 +181,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] retweetedByID := req.FormValue("retweeted_by_id") @@ -198,6 +202,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] retweetedByID := req.FormValue("retweeted_by_id") @@ -217,6 +223,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] retweetedByID := req.FormValue("retweeted_by_id") @@ -236,6 +244,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/fluoride/like/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] count, err := s.Like(ctx, w, nil, id) if err != nil { @@ -252,6 +262,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/fluoride/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] count, err := s.UnLike(ctx, w, nil, id) if err != nil { @@ -268,6 +280,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/fluoride/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] count, err := s.Retweet(ctx, w, nil, id) if err != nil { @@ -284,6 +298,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/fluoride/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] count, err := s.UnRetweet(ctx, w, nil, id) if err != nil { @@ -299,14 +315,16 @@ func NewHandler(s Service, staticDir string) http.Handler { }).Methods(http.MethodPost) r.HandleFunc("/post", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - err := req.ParseMultipartForm(4 << 20) if err != nil { s.ServeErrorPage(ctx, w, nil, err) return } + ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", + getMultipartFormValue(req.MultipartForm, "csrf_token")) + content := getMultipartFormValue(req.MultipartForm, "content") replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id") format := getMultipartFormValue(req.MultipartForm, "format") @@ -358,6 +376,7 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/follow/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) id, _ := mux.Vars(req)["id"] @@ -373,6 +392,7 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/unfollow/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) id, _ := mux.Vars(req)["id"] @@ -442,6 +462,7 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) visibility := req.FormValue("visibility") copyScope := req.FormValue("copy_scope") == "true" diff --git a/static/fluoride.js b/static/fluoride.js index 6a1b5fb..3c0d7f2 100644 --- a/static/fluoride.js +++ b/static/fluoride.js @@ -16,7 +16,14 @@ var reverseActions = { "unretweet": "retweet" }; -function http(method, url, success, error) { +function getCSRFToken() { + var tag = document.querySelector("meta[name='csrf_token']") + if (tag) + return tag.getAttribute("content"); + return ""; +} + +function http(method, url, body, type, success, error) { var req = new XMLHttpRequest(); req.onload = function() { if (this.status === 200 && typeof success === "function") { @@ -31,14 +38,15 @@ function http(method, url, success, error) { } }; req.open(method, url); - req.send(); + req.setRequestHeader("Content-Type", type); + req.send(body); } function updateActionForm(id, f, action) { if (Array.from(document.body.classList).indexOf("dark") > -1) { - f.children[1].src = actionIcons["dark-" + action]; + f.querySelector(".icon").src = actionIcons["dark-" + action]; } else { - f.children[1].src = actionIcons[action]; + f.querySelector(".icon").src = actionIcons[action]; } f.action = "/" + action + "/" + id; f.dataset.action = action; @@ -54,7 +62,9 @@ function handleLikeForm(id, f) { updateActionForm(id, f, reverseActions[action]); }); - http("POST", "/fluoride/" + action + "/" + id, function(res, type) { + var body = "csrf_token=" + encodeURIComponent(getCSRFToken()); + var contentType = "application/x-www-form-urlencoded"; + http("POST", "/fluoride/" + action + "/" + id, body, contentType, function(res, type) { var data = JSON.parse(res); var count = data.data; if (count === 0) { @@ -82,7 +92,9 @@ function handleRetweetForm(id, f) { updateActionForm(id, f, reverseActions[action]); }); - http("POST", "/fluoride/" + action + "/" + id, function(res, type) { + var body = "csrf_token=" + encodeURIComponent(getCSRFToken()); + var contentType = "application/x-www-form-urlencoded"; + http("POST", "/fluoride/" + action + "/" + id, body, contentType, function(res, type) { var data = JSON.parse(res); var count = data.data; if (count === 0) { diff --git a/templates/header.tmpl b/templates/header.tmpl index 571008a..e6e7f0d 100644 --- a/templates/header.tmpl +++ b/templates/header.tmpl @@ -4,6 +4,9 @@ + {{if .CSRFToken}} + + {{end}} {{if gt .NotificationCount 0}}({{.NotificationCount}}) {{end}}{{.Title}} {{if .CustomCSS}} diff --git a/templates/postform.tmpl b/templates/postform.tmpl index 0b83d2c..ff2dfd9 100644 --- a/templates/postform.tmpl +++ b/templates/postform.tmpl @@ -1,5 +1,6 @@ {{with .Data}}
+ {{if .ReplyContext}} diff --git a/templates/settings.tmpl b/templates/settings.tmpl index a32a1b0..e7d49e9 100644 --- a/templates/settings.tmpl +++ b/templates/settings.tmpl @@ -4,6 +4,7 @@
Settings
+
+ + {{else}}
- + +
{{end}} @@ -126,12 +128,14 @@
{{if .Favourited}}
- + +
{{else}}
- + +
{{end}} diff --git a/templates/user.tmpl b/templates/user.tmpl index bbbce32..abf22ec 100644 --- a/templates/user.tmpl +++ b/templates/user.tmpl @@ -22,17 +22,20 @@ {{if .User.Pleroma.Relationship.FollowedBy}} follows you - {{end}} {{if .User.Pleroma.Relationship.Following}}
- + +
{{end}} {{if .User.Pleroma.Relationship.Requested}}
- + +
{{end}} {{if not .User.Pleroma.Relationship.Following}}
- + +
{{end}}
diff --git a/util/rand.go b/util/rand.go index 8502521..212d6d3 100644 --- a/util/rand.go +++ b/util/rand.go @@ -20,3 +20,7 @@ func NewRandId(n int) string { func NewSessionId() string { return NewRandId(24) } + +func NewCSRFToken() string { + return NewRandId(24) +}