diff --git a/go.mod b/go.mod index de1ba89..0bebfe1 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.13 require ( github.com/gorilla/mux v1.7.3 - github.com/mattn/go-sqlite3 v2.0.1+incompatible + github.com/mattn/go-sqlite3 v2.0.2+incompatible // indirect mastodon v0.0.0-00010101000000-000000000000 ) diff --git a/go.mod.old b/go.mod.old new file mode 100644 index 0000000..e633126 --- /dev/null +++ b/go.mod.old @@ -0,0 +1,10 @@ +module web + +go 1.13 + +require ( + github.com/gorilla/mux v1.7.3 + mastodon v0.0.0-00010101000000-000000000000 +) + +replace mastodon => ./mastodon diff --git a/go.sum b/go.sum index 236732d..7a53570 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,7 @@ github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= -github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v2.0.2+incompatible h1:qzw9c2GNT8UFrgWNDhCTqRqYUSmu/Dav/9Z58LGpk7U= +github.com/mattn/go-sqlite3 v2.0.2+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y= github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE= diff --git a/kv/kv.go b/kv/kv.go new file mode 100644 index 0000000..2cfcd60 --- /dev/null +++ b/kv/kv.go @@ -0,0 +1,92 @@ +package kv + +import ( + "errors" + "io/ioutil" + "os" + "path/filepath" + "strings" + "sync" +) + +var ( + errInvalidKey = errors.New("invalid key") + errNoSuchKey = errors.New("no such key") +) + +type Database struct { + data map[string][]byte + basedir string + m sync.RWMutex +} + +func NewDatabse(basedir string) (db *Database, err error) { + err = os.Mkdir(basedir, 0755) + if err != nil && !os.IsExist(err) { + return + } + + return &Database{ + data: make(map[string][]byte), + basedir: basedir, + }, nil +} + +func (db *Database) Set(key string, val []byte) (err error) { + if len(key) < 1 { + return errInvalidKey + } + + db.m.Lock() + defer func() { + if err != nil { + delete(db.data, key) + } + db.m.Unlock() + }() + + db.data[key] = val + + err = ioutil.WriteFile(filepath.Join(db.basedir, key), val, 0644) + + return +} + +func (db *Database) Get(key string) (val []byte, err error) { + if len(key) < 1 { + return nil, errInvalidKey + } + + db.m.RLock() + defer db.m.RUnlock() + + data, ok := db.data[key] + if !ok { + data, err = ioutil.ReadFile(filepath.Join(db.basedir, key)) + if err != nil { + err = errNoSuchKey + return nil, err + } + + db.data[key] = data + } + + val = make([]byte, len(data)) + copy(val, data) + + return +} + +func (db *Database) Remove(key string) { + if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) { + return + } + + db.m.Lock() + defer db.m.Unlock() + + delete(db.data, key) + os.Remove(filepath.Join(db.basedir, key)) + + return +} diff --git a/main.go b/main.go index d726fed..ad62976 100644 --- a/main.go +++ b/main.go @@ -1,19 +1,18 @@ package main import ( - "database/sql" "log" "math/rand" "net/http" "os" + "path/filepath" "time" "web/config" + "web/kv" "web/renderer" "web/repository" "web/service" - - _ "github.com/mattn/go-sqlite3" ) func init() { @@ -35,22 +34,24 @@ func main() { log.Fatal(err) } - db, err := sql.Open("sqlite3", config.DatabasePath) - if err != nil { + err = os.Mkdir(config.DatabasePath, 0755) + if err != nil && !os.IsExist(err) { log.Fatal(err) } - defer db.Close() - sessionRepo, err := repository.NewSessionRepository(db) + sessionDB, err := kv.NewDatabse(filepath.Join(config.DatabasePath, "session")) if err != nil { log.Fatal(err) } - appRepo, err := repository.NewAppRepository(db) + appDB, err := kv.NewDatabse(filepath.Join(config.DatabasePath, "app")) if err != nil { log.Fatal(err) } + sessionRepo := repository.NewSessionRepository(sessionDB) + appRepo := repository.NewAppRepository(appDB) + var logger *log.Logger if len(config.Logfile) < 1 { logger = log.New(os.Stdout, "", log.LstdFlags) diff --git a/model/app.go b/model/app.go index 52ebdf5..89d656d 100644 --- a/model/app.go +++ b/model/app.go @@ -1,19 +1,40 @@ package model -import "errors" +import ( + "errors" + "strings" +) var ( ErrAppNotFound = errors.New("app not found") ) type App struct { - InstanceURL string - ClientID string - ClientSecret string + InstanceDomain string + InstanceURL string + ClientID string + ClientSecret string } type AppRepository interface { Add(app App) (err error) - Update(instanceURL string, clientID string, clientSecret string) (err error) - Get(instanceURL string) (app App, err error) + Get(instanceDomain string) (app App, err error) +} + +func (a *App) Marshal() []byte { + str := a.InstanceURL + "\n" + a.ClientID + "\n" + a.ClientSecret + return []byte(str) +} + +func (a *App) Unmarshal(instanceDomain string, data []byte) error { + str := string(data) + lines := strings.Split(str, "\n") + if len(lines) != 3 { + return errors.New("invalid data") + } + a.InstanceDomain = instanceDomain + a.InstanceURL = lines[0] + a.ClientID = lines[1] + a.ClientSecret = lines[2] + return nil } diff --git a/model/session.go b/model/session.go index 43628ee..94f527b 100644 --- a/model/session.go +++ b/model/session.go @@ -1,15 +1,18 @@ package model -import "errors" +import ( + "errors" + "strings" +) var ( ErrSessionNotFound = errors.New("session not found") ) type Session struct { - ID string - InstanceURL string - AccessToken string + ID string + InstanceDomain string + AccessToken string } type SessionRepository interface { @@ -21,3 +24,26 @@ type SessionRepository interface { func (s Session) IsLoggedIn() bool { return len(s.AccessToken) > 0 } + +func (s *Session) Marshal() []byte { + str := s.InstanceDomain + "\n" + s.AccessToken + return []byte(str) +} + +func (s *Session) Unmarshal(id string, data []byte) error { + str := string(data) + lines := strings.Split(str, "\n") + + size := len(lines) + if size == 1 { + s.InstanceDomain = lines[0] + } else if size == 2 { + s.InstanceDomain = lines[0] + s.AccessToken = lines[1] + } else { + return errors.New("invalid data") + } + + s.ID = id + return nil +} diff --git a/repository/appRepository.go b/repository/appRepository.go index 1a8f204..00ef64d 100644 --- a/repository/appRepository.go +++ b/repository/appRepository.go @@ -1,54 +1,33 @@ package repository import ( - "database/sql" - + "web/kv" "web/model" ) type appRepository struct { - db *sql.DB + db *kv.Database } -func NewAppRepository(db *sql.DB) (*appRepository, error) { - _, err := db.Exec(`CREATE TABLE IF NOT EXISTS app - (instance_url varchar, client_id varchar, client_secret varchar)`, - ) - if err != nil { - return nil, err - } - +func NewAppRepository(db *kv.Database) *appRepository { return &appRepository{ db: db, - }, nil + } } func (repo *appRepository) Add(a model.App) (err error) { - _, err = repo.db.Exec("INSERT INTO app VALUES (?, ?, ?)", a.InstanceURL, a.ClientID, a.ClientSecret) + err = repo.db.Set(a.InstanceDomain, a.Marshal()) return } -func (repo *appRepository) Update(instanceURL string, clientID string, clientSecret string) (err error) { - _, err = repo.db.Exec("UPDATE app SET client_id = ?, client_secret = ? where instance_url = ?", clientID, clientSecret, instanceURL) - return -} - -func (repo *appRepository) Get(instanceURL string) (a model.App, err error) { - rows, err := repo.db.Query("SELECT * FROM app WHERE instance_url = ?", instanceURL) +func (repo *appRepository) Get(instanceDomain string) (a model.App, err error) { + data, err := repo.db.Get(instanceDomain) if err != nil { - return - } - defer rows.Close() - - if !rows.Next() { err = model.ErrAppNotFound return } - err = rows.Scan(&a.InstanceURL, &a.ClientID, &a.ClientSecret) - if err != nil { - return - } + err = a.Unmarshal(instanceDomain, data) return } diff --git a/repository/sessionRepository.go b/repository/sessionRepository.go index 2a88b40..6c26313 100644 --- a/repository/sessionRepository.go +++ b/repository/sessionRepository.go @@ -1,54 +1,50 @@ package repository import ( - "database/sql" - + "web/kv" "web/model" ) type sessionRepository struct { - db *sql.DB + db *kv.Database } -func NewSessionRepository(db *sql.DB) (*sessionRepository, error) { - _, err := db.Exec(`CREATE TABLE IF NOT EXISTS session - (id varchar, instance_url varchar, access_token varchar)`, - ) - if err != nil { - return nil, err - } - +func NewSessionRepository(db *kv.Database) *sessionRepository { return &sessionRepository{ db: db, - }, nil + } } func (repo *sessionRepository) Add(s model.Session) (err error) { - _, err = repo.db.Exec("INSERT INTO session VALUES (?, ?, ?)", s.ID, s.InstanceURL, s.AccessToken) + err = repo.db.Set(s.ID, s.Marshal()) return } -func (repo *sessionRepository) Update(sessionID string, accessToken string) (err error) { - _, err = repo.db.Exec("UPDATE session SET access_token = ? where id = ?", accessToken, sessionID) - return -} - -func (repo *sessionRepository) Get(id string) (s model.Session, err error) { - rows, err := repo.db.Query("SELECT * FROM session WHERE id = ?", id) +func (repo *sessionRepository) Update(id string, accessToken string) (err error) { + data, err := repo.db.Get(id) if err != nil { return } - defer rows.Close() - if !rows.Next() { + var s model.Session + err = s.Unmarshal(id, data) + if err != nil { + return + } + + s.AccessToken = accessToken + + return repo.db.Set(id, s.Marshal()) +} + +func (repo *sessionRepository) Get(id string) (s model.Session, err error) { + data, err := repo.db.Get(id) + if err != nil { err = model.ErrSessionNotFound return } - err = rows.Scan(&s.ID, &s.InstanceURL, &s.AccessToken) - if err != nil { - return - } + err = s.Unmarshal(id, data) return } diff --git a/service/auth.go b/service/auth.go index e9bec38..38c0a43 100644 --- a/service/auth.go +++ b/service/auth.go @@ -40,12 +40,12 @@ func (s *authService) getClient(ctx context.Context) (c *mastodon.Client, err er if err != nil { return nil, ErrInvalidSession } - client, err := s.appRepo.Get(session.InstanceURL) + client, err := s.appRepo.Get(session.InstanceDomain) if err != nil { return } c = mastodon.NewClient(&mastodon.Config{ - Server: session.InstanceURL, + Server: client.InstanceURL, ClientID: client.ClientID, ClientSecret: client.ClientSecret, AccessToken: session.AccessToken, diff --git a/service/service.go b/service/service.go index 5181475..bb03c26 100644 --- a/service/service.go +++ b/service/service.go @@ -9,7 +9,6 @@ import ( "mime/multipart" "net/http" "net/url" - "path" "strings" "mastodon" @@ -64,14 +63,18 @@ func NewService(clientName string, clientScope string, clientWebsite string, func (svc *service) GetAuthUrl(ctx context.Context, instance string) ( redirectUrl string, sessionID string, err error) { - if !strings.HasPrefix(instance, "https://") { - instance = "https://" + instance + var instanceURL string + if strings.HasPrefix(instance, "https://") { + instanceURL = instance + instance = strings.TrimPrefix(instance, "https://") + } else { + instanceURL = "https://" + instance } sessionID = util.NewSessionId() err = svc.sessionRepo.Add(model.Session{ - ID: sessionID, - InstanceURL: instance, + ID: sessionID, + InstanceDomain: instance, }) if err != nil { return @@ -85,7 +88,7 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) ( var mastoApp *mastodon.Application mastoApp, err = mastodon.RegisterApp(ctx, &mastodon.AppConfig{ - Server: instance, + Server: instanceURL, ClientName: svc.clientName, Scopes: svc.clientScope, Website: svc.clientWebsite, @@ -96,9 +99,10 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) ( } app = model.App{ - InstanceURL: instance, - ClientID: mastoApp.ClientID, - ClientSecret: mastoApp.ClientSecret, + InstanceDomain: instance, + InstanceURL: instanceURL, + ClientID: mastoApp.ClientID, + ClientSecret: mastoApp.ClientSecret, } err = svc.appRepo.Add(app) @@ -136,7 +140,7 @@ func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *masto return } - app, err := svc.appRepo.Get(session.InstanceURL) + app, err := svc.appRepo.Get(session.InstanceDomain) if err != nil { return }