diff --git a/cmd/gotosocial/action/testrig/testrig.go b/cmd/gotosocial/action/testrig/testrig.go index 0769b8878..79361375b 100644 --- a/cmd/gotosocial/action/testrig/testrig.go +++ b/cmd/gotosocial/action/testrig/testrig.go @@ -98,8 +98,8 @@ testrig.StandardStorageSetup(state.Storage, "./testrig/media") // Initialize workers. - state.Workers.Start() - defer state.Workers.Stop() + testrig.StartNoopWorkers(&state) + defer testrig.StopWorkers(&state) // build backend handlers transportController := testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { diff --git a/internal/api/auth/token.go b/internal/api/auth/token.go index cab9352fa..d9f0d8154 100644 --- a/internal/api/auth/token.go +++ b/internal/api/auth/token.go @@ -49,7 +49,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) { form := &tokenRequestForm{} if err := c.ShouldBind(form); err != nil { - apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), err.Error())) + apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, err.Error())) return } @@ -98,7 +98,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) { } if len(help) != 0 { - apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), help...)) + apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, help...)) return } diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 3aa21cdd0..d35162172 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -59,6 +59,7 @@ func (c *Caches) Init() { c.initBlock() c.initBlockIDs() c.initBoostOfIDs() + c.initClient() c.initDomainAllow() c.initDomainBlock() c.initEmoji() @@ -85,9 +86,10 @@ func (c *Caches) Init() { c.initReport() c.initStatus() c.initStatusFave() + c.initStatusFaveIDs() c.initTag() c.initThreadMute() - c.initStatusFaveIDs() + c.initToken() c.initTombstone() c.initUser() c.initWebfinger() diff --git a/internal/cache/db.go b/internal/cache/db.go index c383ed6c7..cb0ed6712 100644 --- a/internal/cache/db.go +++ b/internal/cache/db.go @@ -58,6 +58,9 @@ type GTSCaches struct { // BoostOfIDs provides access to the boost of IDs list database cache. BoostOfIDs SliceCache[string] + // Client provides access to the gtsmodel Client database cache. + Client StructCache[*gtsmodel.Client] + // DomainAllow provides access to the domain allow database cache. DomainAllow *domain.Cache @@ -150,6 +153,9 @@ type GTSCaches struct { // Tag provides access to the gtsmodel Tag database cache. Tag StructCache[*gtsmodel.Tag] + // Token provides access to the gtsmodel Token database cache. + Token StructCache[*gtsmodel.Token] + // Tombstone provides access to the gtsmodel Tombstone database cache. Tombstone StructCache[*gtsmodel.Tombstone] @@ -309,9 +315,10 @@ func (c *Caches) initApplication() { {Fields: "ID"}, {Fields: "ClientID"}, }, - MaxSize: cap, - IgnoreErr: ignoreErrors, - Copy: copyF, + MaxSize: cap, + IgnoreErr: ignoreErrors, + Copy: copyF, + Invalidate: c.OnInvalidateApplication, }) } @@ -374,6 +381,32 @@ func (c *Caches) initBoostOfIDs() { c.GTS.BoostOfIDs.Init(0, cap) } +func (c *Caches) initClient() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofClient(), // model in-mem size. + config.GetCacheClientMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(c1 *gtsmodel.Client) *gtsmodel.Client { + c2 := new(gtsmodel.Client) + *c2 = *c1 + return c2 + } + + c.GTS.Client.Init(structr.CacheConfig[*gtsmodel.Client]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + Copy: copyF, + Invalidate: c.OnInvalidateClient, + }) +} + func (c *Caches) initDomainAllow() { c.GTS.DomainAllow = new(domain.Cache) } @@ -1135,7 +1168,7 @@ func (c *Caches) initTag() { func (c *Caches) initThreadMute() { cap := calculateResultCacheMax( - sizeOfThreadMute(), // model in-mem size. + sizeofThreadMute(), // model in-mem size. config.GetCacheThreadMuteMemRatio(), ) @@ -1160,6 +1193,35 @@ func (c *Caches) initThreadMute() { }) } +func (c *Caches) initToken() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofToken(), // model in-mem size. + config.GetCacheTokenMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(t1 *gtsmodel.Token) *gtsmodel.Token { + t2 := new(gtsmodel.Token) + *t2 = *t1 + return t2 + } + + c.GTS.Token.Init(structr.CacheConfig[*gtsmodel.Token]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "Code"}, + {Fields: "Access"}, + {Fields: "Refresh"}, + {Fields: "ClientID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + Copy: copyF, + }) +} + func (c *Caches) initTombstone() { // Calculate maximum cache size. cap := calculateResultCacheMax( diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go index 746d8c7e7..547015eac 100644 --- a/internal/cache/invalidate.go +++ b/internal/cache/invalidate.go @@ -60,6 +60,11 @@ func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) { c.GTS.Move.Invalidate("TargetURI", account.URI) } +func (c *Caches) OnInvalidateApplication(app *gtsmodel.Application) { + // Invalidate cached client of this application. + c.GTS.Client.Invalidate("ID", app.ClientID) +} + func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) { // Invalidate block origin account ID cached visibility. c.Visibility.Invalidate("ItemID", block.AccountID) @@ -73,6 +78,11 @@ func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) { c.GTS.BlockIDs.Invalidate(block.AccountID) } +func (c *Caches) OnInvalidateClient(client *gtsmodel.Client) { + // Invalidate any tokens under this client. + c.GTS.Token.Invalidate("ClientID", client.ID) +} + func (c *Caches) OnInvalidateEmojiCategory(category *gtsmodel.EmojiCategory) { // Invalidate any emoji in this category. c.GTS.Emoji.Invalidate("CategoryID", category.ID) diff --git a/internal/cache/size.go b/internal/cache/size.go index 83b0da046..9c1a82abc 100644 --- a/internal/cache/size.go +++ b/internal/cache/size.go @@ -176,6 +176,7 @@ func totalOfRatios() float64 { config.GetCacheBlockMemRatio() + config.GetCacheBlockIDsMemRatio() + config.GetCacheBoostOfIDsMemRatio() + + config.GetCacheClientMemRatio() + config.GetCacheEmojiMemRatio() + config.GetCacheEmojiCategoryMemRatio() + config.GetCacheFollowMemRatio() + @@ -198,6 +199,7 @@ func totalOfRatios() float64 { config.GetCacheStatusFaveIDsMemRatio() + config.GetCacheTagMemRatio() + config.GetCacheThreadMuteMemRatio() + + config.GetCacheTokenMemRatio() + config.GetCacheTombstoneMemRatio() + config.GetCacheUserMemRatio() + config.GetCacheWebfingerMemRatio() + @@ -287,6 +289,17 @@ func sizeofBlock() uintptr { })) } +func sizeofClient() uintptr { + return uintptr(size.Of(>smodel.Client{ + ID: exampleID, + CreatedAt: exampleTime, + UpdatedAt: exampleTime, + Secret: exampleID, + Domain: exampleURI, + UserID: exampleID, + })) +} + func sizeofEmoji() uintptr { return uintptr(size.Of(>smodel.Emoji{ ID: exampleID, @@ -591,7 +604,7 @@ func sizeofTag() uintptr { })) } -func sizeOfThreadMute() uintptr { +func sizeofThreadMute() uintptr { return uintptr(size.Of(>smodel.ThreadMute{ ID: exampleID, CreatedAt: exampleTime, @@ -601,6 +614,29 @@ func sizeOfThreadMute() uintptr { })) } +func sizeofToken() uintptr { + return uintptr(size.Of(>smodel.Token{ + ID: exampleID, + CreatedAt: exampleTime, + UpdatedAt: exampleTime, + ClientID: exampleID, + UserID: exampleID, + RedirectURI: exampleURI, + Scope: "r:w", + Code: "", // TODO + CodeChallenge: "", // TODO + CodeChallengeMethod: "", // TODO + CodeCreateAt: exampleTime, + CodeExpiresAt: exampleTime, + Access: exampleID + exampleID, + AccessCreateAt: exampleTime, + AccessExpiresAt: exampleTime, + Refresh: "", // TODO: clients don't really support this very well yet + RefreshCreateAt: exampleTime, + RefreshExpiresAt: exampleTime, + })) +} + func sizeofTombstone() uintptr { return uintptr(size.Of(>smodel.Tombstone{ ID: exampleID, diff --git a/internal/config/config.go b/internal/config/config.go index dee9e99de..3cd67525f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -199,6 +199,7 @@ type CacheConfiguration struct { BlockMemRatio float64 `name:"block-mem-ratio"` BlockIDsMemRatio float64 `name:"block-mem-ratio"` BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"` + ClientMemRatio float64 `name:"client-mem-ratio"` EmojiMemRatio float64 `name:"emoji-mem-ratio"` EmojiCategoryMemRatio float64 `name:"emoji-category-mem-ratio"` FilterMemRatio float64 `name:"filter-mem-ratio"` @@ -226,6 +227,7 @@ type CacheConfiguration struct { StatusFaveIDsMemRatio float64 `name:"status-fave-ids-mem-ratio"` TagMemRatio float64 `name:"tag-mem-ratio"` ThreadMuteMemRatio float64 `name:"thread-mute-mem-ratio"` + TokenMemRatio float64 `name:"token-mem-ratio"` TombstoneMemRatio float64 `name:"tombstone-mem-ratio"` UserMemRatio float64 `name:"user-mem-ratio"` WebfingerMemRatio float64 `name:"webfinger-mem-ratio"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 64fff366a..f5f8fb6ac 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -163,6 +163,7 @@ BlockMemRatio: 2, BlockIDsMemRatio: 3, BoostOfIDsMemRatio: 3, + ClientMemRatio: 0.1, EmojiMemRatio: 3, EmojiCategoryMemRatio: 0.1, FilterMemRatio: 0.5, @@ -190,6 +191,7 @@ StatusFaveIDsMemRatio: 3, TagMemRatio: 2, ThreadMuteMemRatio: 0.2, + TokenMemRatio: 0.75, TombstoneMemRatio: 0.5, UserMemRatio: 0.25, WebfingerMemRatio: 0.1, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index 39d26d13e..a8c919834 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2925,6 +2925,31 @@ func GetCacheBoostOfIDsMemRatio() float64 { return global.GetCacheBoostOfIDsMemR // SetCacheBoostOfIDsMemRatio safely sets the value for global configuration 'Cache.BoostOfIDsMemRatio' field func SetCacheBoostOfIDsMemRatio(v float64) { global.SetCacheBoostOfIDsMemRatio(v) } +// GetCacheClientMemRatio safely fetches the Configuration value for state's 'Cache.ClientMemRatio' field +func (st *ConfigState) GetCacheClientMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.ClientMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheClientMemRatio safely sets the Configuration value for state's 'Cache.ClientMemRatio' field +func (st *ConfigState) SetCacheClientMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.ClientMemRatio = v + st.reloadToViper() +} + +// CacheClientMemRatioFlag returns the flag name for the 'Cache.ClientMemRatio' field +func CacheClientMemRatioFlag() string { return "cache-client-mem-ratio" } + +// GetCacheClientMemRatio safely fetches the value for global configuration 'Cache.ClientMemRatio' field +func GetCacheClientMemRatio() float64 { return global.GetCacheClientMemRatio() } + +// SetCacheClientMemRatio safely sets the value for global configuration 'Cache.ClientMemRatio' field +func SetCacheClientMemRatio(v float64) { global.SetCacheClientMemRatio(v) } + // GetCacheEmojiMemRatio safely fetches the Configuration value for state's 'Cache.EmojiMemRatio' field func (st *ConfigState) GetCacheEmojiMemRatio() (v float64) { st.mutex.RLock() @@ -3600,6 +3625,31 @@ func GetCacheThreadMuteMemRatio() float64 { return global.GetCacheThreadMuteMemR // SetCacheThreadMuteMemRatio safely sets the value for global configuration 'Cache.ThreadMuteMemRatio' field func SetCacheThreadMuteMemRatio(v float64) { global.SetCacheThreadMuteMemRatio(v) } +// GetCacheTokenMemRatio safely fetches the Configuration value for state's 'Cache.TokenMemRatio' field +func (st *ConfigState) GetCacheTokenMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.TokenMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheTokenMemRatio safely sets the Configuration value for state's 'Cache.TokenMemRatio' field +func (st *ConfigState) SetCacheTokenMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.TokenMemRatio = v + st.reloadToViper() +} + +// CacheTokenMemRatioFlag returns the flag name for the 'Cache.TokenMemRatio' field +func CacheTokenMemRatioFlag() string { return "cache-token-mem-ratio" } + +// GetCacheTokenMemRatio safely fetches the value for global configuration 'Cache.TokenMemRatio' field +func GetCacheTokenMemRatio() float64 { return global.GetCacheTokenMemRatio() } + +// SetCacheTokenMemRatio safely sets the value for global configuration 'Cache.TokenMemRatio' field +func SetCacheTokenMemRatio(v float64) { global.SetCacheTokenMemRatio(v) } + // GetCacheTombstoneMemRatio safely fetches the Configuration value for state's 'Cache.TombstoneMemRatio' field func (st *ConfigState) GetCacheTombstoneMemRatio() (v float64) { st.mutex.RLock() diff --git a/internal/db/application.go b/internal/db/application.go index 34a857d3f..b71e593c2 100644 --- a/internal/db/application.go +++ b/internal/db/application.go @@ -35,4 +35,40 @@ type Application interface { // DeleteApplicationByClientID deletes the application with corresponding client_id value from the database. DeleteApplicationByClientID(ctx context.Context, clientID string) error + + // GetClientByID ... + GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) + + // PutClient ... + PutClient(ctx context.Context, client *gtsmodel.Client) error + + // DeleteClientByID ... + DeleteClientByID(ctx context.Context, id string) error + + // GetAllTokens ... + GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) + + // GetTokenByCode ... + GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) + + // GetTokenByAccess ... + GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) + + // GetTokenByRefresh ... + GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) + + // PutToken ... + PutToken(ctx context.Context, token *gtsmodel.Token) error + + // DeleteTokenByID ... + DeleteTokenByID(ctx context.Context, id string) error + + // DeleteTokenByCode ... + DeleteTokenByCode(ctx context.Context, code string) error + + // DeleteTokenByAccess ... + DeleteTokenByAccess(ctx context.Context, access string) error + + // DeleteTokenByRefresh ... + DeleteTokenByRefresh(ctx context.Context, refresh string) error } diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index e52467b9b..e9191b7c7 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -397,7 +397,7 @@ func (a *adminDB) CreateInstanceApplication(ctx context.Context) error { } // Store it. - return a.state.DB.Put(ctx, oc) + return a.state.DB.PutClient(ctx, oc) } func (a *adminDB) GetInstanceApplication(ctx context.Context) (*gtsmodel.Application, error) { diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go index f02632793..c2a957c93 100644 --- a/internal/db/bundb/application.go +++ b/internal/db/bundb/application.go @@ -22,6 +22,7 @@ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -95,3 +96,181 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI return nil } + +func (a *applicationDB) GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) { + return a.state.Caches.GTS.Client.LoadOne("ID", func() (*gtsmodel.Client, error) { + var client gtsmodel.Client + + if err := a.db.NewSelect(). + Model(&client). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx); err != nil { + return nil, err + } + + return &client, nil + }, id) +} + +func (a *applicationDB) PutClient(ctx context.Context, client *gtsmodel.Client) error { + return a.state.Caches.GTS.Client.Store(client, func() error { + _, err := a.db.NewInsert().Model(client).Exec(ctx) + return err + }) +} + +func (a *applicationDB) DeleteClientByID(ctx context.Context, id string) error { + _, err := a.db.NewDelete(). + Table("clients"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + if err != nil { + return err + } + + a.state.Caches.GTS.Client.Invalidate("ID", id) + return nil +} + +func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) { + var tokenIDs []string + + // Select ALL token IDs. + if err := a.db.NewSelect(). + Table("tokens"). + Column("id"). + Scan(ctx, &tokenIDs); err != nil { + return nil, err + } + + // Load all input token IDs via cache loader callback. + tokens, err := a.state.Caches.GTS.Token.LoadIDs("ID", + tokenIDs, + func(uncached []string) ([]*gtsmodel.Token, error) { + // Preallocate expected length of uncached tokens. + tokens := make([]*gtsmodel.Token, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) token IDs. + if err := a.db.NewSelect(). + Model(tokens). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return tokens, nil + }, + ) + if err != nil { + return nil, err + } + + // Reoroder the tokens by their + // IDs to ensure in correct order. + getID := func(t *gtsmodel.Token) string { return t.ID } + util.OrderBy(tokens, tokenIDs, getID) + + return tokens, nil +} + +func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) { + return a.getTokenBy( + "Code", + func(t *gtsmodel.Token) error { + return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("code"), code).Scan(ctx) + }, + code, + ) +} + +func (a *applicationDB) GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) { + return a.getTokenBy( + "Access", + func(t *gtsmodel.Token) error { + return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("access"), access).Scan(ctx) + }, + access, + ) +} + +func (a *applicationDB) GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) { + return a.getTokenBy( + "Refresh", + func(t *gtsmodel.Token) error { + return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("refresh"), refresh).Scan(ctx) + }, + refresh, + ) +} + +func (a *applicationDB) getTokenBy(lookup string, dbQuery func(*gtsmodel.Token) error, keyParts ...any) (*gtsmodel.Token, error) { + return a.state.Caches.GTS.Token.LoadOne(lookup, func() (*gtsmodel.Token, error) { + var token gtsmodel.Token + + if err := dbQuery(&token); err != nil { + return nil, err + } + + return &token, nil + }, keyParts...) +} + +func (a *applicationDB) PutToken(ctx context.Context, token *gtsmodel.Token) error { + return a.state.Caches.GTS.Token.Store(token, func() error { + _, err := a.db.NewInsert().Model(token).Exec(ctx) + return err + }) +} + +func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { + _, err := a.db.NewDelete(). + Table("tokens"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + if err != nil { + return err + } + + a.state.Caches.GTS.Token.Invalidate("ID", id) + return nil +} + +func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error { + _, err := a.db.NewDelete(). + Table("tokens"). + Where("? = ?", bun.Ident("code"), code). + Exec(ctx) + if err != nil { + return err + } + + a.state.Caches.GTS.Token.Invalidate("Code", code) + return nil +} + +func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error { + _, err := a.db.NewDelete(). + Table("tokens"). + Where("? = ?", bun.Ident("access"), access). + Exec(ctx) + if err != nil { + return err + } + + a.state.Caches.GTS.Token.Invalidate("Access", access) + return nil +} + +func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error { + _, err := a.db.NewDelete(). + Table("tokens"). + Where("? = ?", bun.Ident("refresh"), refresh). + Exec(ctx) + if err != nil { + return err + } + + a.state.Caches.GTS.Token.Invalidate("Refresh", refresh) + return nil +} diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go index 5bb600e70..bddb30b1b 100644 --- a/internal/oauth/clientstore.go +++ b/internal/oauth/clientstore.go @@ -27,11 +27,11 @@ ) type clientStore struct { - db db.Basic + db db.DB } // NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend. -func NewClientStore(db db.Basic) oauth2.ClientStore { +func NewClientStore(db db.DB) oauth2.ClientStore { pts := &clientStore{ db: db, } @@ -39,26 +39,27 @@ func NewClientStore(db db.Basic) oauth2.ClientStore { } func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { - poc := >smodel.Client{} - if err := cs.db.GetByID(ctx, clientID, poc); err != nil { + client, err := cs.db.GetClientByID(ctx, clientID) + if err != nil { return nil, err } - return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil + return models.New( + client.ID, + client.Secret, + client.Domain, + client.UserID, + ), nil } func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { - poc := >smodel.Client{ + return cs.db.PutClient(ctx, >smodel.Client{ ID: cli.GetID(), Secret: cli.GetSecret(), Domain: cli.GetDomain(), UserID: cli.GetUserID(), - } - return cs.db.Put(ctx, poc) + }) } func (cs *clientStore) Delete(ctx context.Context, id string) error { - poc := >smodel.Client{ - ID: id, - } - return cs.db.DeleteByID(ctx, id, poc) + return cs.db.DeleteClientByID(ctx, id) } diff --git a/internal/oauth/errors.go b/internal/oauth/errors.go index dd61be28c..b16143e5c 100644 --- a/internal/oauth/errors.go +++ b/internal/oauth/errors.go @@ -19,7 +19,5 @@ import "github.com/superseriousbusiness/oauth2/v4/errors" -// InvalidRequest returns an oauth spec compliant 'invalid_request' error. -func InvalidRequest() error { - return errors.New("invalid_request") -} +// ErrInvalidRequest is an oauth spec compliant 'invalid_request' error. +var ErrInvalidRequest = errors.New("invalid_request") diff --git a/internal/oauth/server.go b/internal/oauth/server.go index 3e4519479..4f2ed509b 100644 --- a/internal/oauth/server.go +++ b/internal/oauth/server.go @@ -75,7 +75,7 @@ type s struct { } // New returns a new oauth server that implements the Server interface -func New(ctx context.Context, database db.Basic) Server { +func New(ctx context.Context, database db.DB) Server { ts := newTokenStore(ctx, database) cs := NewClientStore(database) diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go index 3658f0aa9..14b91fa06 100644 --- a/internal/oauth/tokenstore.go +++ b/internal/oauth/tokenstore.go @@ -20,7 +20,6 @@ import ( "context" "errors" - "fmt" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -34,14 +33,14 @@ // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. type tokenStore struct { oauth2.TokenStore - db db.Basic + db db.DB } // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. // // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through // the tokens in the DB once per minute and deletes any that have expired. -func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore { +func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore { ts := &tokenStore{ db: db, } @@ -69,19 +68,19 @@ func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore { func (ts *tokenStore) sweep(ctx context.Context) error { // select *all* tokens from the db // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. - tokens := new([]*gtsmodel.Token) - if err := ts.db.GetAll(ctx, tokens); err != nil { + tokens, err := ts.db.GetAllTokens(ctx) + if err != nil { return err } // iterate through and remove expired tokens now := time.Now() - for _, dbt := range *tokens { + for _, dbt := range tokens { // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: // we only want to check if a token expired before now if the expiry time is *not zero*; // ie., if it's been explicity set. if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { - if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil { + if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil { return err } } @@ -107,67 +106,49 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { dbt.ID = dbtID } - if err := ts.db.Put(ctx, dbt); err != nil { - return fmt.Errorf("error in tokenstore create: %s", err) - } - return nil + return ts.db.PutToken(ctx, dbt) } // RemoveByCode deletes a token from the DB based on the Code field func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { - return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, >smodel.Token{}) + return ts.db.DeleteTokenByCode(ctx, code) } // RemoveByAccess deletes a token from the DB based on the Access field func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { - return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, >smodel.Token{}) + return ts.db.DeleteTokenByAccess(ctx, access) } // RemoveByRefresh deletes a token from the DB based on the Refresh field func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, >smodel.Token{}) + return ts.db.DeleteTokenByRefresh(ctx, refresh) } // GetByCode selects a token from the DB based on the Code field func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - if code == "" { - return nil, nil - } - dbt := >smodel.Token{ - Code: code, - } - if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil { + token, err := ts.db.GetTokenByCode(ctx, code) + if err != nil { return nil, err } - return DBTokenToToken(dbt), nil + return DBTokenToToken(token), nil } // GetByAccess selects a token from the DB based on the Access field func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { - if access == "" { - return nil, nil - } - dbt := >smodel.Token{ - Access: access, - } - if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil { + token, err := ts.db.GetTokenByAccess(ctx, access) + if err != nil { return nil, err } - return DBTokenToToken(dbt), nil + return DBTokenToToken(token), nil } // GetByRefresh selects a token from the DB based on the Refresh field func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - if refresh == "" { - return nil, nil - } - dbt := >smodel.Token{ - Refresh: refresh, - } - if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil { + token, err := ts.db.GetTokenByRefresh(ctx, refresh) + if err != nil { return nil, err } - return DBTokenToToken(dbt), nil + return DBTokenToToken(token), nil } /* diff --git a/internal/processing/app.go b/internal/processing/app.go index eef4fae0d..d492b3bc4 100644 --- a/internal/processing/app.go +++ b/internal/processing/app.go @@ -75,7 +75,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api } // chuck it in the db - if err := p.state.DB.Put(ctx, oc); err != nil { + if err := p.state.DB.PutClient(ctx, oc); err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/test/envparsing.sh b/test/envparsing.sh index 19b86a818..a379750c0 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -29,6 +29,7 @@ EXPECT=$(cat << "EOF" "application-mem-ratio": 0.1, "block-mem-ratio": 3, "boost-of-ids-mem-ratio": 3, + "client-mem-ratio": 0.1, "emoji-category-mem-ratio": 0.1, "emoji-mem-ratio": 3, "filter-keyword-mem-ratio": 0.5, @@ -57,6 +58,7 @@ EXPECT=$(cat << "EOF" "status-mem-ratio": 5, "tag-mem-ratio": 2, "thread-mute-mem-ratio": 0.2, + "token-mem-ratio": 0.75, "tombstone-mem-ratio": 0.5, "user-mem-ratio": 0.25, "visibility-mem-ratio": 2,