diff --git a/cmd/gotosocial/action/admin/account/account.go b/cmd/gotosocial/action/admin/account/account.go
index c5ac1073e..f2cce57b5 100644
--- a/cmd/gotosocial/action/admin/account/account.go
+++ b/cmd/gotosocial/action/admin/account/account.go
@@ -26,9 +26,7 @@
"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/validate"
"golang.org/x/crypto/bcrypt"
)
@@ -92,8 +90,8 @@
return err
}
- u := >smodel.User{}
- if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ u, err := dbConn.GetUserByAccountID(ctx, a.ID)
+ if err != nil {
return err
}
@@ -130,8 +128,8 @@
return err
}
- u := >smodel.User{}
- if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ u, err := dbConn.GetUserByAccountID(ctx, a.ID)
+ if err != nil {
return err
}
@@ -139,7 +137,7 @@
admin := true
u.Admin = &admin
u.UpdatedAt = time.Now()
- if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
+ if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}
@@ -166,8 +164,8 @@
return err
}
- u := >smodel.User{}
- if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ u, err := dbConn.GetUserByAccountID(ctx, a.ID)
+ if err != nil {
return err
}
@@ -175,7 +173,7 @@
admin := false
u.Admin = &admin
u.UpdatedAt = time.Now()
- if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
+ if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}
@@ -202,8 +200,8 @@
return err
}
- u := >smodel.User{}
- if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ u, err := dbConn.GetUserByAccountID(ctx, a.ID)
+ if err != nil {
return err
}
@@ -211,7 +209,7 @@
disabled := true
u.Disabled = &disabled
u.UpdatedAt = time.Now()
- if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
+ if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}
@@ -252,8 +250,8 @@
return err
}
- u := >smodel.User{}
- if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ u, err := dbConn.GetUserByAccountID(ctx, a.ID)
+ if err != nil {
return err
}
@@ -265,7 +263,7 @@
updatingColumns := []string{"encrypted_password", "updated_at"}
u.EncryptedPassword = string(pw)
u.UpdatedAt = time.Now()
- if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
+ if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}
diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go
index 83cddd9b5..b345f9b01 100644
--- a/internal/api/client/auth/authorize.go
+++ b/internal/api/client/auth/authorize.go
@@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
return
}
- user := >smodel.User{}
- if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
+ user, err := m.db.GetUserByID(c.Request.Context(), userID)
+ if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
@@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
return
}
- user := >smodel.User{}
- if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
+ user, err := m.db.GetUserByID(c.Request.Context(), userID)
+ if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
diff --git a/internal/api/client/auth/authorize_test.go b/internal/api/client/auth/authorize_test.go
index eab893416..fcc4b8caa 100644
--- a/internal/api/client/auth/authorize_test.go
+++ b/internal/api/client/auth/authorize_test.go
@@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
doTest := func(testCase authorizeHandlerTestCase) {
ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "")
- user := suite.testUsers["unconfirmed_account"]
- account := suite.testAccounts["unconfirmed_account"]
+ user := >smodel.User{}
+ account := >smodel.Account{}
+
+ *user = *suite.testUsers["unconfirmed_account"]
+ *account = *suite.testAccounts["unconfirmed_account"]
testSession := sessions.Default(ctx)
testSession.Set(sessionUserID, user.ID)
@@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)
updatingColumns = append(updatingColumns, "updated_at")
- user.UpdatedAt = time.Now()
- err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...)
+ _, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...)
suite.NoError(err)
_, err = suite.db.UpdateAccount(context.Background(), account)
suite.NoError(err)
diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go
index 96a73a52f..daee2ae31 100644
--- a/internal/api/client/auth/callback.go
+++ b/internal/api/client/auth/callback.go
@@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i
// see if we already have a user for this email address
// if so, we don't need to continue + create one
- user := >smodel.User{}
- err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
+ user, err := m.db.GetUserByEmailAddress(ctx, claims.Email)
if err == nil {
return user, nil
}
diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go
index 58f3fad7e..06b601b10 100644
--- a/internal/api/client/auth/signin.go
+++ b/internal/api/client/auth/signin.go
@@ -28,9 +28,7 @@
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"golang.org/x/crypto/bcrypt"
)
@@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st
return incorrectPassword(err)
}
- user := >smodel.User{}
- if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil {
+ user, err := m.db.GetUserByEmailAddress(ctx, email)
+ if err != nil {
err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword(err)
}
diff --git a/internal/api/security/tokencheck.go b/internal/api/security/tokencheck.go
index 3df7ee943..9f2b7f36e 100644
--- a/internal/api/security/tokencheck.go
+++ b/internal/api/security/tokencheck.go
@@ -52,8 +52,8 @@ func (m *Module) TokenCheck(c *gin.Context) {
log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope())
// fetch user for this token
- user := >smodel.User{}
- if err := m.db.GetByID(ctx, userID, user); err != nil {
+ user, err := m.db.GetUserByID(ctx, userID)
+ if err != nil {
if err != db.ErrNoEntries {
log.Errorf("database error looking for user with id %s: %s", userID, err)
return
@@ -80,22 +80,25 @@ func (m *Module) TokenCheck(c *gin.Context) {
c.Set(oauth.SessionAuthorizedUser, user)
// fetch account for this token
- acct, err := m.db.GetAccountByID(ctx, user.AccountID)
- if err != nil {
- if err != db.ErrNoEntries {
- log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
+ if user.Account == nil {
+ acct, err := m.db.GetAccountByID(ctx, user.AccountID)
+ if err != nil {
+ if err != db.ErrNoEntries {
+ log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
+ return
+ }
+ log.Warnf("no account found for userID %s", userID)
return
}
- log.Warnf("no account found for userID %s", userID)
- return
+ user.Account = acct
}
- if !acct.SuspendedAt.IsZero() {
+ if !user.Account.SuspendedAt.IsZero() {
log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID)
return
}
- c.Set(oauth.SessionAuthorizedAccount, acct)
+ c.Set(oauth.SessionAuthorizedAccount, user.Account)
}
// check for application token
diff --git a/internal/cache/user.go b/internal/cache/user.go
new file mode 100644
index 000000000..23bf0b7e9
--- /dev/null
+++ b/internal/cache/user.go
@@ -0,0 +1,141 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package cache
+
+import (
+ "time"
+
+ "codeberg.org/gruf/go-cache/v2"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// UserCache is a cache wrapper to provide lookups for gtsmodel.User
+type UserCache struct {
+ cache cache.LookupCache[string, string, *gtsmodel.User]
+}
+
+// NewUserCache returns a new instantiated UserCache object
+func NewUserCache() *UserCache {
+ c := &UserCache{}
+ c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.User]{
+ RegisterLookups: func(lm *cache.LookupMap[string, string]) {
+ lm.RegisterLookup("accountid")
+ lm.RegisterLookup("email")
+ lm.RegisterLookup("unconfirmedemail")
+ lm.RegisterLookup("confirmationtoken")
+ },
+
+ AddLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
+ lm.Set("accountid", user.AccountID, user.ID)
+ if email := user.Email; email != "" {
+ lm.Set("email", email, user.ID)
+ }
+ if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
+ lm.Set("unconfirmedemail", unconfirmedEmail, user.ID)
+ }
+ if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
+ lm.Set("confirmationtoken", confirmationToken, user.ID)
+ }
+ },
+
+ DeleteLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
+ lm.Delete("accountid", user.AccountID)
+ if email := user.Email; email != "" {
+ lm.Delete("email", email)
+ }
+ if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
+ lm.Delete("unconfirmedemail", unconfirmedEmail)
+ }
+ if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
+ lm.Delete("confirmationtoken", confirmationToken)
+ }
+ },
+ })
+ c.cache.SetTTL(time.Minute*5, false)
+ c.cache.Start(time.Second * 10)
+ return c
+}
+
+// GetByID attempts to fetch a user from the cache by its ID, you will receive a copy for thread-safety
+func (c *UserCache) GetByID(id string) (*gtsmodel.User, bool) {
+ return c.cache.Get(id)
+}
+
+// GetByAccountID attempts to fetch a user from the cache by its account ID, you will receive a copy for thread-safety
+func (c *UserCache) GetByAccountID(accountID string) (*gtsmodel.User, bool) {
+ return c.cache.GetBy("accountid", accountID)
+}
+
+// GetByEmail attempts to fetch a user from the cache by its email address, you will receive a copy for thread-safety
+func (c *UserCache) GetByEmail(email string) (*gtsmodel.User, bool) {
+ return c.cache.GetBy("email", email)
+}
+
+// GetByUnconfirmedEmail attempts to fetch a user from the cache by its confirmation token, you will receive a copy for thread-safety
+func (c *UserCache) GetByConfirmationToken(token string) (*gtsmodel.User, bool) {
+ return c.cache.GetBy("confirmationtoken", token)
+}
+
+// Put places a user in the cache, ensuring that the object place is a copy for thread-safety
+func (c *UserCache) Put(user *gtsmodel.User) {
+ if user == nil || user.ID == "" {
+ panic("invalid user")
+ }
+ c.cache.Set(user.ID, copyUser(user))
+}
+
+// Invalidate invalidates one user from the cache using the ID of the user as key.
+func (c *UserCache) Invalidate(userID string) {
+ c.cache.Invalidate(userID)
+}
+
+func copyUser(user *gtsmodel.User) *gtsmodel.User {
+ return >smodel.User{
+ ID: user.ID,
+ CreatedAt: user.CreatedAt,
+ UpdatedAt: user.UpdatedAt,
+ Email: user.Email,
+ AccountID: user.AccountID,
+ Account: nil,
+ EncryptedPassword: user.EncryptedPassword,
+ SignUpIP: user.SignUpIP,
+ CurrentSignInAt: user.CurrentSignInAt,
+ CurrentSignInIP: user.CurrentSignInIP,
+ LastSignInAt: user.LastSignInAt,
+ LastSignInIP: user.LastSignInIP,
+ SignInCount: user.SignInCount,
+ InviteID: user.InviteID,
+ ChosenLanguages: user.ChosenLanguages,
+ FilteredLanguages: user.FilteredLanguages,
+ Locale: user.Locale,
+ CreatedByApplicationID: user.CreatedByApplicationID,
+ CreatedByApplication: nil,
+ LastEmailedAt: user.LastEmailedAt,
+ ConfirmationToken: user.ConfirmationToken,
+ ConfirmationSentAt: user.ConfirmationSentAt,
+ ConfirmedAt: user.ConfirmedAt,
+ UnconfirmedEmail: user.UnconfirmedEmail,
+ Moderator: copyBoolPtr(user.Moderator),
+ Admin: copyBoolPtr(user.Admin),
+ Disabled: copyBoolPtr(user.Disabled),
+ Approved: copyBoolPtr(user.Approved),
+ ResetPasswordToken: user.ResetPasswordToken,
+ ResetPasswordSentAt: user.ResetPasswordSentAt,
+ }
+}
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
index f66ed0294..9fa78eca0 100644
--- a/internal/db/bundb/admin.go
+++ b/internal/db/bundb/admin.go
@@ -30,6 +30,7 @@
"time"
"github.com/superseriousbusiness/gotosocial/internal/ap"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -40,7 +41,8 @@
)
type adminDB struct {
- conn *DBConn
+ conn *DBConn
+ userCache *cache.UserCache
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
@@ -175,6 +177,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
Exec(ctx); err != nil {
return nil, a.conn.ProcessError(err)
}
+ a.userCache.Put(u)
return u, nil
}
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 1579fae76..70a44d4c1 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -87,6 +87,7 @@ type DBService struct {
db.Session
db.Status
db.Timeline
+ db.User
conn *DBConn
}
@@ -181,13 +182,15 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
notifCache.SetTTL(time.Minute*5, false)
notifCache.Start(time.Second * 10)
- // Prepare domain block cache
+ // Prepare other caches
blockCache := cache.NewDomainBlockCache()
+ userCache := cache.NewUserCache()
ps := &DBService{
Account: accounts,
Admin: &adminDB{
- conn: conn,
+ conn: conn,
+ userCache: userCache,
},
Basic: &basicDB{
conn: conn,
@@ -219,7 +222,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
},
Status: status,
Timeline: timeline,
- conn: conn,
+ User: &userDB{
+ conn: conn,
+ cache: userCache,
+ },
+ conn: conn,
}
// we can confidently return this useable service now
diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go
new file mode 100644
index 000000000..46f24c4b2
--- /dev/null
+++ b/internal/db/bundb/user.go
@@ -0,0 +1,151 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type userDB struct {
+ conn *DBConn
+ cache *cache.UserCache
+}
+
+func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery {
+ return u.conn.
+ NewSelect().
+ Model(user).
+ Relation("Account")
+}
+
+func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) {
+ // Attempt to fetch cached user
+ user, cached := cacheGet()
+
+ if !cached {
+ user = >smodel.User{}
+
+ // Not cached! Perform database query
+ err := dbQuery(user)
+ if err != nil {
+ return nil, u.conn.ProcessError(err)
+ }
+
+ // Place in the cache
+ u.cache.Put(user)
+ }
+
+ return user, nil
+}
+
+func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) {
+ return u.getUser(
+ ctx,
+ func() (*gtsmodel.User, bool) {
+ return u.cache.GetByID(id)
+ },
+ func(user *gtsmodel.User) error {
+ return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx)
+ },
+ )
+}
+
+func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) {
+ return u.getUser(
+ ctx,
+ func() (*gtsmodel.User, bool) {
+ return u.cache.GetByAccountID(accountID)
+ },
+ func(user *gtsmodel.User) error {
+ return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx)
+ },
+ )
+}
+
+func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) {
+ return u.getUser(
+ ctx,
+ func() (*gtsmodel.User, bool) {
+ return u.cache.GetByEmail(emailAddress)
+ },
+ func(user *gtsmodel.User) error {
+ return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx)
+ },
+ )
+}
+
+func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) {
+ return u.getUser(
+ ctx,
+ func() (*gtsmodel.User, bool) {
+ return u.cache.GetByConfirmationToken(confirmationToken)
+ },
+ func(user *gtsmodel.User) error {
+ return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx)
+ },
+ )
+}
+
+func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) {
+ if _, err := u.conn.
+ NewInsert().
+ Model(user).
+ Exec(ctx); err != nil {
+ return nil, u.conn.ProcessError(err)
+ }
+
+ u.cache.Put(user)
+ return user, nil
+}
+
+func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) {
+ // Update the user's last-updated
+ user.UpdatedAt = time.Now()
+
+ if _, err := u.conn.
+ NewUpdate().
+ Model(user).
+ WherePK().
+ Column(columns...).
+ Exec(ctx); err != nil {
+ return nil, u.conn.ProcessError(err)
+ }
+
+ u.cache.Invalidate(user.ID)
+ return user, nil
+}
+
+func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
+ if _, err := u.conn.
+ NewDelete().
+ Model(>smodel.User{ID: userID}).
+ WherePK().
+ Exec(ctx); err != nil {
+ return u.conn.ProcessError(err)
+ }
+
+ u.cache.Invalidate(userID)
+ return nil
+}
diff --git a/internal/db/bundb/user_test.go b/internal/db/bundb/user_test.go
new file mode 100644
index 000000000..6ad59fc8e
--- /dev/null
+++ b/internal/db/bundb/user_test.go
@@ -0,0 +1,73 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type UserTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *UserTestSuite) TestGetUser() {
+ user, err := suite.db.GetUserByID(context.Background(), suite.testUsers["local_account_1"].ID)
+ suite.NoError(err)
+ suite.NotNil(user)
+}
+
+func (suite *UserTestSuite) TestGetUserByEmailAddress() {
+ user, err := suite.db.GetUserByEmailAddress(context.Background(), suite.testUsers["local_account_1"].Email)
+ suite.NoError(err)
+ suite.NotNil(user)
+}
+
+func (suite *UserTestSuite) TestGetUserByAccountID() {
+ user, err := suite.db.GetUserByAccountID(context.Background(), suite.testAccounts["local_account_1"].ID)
+ suite.NoError(err)
+ suite.NotNil(user)
+}
+
+func (suite *UserTestSuite) TestUpdateUserSelectedColumns() {
+ testUser := suite.testUsers["local_account_1"]
+ user := >smodel.User{
+ ID: testUser.ID,
+ Email: "whatever",
+ Locale: "es",
+ }
+
+ user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale")
+ suite.NoError(err)
+ suite.NotNil(user)
+
+ dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID)
+ suite.NoError(err)
+ suite.NotNil(dbUser)
+ suite.Equal("whatever", dbUser.Email)
+ suite.Equal("es", dbUser.Locale)
+ suite.Equal(testUser.AccountID, dbUser.AccountID)
+}
+
+func TestUserTestSuite(t *testing.T) {
+ suite.Run(t, new(UserTestSuite))
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index 0c1f2602a..52a76ecdb 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -44,6 +44,7 @@ type DB interface {
Session
Status
Timeline
+ User
/*
USEFUL CONVERSION FUNCTIONS
diff --git a/internal/db/user.go b/internal/db/user.go
new file mode 100644
index 000000000..a4d48db56
--- /dev/null
+++ b/internal/db/user.go
@@ -0,0 +1,42 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// User contains functions related to user getting/setting/creation.
+type User interface {
+ // GetUserByID returns one user with the given ID, or an error if something goes wrong.
+ GetUserByID(ctx context.Context, id string) (*gtsmodel.User, Error)
+ // GetUserByAccountID returns one user by its account ID, or an error if something goes wrong.
+ GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, Error)
+ // GetUserByID returns one user with the given email address, or an error if something goes wrong.
+ GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error)
+ // GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong.
+ GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error)
+ // UpdateUser updates one user by its primary key. If columns is set, only given columns
+ // will be updated. If not set, all columns will be updated.
+ UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, Error)
+ // DeleteUserByID deletes one user by its ID.
+ DeleteUserByID(ctx context.Context, userID string) Error
+}
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index 3a5a9c622..3758a4000 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -70,13 +70,14 @@ func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// 1. Delete account's application(s), clients, and oauth tokens
// we only need to do this step for local account since remote ones won't have any tokens or applications on our server
+ var user *gtsmodel.User
if account.Domain == "" {
// see if we can get a user for this account
- u := >smodel.User{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil {
+ var err error
+ if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil {
// we got one! select all tokens with the user's ID
tokens := []*gtsmodel.Token{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {
// we have some tokens to delete
for _, t := range tokens {
// delete client(s) associated with this token
@@ -240,9 +241,11 @@ func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// TODO
// 16. Delete account's user
- l.Debug("deleting account user")
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil {
- return gtserror.NewErrorInternalError(err)
+ if user != nil {
+ l.Debug("deleting account user")
+ if err := p.db.DeleteUserByID(ctx, user.ID); err != nil {
+ return gtserror.NewErrorInternalError(err)
+ }
}
// 17. Delete account's timeline
@@ -288,8 +291,8 @@ func (p *processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
if form.DeleteOriginID == account.ID {
// the account owner themself has requested deletion via the API, get their user from the db
- user := >smodel.User{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil {
+ user, err := p.db.GetUserByAccountID(ctx, account.ID)
+ if err != nil {
return gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go
index d7c9c5d82..a688e3732 100644
--- a/internal/processing/fromclientapi.go
+++ b/internal/processing/fromclientapi.go
@@ -29,7 +29,6 @@
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/ap"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages"
@@ -138,8 +137,8 @@ func (p *processor) processCreateAccountFromClientAPI(ctx context.Context, clien
}
// get the user this account belongs to
- user := >smodel.User{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil {
+ user, err := p.db.GetUserByAccountID(ctx, account.ID)
+ if err != nil {
return err
}
diff --git a/internal/processing/fromfederator_test.go b/internal/processing/fromfederator_test.go
index 9337482c4..22d0ba9f4 100644
--- a/internal/processing/fromfederator_test.go
+++ b/internal/processing/fromfederator_test.go
@@ -370,7 +370,7 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() {
// no statuses from foss satan should be left in the database
if !testrig.WaitFor(func() bool {
s, err := suite.db.GetAccountStatuses(ctx, deletedAccount.ID, 0, false, false, "", "", false, false, false)
- return s == nil && err == db.ErrNoEntries
+ return s == nil && err == db.ErrNoEntries
}) {
suite.FailNow("timeout waiting for statuses to be deleted")
}
diff --git a/internal/processing/instance.go b/internal/processing/instance.go
index b7418659a..32a4de6f0 100644
--- a/internal/processing/instance.go
+++ b/internal/processing/instance.go
@@ -142,8 +142,8 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))
}
// make sure it has a user associated with it
- contactUser := >smodel.User{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil {
+ contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID)
+ if err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))
}
// suspended accounts cannot be contact accounts
diff --git a/internal/processing/streaming/authorize.go b/internal/processing/streaming/authorize.go
index 70e4741e1..cb152b676 100644
--- a/internal/processing/streaming/authorize.go
+++ b/internal/processing/streaming/authorize.go
@@ -40,8 +40,8 @@ func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken s
return nil, gtserror.NewErrorUnauthorized(err)
}
- user := >smodel.User{}
- if err := p.db.GetByID(ctx, uid, user); err != nil {
+ user, err := p.db.GetUserByID(ctx, uid)
+ if err != nil {
if err == db.ErrNoEntries {
err := fmt.Errorf("no user found for validated uid %s", uid)
return nil, gtserror.NewErrorUnauthorized(err)
diff --git a/internal/processing/user/emailconfirm.go b/internal/processing/user/emailconfirm.go
index 6bffce7d9..5a68383b8 100644
--- a/internal/processing/user/emailconfirm.go
+++ b/internal/processing/user/emailconfirm.go
@@ -89,8 +89,8 @@ func (p *processor) ConfirmEmail(ctx context.Context, token string) (*gtsmodel.U
return nil, gtserror.NewErrorNotFound(errors.New("no token provided"))
}
- user := >smodel.User{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "confirmation_token", Value: token}}, user); err != nil {
+ user, err := p.db.GetUserByConfirmationToken(ctx, token)
+ if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err)
}
diff --git a/internal/typeutils/internaltofrontend_test.go b/internal/typeutils/internaltofrontend_test.go
index 6028344b4..a13e5255c 100644
--- a/internal/typeutils/internaltofrontend_test.go
+++ b/internal/typeutils/internaltofrontend_test.go
@@ -46,9 +46,9 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontend() {
func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiStruct() {
testAccount := suite.testAccounts["local_account_1"] // take zork for this test
testEmoji := suite.testEmojis["rainbow"]
-
+
testAccount.Emojis = []*gtsmodel.Emoji{testEmoji}
-
+
apiAccount, err := suite.typeconverter.AccountToAPIAccountPublic(context.Background(), testAccount)
suite.NoError(err)
suite.NotNil(apiAccount)
@@ -61,9 +61,9 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiStruct()
func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiIDs() {
testAccount := suite.testAccounts["local_account_1"] // take zork for this test
testEmoji := suite.testEmojis["rainbow"]
-
+
testAccount.EmojiIDs = []string{testEmoji.ID}
-
+
apiAccount, err := suite.typeconverter.AccountToAPIAccountPublic(context.Background(), testAccount)
suite.NoError(err)
suite.NotNil(apiAccount)
diff --git a/internal/visibility/statusvisible.go b/internal/visibility/statusvisible.go
index 15d8544ad..c62ebb0af 100644
--- a/internal/visibility/statusvisible.go
+++ b/internal/visibility/statusvisible.go
@@ -68,8 +68,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu
// if the target user doesn't exist (anymore) then the status also shouldn't be visible
// note: we only do this for local users
if targetAccount.Domain == "" {
- targetUser := >smodel.User{}
- if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil {
+ targetUser, err := f.db.GetUserByAccountID(ctx, targetAccount.ID)
+ if err != nil {
l.Debug("target user could not be selected")
if err == db.ErrNoEntries {
return false, nil
@@ -98,8 +98,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu
// if the requesting user doesn't exist (anymore) then the status also shouldn't be visible
// note: we only do this for local users
if requestingAccount.Domain == "" {
- requestingUser := >smodel.User{}
- if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil {
+ requestingUser, err := f.db.GetUserByAccountID(ctx, requestingAccount.ID)
+ if err != nil {
// if the requesting account is local but doesn't have a corresponding user in the db this is a problem
l.Debug("requesting user could not be selected")
if err == db.ErrNoEntries {