diff --git a/internal/api/auth/authorize.go b/internal/api/auth/authorize.go
index f8978de92..4977ae4f2 100644
--- a/internal/api/auth/authorize.go
+++ b/internal/api/auth/authorize.go
@@ -75,8 +75,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
return
}
- app := >smodel.Application{}
- if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil {
+ app, err := m.db.GetApplicationByClientID(c.Request.Context(), clientID)
+ if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID)
var errWithCode gtserror.WithCode
diff --git a/internal/api/auth/callback.go b/internal/api/auth/callback.go
index 8871fd2dc..6a6eecd8c 100644
--- a/internal/api/auth/callback.go
+++ b/internal/api/auth/callback.go
@@ -107,8 +107,8 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
return
}
- app := >smodel.Application{}
- if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil {
+ app, err := m.db.GetApplicationByClientID(c.Request.Context(), sessionClientID)
+ if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID)
var errWithCode gtserror.WithCode
diff --git a/internal/cache/gts.go b/internal/cache/gts.go
index f120bcf4e..8d7ebcd98 100644
--- a/internal/cache/gts.go
+++ b/internal/cache/gts.go
@@ -32,6 +32,7 @@
type GTSCaches struct {
account *result.Cache[*gtsmodel.Account]
accountNote *result.Cache[*gtsmodel.AccountNote]
+ application *result.Cache[*gtsmodel.Application]
block *result.Cache[*gtsmodel.Block]
blockIDs *SliceCache[string]
boostOfIDs *SliceCache[string]
@@ -67,6 +68,7 @@ type GTSCaches struct {
func (c *GTSCaches) Init() {
c.initAccount()
c.initAccountNote()
+ c.initApplication()
c.initBlock()
c.initBlockIDs()
c.initBoostOfIDs()
@@ -117,6 +119,11 @@ func (c *GTSCaches) AccountNote() *result.Cache[*gtsmodel.AccountNote] {
return c.accountNote
}
+// Application provides access to the gtsmodel Application database cache.
+func (c *GTSCaches) Application() *result.Cache[*gtsmodel.Application] {
+ return c.application
+}
+
// Block provides access to the gtsmodel Block (account) database cache.
func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] {
return c.block
@@ -303,6 +310,26 @@ func (c *GTSCaches) initAccountNote() {
c.accountNote.IgnoreErrors(ignoreErrors)
}
+func (c *GTSCaches) initApplication() {
+ // Calculate maximum cache size.
+ cap := calculateResultCacheMax(
+ sizeofApplication(), // model in-mem size.
+ config.GetCacheApplicationMemRatio(),
+ )
+ log.Infof(nil, "Application cache size = %d", cap)
+
+ c.application = result.New([]result.Lookup{
+ {Name: "ID"},
+ {Name: "ClientID"},
+ }, func(a1 *gtsmodel.Application) *gtsmodel.Application {
+ a2 := new(gtsmodel.Application)
+ *a2 = *a1
+ return a2
+ }, cap)
+
+ c.application.IgnoreErrors(ignoreErrors)
+}
+
func (c *GTSCaches) initBlock() {
// Calculate maximum cache size.
cap := calculateResultCacheMax(
diff --git a/internal/cache/size.go b/internal/cache/size.go
index ec7c554c0..34586b0b1 100644
--- a/internal/cache/size.go
+++ b/internal/cache/size.go
@@ -155,6 +155,7 @@ func totalOfRatios() float64 {
return 0 +
config.GetCacheAccountMemRatio() +
config.GetCacheAccountNoteMemRatio() +
+ config.GetCacheApplicationMemRatio() +
config.GetCacheBlockMemRatio() +
config.GetCacheBlockIDsMemRatio() +
config.GetCacheBoostOfIDsMemRatio() +
@@ -217,7 +218,7 @@ func sizeofAccount() uintptr {
SilencedAt: time.Now(),
SuspendedAt: time.Now(),
HideCollections: func() *bool { ok := true; return &ok }(),
- SuspensionOrigin: "",
+ SuspensionOrigin: exampleID,
EnableRSS: func() *bool { ok := true; return &ok }(),
}))
}
@@ -231,6 +232,20 @@ func sizeofAccountNote() uintptr {
}))
}
+func sizeofApplication() uintptr {
+ return uintptr(size.Of(>smodel.Application{
+ ID: exampleID,
+ CreatedAt: time.Now(),
+ UpdatedAt: time.Now(),
+ Name: exampleUsername,
+ Website: exampleURI,
+ RedirectURI: exampleURI,
+ ClientID: exampleID,
+ ClientSecret: exampleID,
+ Scopes: exampleTextSmall,
+ }))
+}
+
func sizeofBlock() uintptr {
return uintptr(size.Of(>smodel.Block{
ID: exampleID,
@@ -500,5 +515,31 @@ func sizeofVisibility() uintptr {
}
func sizeofUser() uintptr {
- return uintptr(size.Of(>smodel.User{}))
+ return uintptr(size.Of(>smodel.User{
+ ID: exampleID,
+ CreatedAt: time.Now(),
+ UpdatedAt: time.Now(),
+ Email: exampleURI,
+ AccountID: exampleID,
+ EncryptedPassword: exampleTextSmall,
+ CurrentSignInAt: time.Now(),
+ LastSignInAt: time.Now(),
+ InviteID: exampleID,
+ ChosenLanguages: []string{"en", "fr", "jp"},
+ FilteredLanguages: []string{"en", "fr", "jp"},
+ Locale: "en",
+ CreatedByApplicationID: exampleID,
+ LastEmailedAt: time.Now(),
+ ConfirmationToken: exampleTextSmall,
+ ConfirmationSentAt: time.Now(),
+ ConfirmedAt: time.Now(),
+ UnconfirmedEmail: exampleURI,
+ Moderator: func() *bool { ok := true; return &ok }(),
+ Admin: func() *bool { ok := true; return &ok }(),
+ Disabled: func() *bool { ok := true; return &ok }(),
+ Approved: func() *bool { ok := true; return &ok }(),
+ ResetPasswordToken: exampleTextSmall,
+ ResetPasswordSentAt: time.Now(),
+ ExternalID: exampleID,
+ }))
}
diff --git a/internal/config/config.go b/internal/config/config.go
index ef79d4e12..5a26222ed 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -178,6 +178,7 @@ type CacheConfiguration struct {
MemoryTarget bytesize.Size `name:"memory-target"`
AccountMemRatio float64 `name:"account-mem-ratio"`
AccountNoteMemRatio float64 `name:"account-note-mem-ratio"`
+ ApplicationMemRatio float64 `name:"application-mem-ratio"`
BlockMemRatio float64 `name:"block-mem-ratio"`
BlockIDsMemRatio float64 `name:"block-mem-ratio"`
BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"`
diff --git a/internal/config/defaults.go b/internal/config/defaults.go
index 2bc95f6f1..b78362973 100644
--- a/internal/config/defaults.go
+++ b/internal/config/defaults.go
@@ -147,6 +147,7 @@
// be able to make some more sense :D
AccountMemRatio: 18,
AccountNoteMemRatio: 0.1,
+ ApplicationMemRatio: 0.1,
BlockMemRatio: 3,
BlockIDsMemRatio: 3,
BoostOfIDsMemRatio: 3,
@@ -170,7 +171,7 @@
StatusFaveIDsMemRatio: 3,
TagMemRatio: 3,
TombstoneMemRatio: 2,
- UserMemRatio: 0.1,
+ UserMemRatio: 0.25,
WebfingerMemRatio: 0.1,
VisibilityMemRatio: 2,
},
diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go
index 0a299e7d0..03411853f 100644
--- a/internal/config/helpers.gen.go
+++ b/internal/config/helpers.gen.go
@@ -2499,6 +2499,31 @@ func GetCacheAccountNoteMemRatio() float64 { return global.GetCacheAccountNoteMe
// SetCacheAccountNoteMemRatio safely sets the value for global configuration 'Cache.AccountNoteMemRatio' field
func SetCacheAccountNoteMemRatio(v float64) { global.SetCacheAccountNoteMemRatio(v) }
+// GetCacheApplicationMemRatio safely fetches the Configuration value for state's 'Cache.ApplicationMemRatio' field
+func (st *ConfigState) GetCacheApplicationMemRatio() (v float64) {
+ st.mutex.RLock()
+ v = st.config.Cache.ApplicationMemRatio
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheApplicationMemRatio safely sets the Configuration value for state's 'Cache.ApplicationMemRatio' field
+func (st *ConfigState) SetCacheApplicationMemRatio(v float64) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.ApplicationMemRatio = v
+ st.reloadToViper()
+}
+
+// CacheApplicationMemRatioFlag returns the flag name for the 'Cache.ApplicationMemRatio' field
+func CacheApplicationMemRatioFlag() string { return "cache-application-mem-ratio" }
+
+// GetCacheApplicationMemRatio safely fetches the value for global configuration 'Cache.ApplicationMemRatio' field
+func GetCacheApplicationMemRatio() float64 { return global.GetCacheApplicationMemRatio() }
+
+// SetCacheApplicationMemRatio safely sets the value for global configuration 'Cache.ApplicationMemRatio' field
+func SetCacheApplicationMemRatio(v float64) { global.SetCacheApplicationMemRatio(v) }
+
// GetCacheBlockMemRatio safely fetches the Configuration value for state's 'Cache.BlockMemRatio' field
func (st *ConfigState) GetCacheBlockMemRatio() (v float64) {
st.mutex.RLock()
diff --git a/internal/db/application.go b/internal/db/application.go
new file mode 100644
index 000000000..34a857d3f
--- /dev/null
+++ b/internal/db/application.go
@@ -0,0 +1,38 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type Application interface {
+ // GetApplicationByID fetches the application from the database with corresponding ID value.
+ GetApplicationByID(ctx context.Context, id string) (*gtsmodel.Application, error)
+
+ // GetApplicationByClientID fetches the application from the database with corresponding client_id value.
+ GetApplicationByClientID(ctx context.Context, clientID string) (*gtsmodel.Application, error)
+
+ // PutApplication places the new application in the database, erroring on non-unique ID or client_id.
+ PutApplication(ctx context.Context, app *gtsmodel.Application) error
+
+ // DeleteApplicationByClientID deletes the application with corresponding client_id value from the database.
+ DeleteApplicationByClientID(ctx context.Context, clientID string) error
+}
diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go
new file mode 100644
index 000000000..b53d2c0b0
--- /dev/null
+++ b/internal/db/bundb/application.go
@@ -0,0 +1,97 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/uptrace/bun"
+)
+
+type applicationDB struct {
+ db *WrappedDB
+ state *state.State
+}
+
+func (a *applicationDB) GetApplicationByID(ctx context.Context, id string) (*gtsmodel.Application, error) {
+ return a.getApplication(
+ ctx,
+ "ID",
+ func(app *gtsmodel.Application) error {
+ return a.db.NewSelect().Model(app).Where("? = ?", bun.Ident("id"), id).Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (a *applicationDB) GetApplicationByClientID(ctx context.Context, clientID string) (*gtsmodel.Application, error) {
+ return a.getApplication(
+ ctx,
+ "ClientID",
+ func(app *gtsmodel.Application) error {
+ return a.db.NewSelect().Model(app).Where("? = ?", bun.Ident("client_id"), clientID).Scan(ctx)
+ },
+ clientID,
+ )
+}
+
+func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) {
+ return a.state.Caches.GTS.Application().Load(lookup, func() (*gtsmodel.Application, error) {
+ var app gtsmodel.Application
+
+ // Not cached! Perform database query.
+ if err := dbQuery(&app); err != nil {
+ return nil, a.db.ProcessError(err)
+ }
+
+ return &app, nil
+ }, keyParts...)
+}
+
+func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error {
+ return a.state.Caches.GTS.Application().Store(app, func() error {
+ _, err := a.db.NewInsert().Model(app).Exec(ctx)
+ return a.db.ProcessError(err)
+ })
+}
+
+func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientID string) error {
+ // Attempt to delete application.
+ if _, err := a.db.NewDelete().
+ Table("applications").
+ Where("? = ?", bun.Ident("client_id"), clientID).
+ Exec(ctx); err != nil {
+ return a.db.ProcessError(err)
+ }
+
+ // NOTE about further side effects:
+ //
+ // We don't need to handle updating any statuses or users
+ // (both of which may contain refs to applications), as
+ // DeleteApplication__() is only ever called during an
+ // account deletion, which handles deletion of the user
+ // and all their statuses already.
+ //
+
+ // Clear application from the cache.
+ a.state.Caches.GTS.Application().Invalidate("ClientID", clientID)
+
+ return nil
+}
diff --git a/internal/db/bundb/application_test.go b/internal/db/bundb/application_test.go
new file mode 100644
index 000000000..d2ab05ebd
--- /dev/null
+++ b/internal/db/bundb/application_test.go
@@ -0,0 +1,128 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package bundb_test
+
+import (
+ "context"
+ "errors"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type ApplicationTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *ApplicationTestSuite) TestGetApplicationBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 application models are equal.
+ isEqual := func(a1, a2 gtsmodel.Application) bool {
+ // Clear database-set fields.
+ a1.CreatedAt = time.Time{}
+ a2.CreatedAt = time.Time{}
+ a1.UpdatedAt = time.Time{}
+ a2.UpdatedAt = time.Time{}
+
+ return reflect.DeepEqual(a1, a2)
+ }
+
+ for _, app := range suite.testApplications {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.Application, error){
+ "id": func() (*gtsmodel.Application, error) {
+ return suite.db.GetApplicationByID(ctx, app.ID)
+ },
+
+ "client_id": func() (*gtsmodel.Application, error) {
+ return suite.db.GetApplicationByClientID(ctx, app.ClientID)
+ },
+ } {
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkApp, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received application data.
+ if !isEqual(*checkApp, *app) {
+ t.Errorf("application does not contain expected data: %+v", checkApp)
+ continue
+ }
+ }
+ }
+}
+
+func (suite *ApplicationTestSuite) TestDeleteApplicationBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ for _, app := range suite.testApplications {
+ for lookup, dbfunc := range map[string]func() error{
+ "client_id": func() error {
+ return suite.db.DeleteApplicationByClientID(ctx, app.ClientID)
+ },
+ } {
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ err := dbfunc()
+ if err != nil {
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Ensure this application has been deleted and cache cleared.
+ if _, err := suite.db.GetApplicationByID(ctx, app.ID); err != db.ErrNoEntries {
+ t.Errorf("application does not appear to have been deleted %q: %v", lookup, err)
+ continue
+ }
+ }
+ }
+}
+
+func TestApplicationTestSuite(t *testing.T) {
+ suite.Run(t, new(ApplicationTestSuite))
+}
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 8387bb8d1..26b31ff28 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -60,6 +60,7 @@
type DBService struct {
db.Account
db.Admin
+ db.Application
db.Basic
db.Domain
db.Emoji
@@ -168,6 +169,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
+ Application: &applicationDB{
+ db: db,
+ state: state,
+ },
Basic: &basicDB{
db: db,
},
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index c6091e2c9..311732299 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -37,19 +37,12 @@ type statusDB struct {
state *state.State
}
-func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
- return s.db.
- NewSelect().
- Model(status).
- Relation("CreatedWithApplication")
-}
-
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) {
return s.getStatus(
ctx,
"ID",
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)
+ return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)
},
id,
)
@@ -78,7 +71,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
ctx,
"URI",
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)
+ return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)
},
uri,
)
@@ -89,7 +82,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St
ctx,
"URL",
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)
+ return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)
},
url,
)
@@ -100,7 +93,7 @@ func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccou
ctx,
"BoostOfID.AccountID",
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).
+ return s.db.NewSelect().Model(status).
Where("status.boost_of_id = ?", boostOfID).
Where("status.account_id = ?", byAccountID).
@@ -264,6 +257,17 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
}
}
+ if status.CreatedWithApplicationID != "" && status.CreatedWithApplication == nil {
+ // Populate the status' expected CreatedWithApplication (not always set).
+ status.CreatedWithApplication, err = s.state.DB.GetApplicationByID(
+ ctx, // these are already barebones
+ status.CreatedWithApplicationID,
+ )
+ if err != nil {
+ errs.Appendf("error populating status application: %w", err)
+ }
+ }
+
return errs.Combine()
}
diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go
index 4b38d48fa..9df05596e 100644
--- a/internal/db/bundb/user.go
+++ b/internal/db/bundb/user.go
@@ -24,6 +24,7 @@
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
@@ -35,107 +36,125 @@ type userDB struct {
}
func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) {
- return u.state.Caches.GTS.User().Load("ID", func() (*gtsmodel.User, error) {
- var user gtsmodel.User
+ return u.getUser(
+ ctx,
+ "ID",
+ func(user *gtsmodel.User) error {
+ return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("id"), id).Scan(ctx)
+ },
+ id,
+ )
+}
- q := u.db.
- NewSelect().
- Model(&user).
- Relation("Account").
- Where("? = ?", bun.Ident("user.id"), id)
+func (u *userDB) GetUsersByIDs(ctx context.Context, ids []string) ([]*gtsmodel.User, error) {
+ var (
+ users = make([]*gtsmodel.User, 0, len(ids))
- if err := q.Scan(ctx); err != nil {
- return nil, u.db.ProcessError(err)
+ // Collect errors instead of
+ // returning early on any.
+ errs gtserror.MultiError
+ )
+
+ for _, id := range ids {
+ // Attempt to fetch user from DB.
+ user, err := u.GetUserByID(ctx, id)
+ if err != nil {
+ errs.Appendf("error getting user %s: %w", id, err)
+ continue
}
- return &user, nil
- }, id)
+ // Append user to return slice.
+ users = append(users, user)
+ }
+
+ return users, errs.Combine()
}
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) {
- return u.state.Caches.GTS.User().Load("AccountID", func() (*gtsmodel.User, error) {
- var user gtsmodel.User
-
- q := u.db.
- NewSelect().
- Model(&user).
- Relation("Account").
- Where("? = ?", bun.Ident("user.account_id"), accountID)
-
- if err := q.Scan(ctx); err != nil {
- return nil, u.db.ProcessError(err)
- }
-
- return &user, nil
- }, accountID)
+ return u.getUser(
+ ctx,
+ "AccountID",
+ func(user *gtsmodel.User) error {
+ return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("account_id"), accountID).Scan(ctx)
+ },
+ accountID,
+ )
}
-func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error) {
- return u.state.Caches.GTS.User().Load("Email", func() (*gtsmodel.User, error) {
- var user gtsmodel.User
-
- q := u.db.
- NewSelect().
- Model(&user).
- Relation("Account").
- Where("? = ?", bun.Ident("user.email"), emailAddress)
-
- if err := q.Scan(ctx); err != nil {
- return nil, u.db.ProcessError(err)
- }
-
- return &user, nil
- }, emailAddress)
+func (u *userDB) GetUserByEmailAddress(ctx context.Context, email string) (*gtsmodel.User, error) {
+ return u.getUser(
+ ctx,
+ "Email",
+ func(user *gtsmodel.User) error {
+ return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("email"), email).Scan(ctx)
+ },
+ email,
+ )
}
func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) {
- return u.state.Caches.GTS.User().Load("ExternalID", func() (*gtsmodel.User, error) {
- var user gtsmodel.User
-
- q := u.db.
- NewSelect().
- Model(&user).
- Relation("Account").
- Where("? = ?", bun.Ident("user.external_id"), id)
-
- if err := q.Scan(ctx); err != nil {
- return nil, u.db.ProcessError(err)
- }
-
- return &user, nil
- }, id)
+ return u.getUser(
+ ctx,
+ "ExternalID",
+ func(user *gtsmodel.User) error {
+ return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("external_id"), id).Scan(ctx)
+ },
+ id,
+ )
}
-func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error) {
- return u.state.Caches.GTS.User().Load("ConfirmationToken", func() (*gtsmodel.User, error) {
+func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) (*gtsmodel.User, error) {
+ return u.getUser(
+ ctx,
+ "ConfirmationToken",
+ func(user *gtsmodel.User) error {
+ return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("confirmation_token"), token).Scan(ctx)
+ },
+ token,
+ )
+}
+
+func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) {
+ // Fetch user from database cache with loader callback.
+ user, err := u.state.Caches.GTS.User().Load(lookup, func() (*gtsmodel.User, error) {
var user gtsmodel.User
- q := u.db.
- NewSelect().
- Model(&user).
- Relation("Account").
- Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken)
-
- if err := q.Scan(ctx); err != nil {
+ // Not cached! perform database query.
+ if err := dbQuery(&user); err != nil {
return nil, u.db.ProcessError(err)
}
return &user, nil
- }, confirmationToken)
+ }, keyParts...)
+ if err != nil {
+ return nil, err
+ }
+
+ // Fetch the related account model for this user.
+ user.Account, err = u.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ user.AccountID,
+ )
+ if err != nil {
+ return nil, gtserror.Newf("error populating user account: %w", err)
+ }
+
+ return user, nil
}
func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
- var users []*gtsmodel.User
- q := u.db.
- NewSelect().
- Model(&users).
- Relation("Account")
+ var userIDs []string
- if err := q.Scan(ctx); err != nil {
+ // Scan all user IDs into slice.
+ if err := u.db.NewSelect().
+ Table("users").
+ Column("id").
+ Scan(ctx, &userIDs); err != nil {
return nil, u.db.ProcessError(err)
}
- return users, nil
+ // Transform user IDs into user slice.
+ return u.GetUsersByIDs(ctx, userIDs)
}
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error {
diff --git a/internal/db/db.go b/internal/db/db.go
index 7c00050ff..567551c73 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -26,6 +26,7 @@
type DB interface {
Account
Admin
+ Application
Basic
Domain
Emoji
diff --git a/internal/middleware/tokencheck.go b/internal/middleware/tokencheck.go
index 1496363af..d2570c3f0 100644
--- a/internal/middleware/tokencheck.go
+++ b/internal/middleware/tokencheck.go
@@ -22,7 +22,6 @@
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/oauth2/v4"
@@ -125,8 +124,8 @@ func TokenCheck(dbConn db.DB, validateBearerToken func(r *http.Request) (oauth2.
log.Tracef(ctx, "authenticated client %s with bearer token, scope is %s", clientID, ti.GetScope())
// fetch app for this token
- app := >smodel.Application{}
- if err := dbConn.GetWhere(ctx, []db.Where{{Key: "client_id", Value: clientID}}, app); err != nil {
+ app, err := dbConn.GetApplicationByClientID(ctx, clientID)
+ if err != nil {
if err != db.ErrNoEntries {
log.Errorf(ctx, "database error looking for application with clientID %s: %s", clientID, err)
return
@@ -134,6 +133,7 @@ func TokenCheck(dbConn db.DB, validateBearerToken func(r *http.Request) (oauth2.
log.Warnf(ctx, "no app found for client %s", clientID)
return
}
+
c.Set(oauth.SessionAuthorizedApplication, app)
}
}
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index d2483aeb1..da13eb20e 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -46,12 +46,6 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
}...)
l.Trace("beginning account delete process")
- if account.IsLocal() {
- if err := p.deleteUserAndTokensForAccount(ctx, account); err != nil {
- return gtserror.NewErrorInternalError(err)
- }
- }
-
if err := p.deleteAccountFollows(ctx, account); err != nil {
return gtserror.NewErrorInternalError(err)
}
@@ -72,6 +66,14 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
return gtserror.NewErrorInternalError(err)
}
+ if account.IsLocal() {
+ // we tokens, applications and clients for account as one of the last
+ // stages during deletion, as other database models rely on these.
+ if err := p.deleteUserAndTokensForAccount(ctx, account); err != nil {
+ return gtserror.NewErrorInternalError(err)
+ }
+ }
+
// To prevent the account being created again,
// stubbify it and update it in the db.
// The account will not be deleted, but it
@@ -129,7 +131,7 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *
}
// Delete any OAuth applications associated with this token.
- if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, &[]*gtsmodel.Application{}); err != nil {
+ if err := p.state.DB.DeleteApplicationByClientID(ctx, t.ClientID); err != nil {
return gtserror.Newf("db error deleting application: %w", err)
}
@@ -305,7 +307,17 @@ func (p *Processor) deleteAccountStatuses(ctx context.Context, account *gtsmodel
statusLoop:
for {
// Page through account's statuses.
- statuses, err = p.state.DB.GetAccountStatuses(ctx, account.ID, deleteSelectLimit, false, false, maxID, "", false, false)
+ statuses, err = p.state.DB.GetAccountStatuses(
+ ctx,
+ account.ID,
+ deleteSelectLimit,
+ false,
+ false,
+ maxID,
+ "",
+ false,
+ false,
+ )
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Make sure we don't have a real error.
return err
diff --git a/internal/processing/app.go b/internal/processing/app.go
index 07739ce92..d4a923e8a 100644
--- a/internal/processing/app.go
+++ b/internal/processing/app.go
@@ -61,7 +61,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api
}
// chuck it in the db
- if err := p.state.DB.Put(ctx, app); err != nil {
+ if err := p.state.DB.PutApplication(ctx, app); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go
index 2dc0e4dd5..ab04f6ccc 100644
--- a/internal/typeutils/internaltofrontend.go
+++ b/internal/typeutils/internaltofrontend.go
@@ -699,8 +699,8 @@ func (c *converter) StatusToAPIStatus(ctx context.Context, s *gtsmodel.Status, r
}
if appID := s.CreatedWithApplicationID; appID != "" {
- app := >smodel.Application{}
- if err := c.db.GetByID(ctx, appID, app); err != nil {
+ app, err := c.db.GetApplicationByID(ctx, appID)
+ if err != nil {
return nil, fmt.Errorf("error getting application %s: %w", appID, err)
}
diff --git a/test/envparsing.sh b/test/envparsing.sh
index 59c69e1b5..d9270e7f6 100755
--- a/test/envparsing.sh
+++ b/test/envparsing.sh
@@ -20,6 +20,7 @@ EXPECT=$(cat << "EOF"
"cache": {
"account-mem-ratio": 18,
"account-note-mem-ratio": 0.1,
+ "application-mem-ratio": 0.1,
"block-mem-ratio": 3,
"boost-of-ids-mem-ratio": 3,
"emoji-category-mem-ratio": 0.1,
@@ -43,7 +44,7 @@ EXPECT=$(cat << "EOF"
"status-mem-ratio": 18,
"tag-mem-ratio": 3,
"tombstone-mem-ratio": 2,
- "user-mem-ratio": 0.1,
+ "user-mem-ratio": 0.25,
"visibility-mem-ratio": 2,
"webfinger-mem-ratio": 0.1
},