diff --git a/internal/api/auth/authorize.go b/internal/api/auth/authorize.go index f8978de92..4977ae4f2 100644 --- a/internal/api/auth/authorize.go +++ b/internal/api/auth/authorize.go @@ -75,8 +75,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { return } - app := >smodel.Application{} - if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { + app, err := m.db.GetApplicationByClientID(c.Request.Context(), clientID) + if err != nil { m.clearSession(s) safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) var errWithCode gtserror.WithCode diff --git a/internal/api/auth/callback.go b/internal/api/auth/callback.go index 8871fd2dc..6a6eecd8c 100644 --- a/internal/api/auth/callback.go +++ b/internal/api/auth/callback.go @@ -107,8 +107,8 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { return } - app := >smodel.Application{} - if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { + app, err := m.db.GetApplicationByClientID(c.Request.Context(), sessionClientID) + if err != nil { m.clearSession(s) safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) var errWithCode gtserror.WithCode diff --git a/internal/cache/gts.go b/internal/cache/gts.go index f120bcf4e..8d7ebcd98 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -32,6 +32,7 @@ type GTSCaches struct { account *result.Cache[*gtsmodel.Account] accountNote *result.Cache[*gtsmodel.AccountNote] + application *result.Cache[*gtsmodel.Application] block *result.Cache[*gtsmodel.Block] blockIDs *SliceCache[string] boostOfIDs *SliceCache[string] @@ -67,6 +68,7 @@ type GTSCaches struct { func (c *GTSCaches) Init() { c.initAccount() c.initAccountNote() + c.initApplication() c.initBlock() c.initBlockIDs() c.initBoostOfIDs() @@ -117,6 +119,11 @@ func (c *GTSCaches) AccountNote() *result.Cache[*gtsmodel.AccountNote] { return c.accountNote } +// Application provides access to the gtsmodel Application database cache. +func (c *GTSCaches) Application() *result.Cache[*gtsmodel.Application] { + return c.application +} + // Block provides access to the gtsmodel Block (account) database cache. func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] { return c.block @@ -303,6 +310,26 @@ func (c *GTSCaches) initAccountNote() { c.accountNote.IgnoreErrors(ignoreErrors) } +func (c *GTSCaches) initApplication() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofApplication(), // model in-mem size. + config.GetCacheApplicationMemRatio(), + ) + log.Infof(nil, "Application cache size = %d", cap) + + c.application = result.New([]result.Lookup{ + {Name: "ID"}, + {Name: "ClientID"}, + }, func(a1 *gtsmodel.Application) *gtsmodel.Application { + a2 := new(gtsmodel.Application) + *a2 = *a1 + return a2 + }, cap) + + c.application.IgnoreErrors(ignoreErrors) +} + func (c *GTSCaches) initBlock() { // Calculate maximum cache size. cap := calculateResultCacheMax( diff --git a/internal/cache/size.go b/internal/cache/size.go index ec7c554c0..34586b0b1 100644 --- a/internal/cache/size.go +++ b/internal/cache/size.go @@ -155,6 +155,7 @@ func totalOfRatios() float64 { return 0 + config.GetCacheAccountMemRatio() + config.GetCacheAccountNoteMemRatio() + + config.GetCacheApplicationMemRatio() + config.GetCacheBlockMemRatio() + config.GetCacheBlockIDsMemRatio() + config.GetCacheBoostOfIDsMemRatio() + @@ -217,7 +218,7 @@ func sizeofAccount() uintptr { SilencedAt: time.Now(), SuspendedAt: time.Now(), HideCollections: func() *bool { ok := true; return &ok }(), - SuspensionOrigin: "", + SuspensionOrigin: exampleID, EnableRSS: func() *bool { ok := true; return &ok }(), })) } @@ -231,6 +232,20 @@ func sizeofAccountNote() uintptr { })) } +func sizeofApplication() uintptr { + return uintptr(size.Of(>smodel.Application{ + ID: exampleID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + Name: exampleUsername, + Website: exampleURI, + RedirectURI: exampleURI, + ClientID: exampleID, + ClientSecret: exampleID, + Scopes: exampleTextSmall, + })) +} + func sizeofBlock() uintptr { return uintptr(size.Of(>smodel.Block{ ID: exampleID, @@ -500,5 +515,31 @@ func sizeofVisibility() uintptr { } func sizeofUser() uintptr { - return uintptr(size.Of(>smodel.User{})) + return uintptr(size.Of(>smodel.User{ + ID: exampleID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + Email: exampleURI, + AccountID: exampleID, + EncryptedPassword: exampleTextSmall, + CurrentSignInAt: time.Now(), + LastSignInAt: time.Now(), + InviteID: exampleID, + ChosenLanguages: []string{"en", "fr", "jp"}, + FilteredLanguages: []string{"en", "fr", "jp"}, + Locale: "en", + CreatedByApplicationID: exampleID, + LastEmailedAt: time.Now(), + ConfirmationToken: exampleTextSmall, + ConfirmationSentAt: time.Now(), + ConfirmedAt: time.Now(), + UnconfirmedEmail: exampleURI, + Moderator: func() *bool { ok := true; return &ok }(), + Admin: func() *bool { ok := true; return &ok }(), + Disabled: func() *bool { ok := true; return &ok }(), + Approved: func() *bool { ok := true; return &ok }(), + ResetPasswordToken: exampleTextSmall, + ResetPasswordSentAt: time.Now(), + ExternalID: exampleID, + })) } diff --git a/internal/config/config.go b/internal/config/config.go index ef79d4e12..5a26222ed 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -178,6 +178,7 @@ type CacheConfiguration struct { MemoryTarget bytesize.Size `name:"memory-target"` AccountMemRatio float64 `name:"account-mem-ratio"` AccountNoteMemRatio float64 `name:"account-note-mem-ratio"` + ApplicationMemRatio float64 `name:"application-mem-ratio"` BlockMemRatio float64 `name:"block-mem-ratio"` BlockIDsMemRatio float64 `name:"block-mem-ratio"` BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 2bc95f6f1..b78362973 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -147,6 +147,7 @@ // be able to make some more sense :D AccountMemRatio: 18, AccountNoteMemRatio: 0.1, + ApplicationMemRatio: 0.1, BlockMemRatio: 3, BlockIDsMemRatio: 3, BoostOfIDsMemRatio: 3, @@ -170,7 +171,7 @@ StatusFaveIDsMemRatio: 3, TagMemRatio: 3, TombstoneMemRatio: 2, - UserMemRatio: 0.1, + UserMemRatio: 0.25, WebfingerMemRatio: 0.1, VisibilityMemRatio: 2, }, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index 0a299e7d0..03411853f 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2499,6 +2499,31 @@ func GetCacheAccountNoteMemRatio() float64 { return global.GetCacheAccountNoteMe // SetCacheAccountNoteMemRatio safely sets the value for global configuration 'Cache.AccountNoteMemRatio' field func SetCacheAccountNoteMemRatio(v float64) { global.SetCacheAccountNoteMemRatio(v) } +// GetCacheApplicationMemRatio safely fetches the Configuration value for state's 'Cache.ApplicationMemRatio' field +func (st *ConfigState) GetCacheApplicationMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.ApplicationMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheApplicationMemRatio safely sets the Configuration value for state's 'Cache.ApplicationMemRatio' field +func (st *ConfigState) SetCacheApplicationMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.ApplicationMemRatio = v + st.reloadToViper() +} + +// CacheApplicationMemRatioFlag returns the flag name for the 'Cache.ApplicationMemRatio' field +func CacheApplicationMemRatioFlag() string { return "cache-application-mem-ratio" } + +// GetCacheApplicationMemRatio safely fetches the value for global configuration 'Cache.ApplicationMemRatio' field +func GetCacheApplicationMemRatio() float64 { return global.GetCacheApplicationMemRatio() } + +// SetCacheApplicationMemRatio safely sets the value for global configuration 'Cache.ApplicationMemRatio' field +func SetCacheApplicationMemRatio(v float64) { global.SetCacheApplicationMemRatio(v) } + // GetCacheBlockMemRatio safely fetches the Configuration value for state's 'Cache.BlockMemRatio' field func (st *ConfigState) GetCacheBlockMemRatio() (v float64) { st.mutex.RLock() diff --git a/internal/db/application.go b/internal/db/application.go new file mode 100644 index 000000000..34a857d3f --- /dev/null +++ b/internal/db/application.go @@ -0,0 +1,38 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type Application interface { + // GetApplicationByID fetches the application from the database with corresponding ID value. + GetApplicationByID(ctx context.Context, id string) (*gtsmodel.Application, error) + + // GetApplicationByClientID fetches the application from the database with corresponding client_id value. + GetApplicationByClientID(ctx context.Context, clientID string) (*gtsmodel.Application, error) + + // PutApplication places the new application in the database, erroring on non-unique ID or client_id. + PutApplication(ctx context.Context, app *gtsmodel.Application) error + + // DeleteApplicationByClientID deletes the application with corresponding client_id value from the database. + DeleteApplicationByClientID(ctx context.Context, clientID string) error +} diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go new file mode 100644 index 000000000..b53d2c0b0 --- /dev/null +++ b/internal/db/bundb/application.go @@ -0,0 +1,97 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/uptrace/bun" +) + +type applicationDB struct { + db *WrappedDB + state *state.State +} + +func (a *applicationDB) GetApplicationByID(ctx context.Context, id string) (*gtsmodel.Application, error) { + return a.getApplication( + ctx, + "ID", + func(app *gtsmodel.Application) error { + return a.db.NewSelect().Model(app).Where("? = ?", bun.Ident("id"), id).Scan(ctx) + }, + id, + ) +} + +func (a *applicationDB) GetApplicationByClientID(ctx context.Context, clientID string) (*gtsmodel.Application, error) { + return a.getApplication( + ctx, + "ClientID", + func(app *gtsmodel.Application) error { + return a.db.NewSelect().Model(app).Where("? = ?", bun.Ident("client_id"), clientID).Scan(ctx) + }, + clientID, + ) +} + +func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) { + return a.state.Caches.GTS.Application().Load(lookup, func() (*gtsmodel.Application, error) { + var app gtsmodel.Application + + // Not cached! Perform database query. + if err := dbQuery(&app); err != nil { + return nil, a.db.ProcessError(err) + } + + return &app, nil + }, keyParts...) +} + +func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error { + return a.state.Caches.GTS.Application().Store(app, func() error { + _, err := a.db.NewInsert().Model(app).Exec(ctx) + return a.db.ProcessError(err) + }) +} + +func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientID string) error { + // Attempt to delete application. + if _, err := a.db.NewDelete(). + Table("applications"). + Where("? = ?", bun.Ident("client_id"), clientID). + Exec(ctx); err != nil { + return a.db.ProcessError(err) + } + + // NOTE about further side effects: + // + // We don't need to handle updating any statuses or users + // (both of which may contain refs to applications), as + // DeleteApplication__() is only ever called during an + // account deletion, which handles deletion of the user + // and all their statuses already. + // + + // Clear application from the cache. + a.state.Caches.GTS.Application().Invalidate("ClientID", clientID) + + return nil +} diff --git a/internal/db/bundb/application_test.go b/internal/db/bundb/application_test.go new file mode 100644 index 000000000..d2ab05ebd --- /dev/null +++ b/internal/db/bundb/application_test.go @@ -0,0 +1,128 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb_test + +import ( + "context" + "errors" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type ApplicationTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *ApplicationTestSuite) TestGetApplicationBy() { + t := suite.T() + + // Create a new context for this test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Sentinel error to mark avoiding a test case. + sentinelErr := errors.New("sentinel") + + // isEqual checks if 2 application models are equal. + isEqual := func(a1, a2 gtsmodel.Application) bool { + // Clear database-set fields. + a1.CreatedAt = time.Time{} + a2.CreatedAt = time.Time{} + a1.UpdatedAt = time.Time{} + a2.UpdatedAt = time.Time{} + + return reflect.DeepEqual(a1, a2) + } + + for _, app := range suite.testApplications { + for lookup, dbfunc := range map[string]func() (*gtsmodel.Application, error){ + "id": func() (*gtsmodel.Application, error) { + return suite.db.GetApplicationByID(ctx, app.ID) + }, + + "client_id": func() (*gtsmodel.Application, error) { + return suite.db.GetApplicationByClientID(ctx, app.ClientID) + }, + } { + // Clear database caches. + suite.state.Caches.Init() + + t.Logf("checking database lookup %q", lookup) + + // Perform database function. + checkApp, err := dbfunc() + if err != nil { + if err == sentinelErr { + continue + } + + t.Errorf("error encountered for database lookup %q: %v", lookup, err) + continue + } + + // Check received application data. + if !isEqual(*checkApp, *app) { + t.Errorf("application does not contain expected data: %+v", checkApp) + continue + } + } + } +} + +func (suite *ApplicationTestSuite) TestDeleteApplicationBy() { + t := suite.T() + + // Create a new context for this test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + for _, app := range suite.testApplications { + for lookup, dbfunc := range map[string]func() error{ + "client_id": func() error { + return suite.db.DeleteApplicationByClientID(ctx, app.ClientID) + }, + } { + // Clear database caches. + suite.state.Caches.Init() + + t.Logf("checking database lookup %q", lookup) + + // Perform database function. + err := dbfunc() + if err != nil { + t.Errorf("error encountered for database lookup %q: %v", lookup, err) + continue + } + + // Ensure this application has been deleted and cache cleared. + if _, err := suite.db.GetApplicationByID(ctx, app.ID); err != db.ErrNoEntries { + t.Errorf("application does not appear to have been deleted %q: %v", lookup, err) + continue + } + } + } +} + +func TestApplicationTestSuite(t *testing.T) { + suite.Run(t, new(ApplicationTestSuite)) +} diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 8387bb8d1..26b31ff28 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -60,6 +60,7 @@ type DBService struct { db.Account db.Admin + db.Application db.Basic db.Domain db.Emoji @@ -168,6 +169,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { db: db, state: state, }, + Application: &applicationDB{ + db: db, + state: state, + }, Basic: &basicDB{ db: db, }, diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index c6091e2c9..311732299 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -37,19 +37,12 @@ type statusDB struct { state *state.State } -func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { - return s.db. - NewSelect(). - Model(status). - Relation("CreatedWithApplication") -} - func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) { return s.getStatus( ctx, "ID", func(status *gtsmodel.Status) error { - return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) + return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) }, id, ) @@ -78,7 +71,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St ctx, "URI", func(status *gtsmodel.Status) error { - return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) + return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) }, uri, ) @@ -89,7 +82,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St ctx, "URL", func(status *gtsmodel.Status) error { - return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) + return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) }, url, ) @@ -100,7 +93,7 @@ func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccou ctx, "BoostOfID.AccountID", func(status *gtsmodel.Status) error { - return s.newStatusQ(status). + return s.db.NewSelect().Model(status). Where("status.boost_of_id = ?", boostOfID). Where("status.account_id = ?", byAccountID). @@ -264,6 +257,17 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) } } + if status.CreatedWithApplicationID != "" && status.CreatedWithApplication == nil { + // Populate the status' expected CreatedWithApplication (not always set). + status.CreatedWithApplication, err = s.state.DB.GetApplicationByID( + ctx, // these are already barebones + status.CreatedWithApplicationID, + ) + if err != nil { + errs.Appendf("error populating status application: %w", err) + } + } + return errs.Combine() } diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index 4b38d48fa..9df05596e 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -24,6 +24,7 @@ "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/uptrace/bun" @@ -35,107 +36,125 @@ type userDB struct { } func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) { - return u.state.Caches.GTS.User().Load("ID", func() (*gtsmodel.User, error) { - var user gtsmodel.User + return u.getUser( + ctx, + "ID", + func(user *gtsmodel.User) error { + return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("id"), id).Scan(ctx) + }, + id, + ) +} - q := u.db. - NewSelect(). - Model(&user). - Relation("Account"). - Where("? = ?", bun.Ident("user.id"), id) +func (u *userDB) GetUsersByIDs(ctx context.Context, ids []string) ([]*gtsmodel.User, error) { + var ( + users = make([]*gtsmodel.User, 0, len(ids)) - if err := q.Scan(ctx); err != nil { - return nil, u.db.ProcessError(err) + // Collect errors instead of + // returning early on any. + errs gtserror.MultiError + ) + + for _, id := range ids { + // Attempt to fetch user from DB. + user, err := u.GetUserByID(ctx, id) + if err != nil { + errs.Appendf("error getting user %s: %w", id, err) + continue } - return &user, nil - }, id) + // Append user to return slice. + users = append(users, user) + } + + return users, errs.Combine() } func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) { - return u.state.Caches.GTS.User().Load("AccountID", func() (*gtsmodel.User, error) { - var user gtsmodel.User - - q := u.db. - NewSelect(). - Model(&user). - Relation("Account"). - Where("? = ?", bun.Ident("user.account_id"), accountID) - - if err := q.Scan(ctx); err != nil { - return nil, u.db.ProcessError(err) - } - - return &user, nil - }, accountID) + return u.getUser( + ctx, + "AccountID", + func(user *gtsmodel.User) error { + return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("account_id"), accountID).Scan(ctx) + }, + accountID, + ) } -func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error) { - return u.state.Caches.GTS.User().Load("Email", func() (*gtsmodel.User, error) { - var user gtsmodel.User - - q := u.db. - NewSelect(). - Model(&user). - Relation("Account"). - Where("? = ?", bun.Ident("user.email"), emailAddress) - - if err := q.Scan(ctx); err != nil { - return nil, u.db.ProcessError(err) - } - - return &user, nil - }, emailAddress) +func (u *userDB) GetUserByEmailAddress(ctx context.Context, email string) (*gtsmodel.User, error) { + return u.getUser( + ctx, + "Email", + func(user *gtsmodel.User) error { + return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("email"), email).Scan(ctx) + }, + email, + ) } func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) { - return u.state.Caches.GTS.User().Load("ExternalID", func() (*gtsmodel.User, error) { - var user gtsmodel.User - - q := u.db. - NewSelect(). - Model(&user). - Relation("Account"). - Where("? = ?", bun.Ident("user.external_id"), id) - - if err := q.Scan(ctx); err != nil { - return nil, u.db.ProcessError(err) - } - - return &user, nil - }, id) + return u.getUser( + ctx, + "ExternalID", + func(user *gtsmodel.User) error { + return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("external_id"), id).Scan(ctx) + }, + id, + ) } -func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error) { - return u.state.Caches.GTS.User().Load("ConfirmationToken", func() (*gtsmodel.User, error) { +func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) (*gtsmodel.User, error) { + return u.getUser( + ctx, + "ConfirmationToken", + func(user *gtsmodel.User) error { + return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("confirmation_token"), token).Scan(ctx) + }, + token, + ) +} + +func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) { + // Fetch user from database cache with loader callback. + user, err := u.state.Caches.GTS.User().Load(lookup, func() (*gtsmodel.User, error) { var user gtsmodel.User - q := u.db. - NewSelect(). - Model(&user). - Relation("Account"). - Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken) - - if err := q.Scan(ctx); err != nil { + // Not cached! perform database query. + if err := dbQuery(&user); err != nil { return nil, u.db.ProcessError(err) } return &user, nil - }, confirmationToken) + }, keyParts...) + if err != nil { + return nil, err + } + + // Fetch the related account model for this user. + user.Account, err = u.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + user.AccountID, + ) + if err != nil { + return nil, gtserror.Newf("error populating user account: %w", err) + } + + return user, nil } func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { - var users []*gtsmodel.User - q := u.db. - NewSelect(). - Model(&users). - Relation("Account") + var userIDs []string - if err := q.Scan(ctx); err != nil { + // Scan all user IDs into slice. + if err := u.db.NewSelect(). + Table("users"). + Column("id"). + Scan(ctx, &userIDs); err != nil { return nil, u.db.ProcessError(err) } - return users, nil + // Transform user IDs into user slice. + return u.GetUsersByIDs(ctx, userIDs) } func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error { diff --git a/internal/db/db.go b/internal/db/db.go index 7c00050ff..567551c73 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -26,6 +26,7 @@ type DB interface { Account Admin + Application Basic Domain Emoji diff --git a/internal/middleware/tokencheck.go b/internal/middleware/tokencheck.go index 1496363af..d2570c3f0 100644 --- a/internal/middleware/tokencheck.go +++ b/internal/middleware/tokencheck.go @@ -22,7 +22,6 @@ "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/oauth2/v4" @@ -125,8 +124,8 @@ func TokenCheck(dbConn db.DB, validateBearerToken func(r *http.Request) (oauth2. log.Tracef(ctx, "authenticated client %s with bearer token, scope is %s", clientID, ti.GetScope()) // fetch app for this token - app := >smodel.Application{} - if err := dbConn.GetWhere(ctx, []db.Where{{Key: "client_id", Value: clientID}}, app); err != nil { + app, err := dbConn.GetApplicationByClientID(ctx, clientID) + if err != nil { if err != db.ErrNoEntries { log.Errorf(ctx, "database error looking for application with clientID %s: %s", clientID, err) return @@ -134,6 +133,7 @@ func TokenCheck(dbConn db.DB, validateBearerToken func(r *http.Request) (oauth2. log.Warnf(ctx, "no app found for client %s", clientID) return } + c.Set(oauth.SessionAuthorizedApplication, app) } } diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index d2483aeb1..da13eb20e 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -46,12 +46,6 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi }...) l.Trace("beginning account delete process") - if account.IsLocal() { - if err := p.deleteUserAndTokensForAccount(ctx, account); err != nil { - return gtserror.NewErrorInternalError(err) - } - } - if err := p.deleteAccountFollows(ctx, account); err != nil { return gtserror.NewErrorInternalError(err) } @@ -72,6 +66,14 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi return gtserror.NewErrorInternalError(err) } + if account.IsLocal() { + // we tokens, applications and clients for account as one of the last + // stages during deletion, as other database models rely on these. + if err := p.deleteUserAndTokensForAccount(ctx, account); err != nil { + return gtserror.NewErrorInternalError(err) + } + } + // To prevent the account being created again, // stubbify it and update it in the db. // The account will not be deleted, but it @@ -129,7 +131,7 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account * } // Delete any OAuth applications associated with this token. - if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, &[]*gtsmodel.Application{}); err != nil { + if err := p.state.DB.DeleteApplicationByClientID(ctx, t.ClientID); err != nil { return gtserror.Newf("db error deleting application: %w", err) } @@ -305,7 +307,17 @@ func (p *Processor) deleteAccountStatuses(ctx context.Context, account *gtsmodel statusLoop: for { // Page through account's statuses. - statuses, err = p.state.DB.GetAccountStatuses(ctx, account.ID, deleteSelectLimit, false, false, maxID, "", false, false) + statuses, err = p.state.DB.GetAccountStatuses( + ctx, + account.ID, + deleteSelectLimit, + false, + false, + maxID, + "", + false, + false, + ) if err != nil && !errors.Is(err, db.ErrNoEntries) { // Make sure we don't have a real error. return err diff --git a/internal/processing/app.go b/internal/processing/app.go index 07739ce92..d4a923e8a 100644 --- a/internal/processing/app.go +++ b/internal/processing/app.go @@ -61,7 +61,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api } // chuck it in the db - if err := p.state.DB.Put(ctx, app); err != nil { + if err := p.state.DB.PutApplication(ctx, app); err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index 2dc0e4dd5..ab04f6ccc 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -699,8 +699,8 @@ func (c *converter) StatusToAPIStatus(ctx context.Context, s *gtsmodel.Status, r } if appID := s.CreatedWithApplicationID; appID != "" { - app := >smodel.Application{} - if err := c.db.GetByID(ctx, appID, app); err != nil { + app, err := c.db.GetApplicationByID(ctx, appID) + if err != nil { return nil, fmt.Errorf("error getting application %s: %w", appID, err) } diff --git a/test/envparsing.sh b/test/envparsing.sh index 59c69e1b5..d9270e7f6 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -20,6 +20,7 @@ EXPECT=$(cat << "EOF" "cache": { "account-mem-ratio": 18, "account-note-mem-ratio": 0.1, + "application-mem-ratio": 0.1, "block-mem-ratio": 3, "boost-of-ids-mem-ratio": 3, "emoji-category-mem-ratio": 0.1, @@ -43,7 +44,7 @@ EXPECT=$(cat << "EOF" "status-mem-ratio": 18, "tag-mem-ratio": 3, "tombstone-mem-ratio": 2, - "user-mem-ratio": 0.1, + "user-mem-ratio": 0.25, "visibility-mem-ratio": 2, "webfinger-mem-ratio": 0.1 },