[chore] update database caching library (#1040)

* convert most of the caches to use result.Cache{}

* add caching of emojis

* fix issues causing failing tests

* update go-cache/v2 instances with v3

* fix getnotification

* add a note about the left-in StatusCreate comment

* update EmojiCategory db access to use new result.Cache{}

* fix possible panic in getstatusparents

* further proof that kim is not stinky
This commit is contained in:
kim 2022-11-15 18:45:15 +00:00 committed by GitHub
parent 9ab60136dd
commit 8598dea98b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 725 additions and 2289 deletions

View file

@ -149,11 +149,10 @@
return err return err
} }
updatingColumns := []string{"admin", "updated_at"}
admin := true admin := true
u.Admin = &admin u.Admin = &admin
u.UpdatedAt = time.Now() u.UpdatedAt = time.Now()
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil { if err := dbConn.UpdateUser(ctx, u); err != nil {
return err return err
} }
@ -185,11 +184,10 @@
return err return err
} }
updatingColumns := []string{"admin", "updated_at"}
admin := false admin := false
u.Admin = &admin u.Admin = &admin
u.UpdatedAt = time.Now() u.UpdatedAt = time.Now()
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil { if err := dbConn.UpdateUser(ctx, u); err != nil {
return err return err
} }
@ -221,11 +219,10 @@
return err return err
} }
updatingColumns := []string{"disabled", "updated_at"}
disabled := true disabled := true
u.Disabled = &disabled u.Disabled = &disabled
u.UpdatedAt = time.Now() u.UpdatedAt = time.Now()
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil { if err := dbConn.UpdateUser(ctx, u); err != nil {
return err return err
} }
@ -270,10 +267,9 @@
return fmt.Errorf("error hashing password: %s", err) return fmt.Errorf("error hashing password: %s", err)
} }
updatingColumns := []string{"encrypted_password", "updated_at"}
u.EncryptedPassword = string(pw) u.EncryptedPassword = string(pw)
u.UpdatedAt = time.Now() u.UpdatedAt = time.Now()
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil { if err := dbConn.UpdateUser(ctx, u); err != nil {
return err return err
} }

1
go.mod
View file

@ -5,7 +5,6 @@ go 1.19
require ( require (
codeberg.org/gruf/go-bytesize v1.0.0 codeberg.org/gruf/go-bytesize v1.0.0
codeberg.org/gruf/go-byteutil v1.0.2 codeberg.org/gruf/go-byteutil v1.0.2
codeberg.org/gruf/go-cache/v2 v2.1.4
codeberg.org/gruf/go-cache/v3 v3.1.8 codeberg.org/gruf/go-cache/v3 v3.1.8
codeberg.org/gruf/go-debug v1.2.0 codeberg.org/gruf/go-debug v1.2.0
codeberg.org/gruf/go-errors/v2 v2.0.2 codeberg.org/gruf/go-errors/v2 v2.0.2

2
go.sum
View file

@ -69,8 +69,6 @@ codeberg.org/gruf/go-bytesize v1.0.0/go.mod h1:n/GU8HzL9f3UNp/mUKyr1qVmTlj7+xacp
codeberg.org/gruf/go-byteutil v1.0.0/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU= codeberg.org/gruf/go-byteutil v1.0.0/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
codeberg.org/gruf/go-byteutil v1.0.2 h1:OesVyK5VKWeWdeDR00zRJ+Oy8hjXx1pBhn7WVvcZWVE= codeberg.org/gruf/go-byteutil v1.0.2 h1:OesVyK5VKWeWdeDR00zRJ+Oy8hjXx1pBhn7WVvcZWVE=
codeberg.org/gruf/go-byteutil v1.0.2/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU= codeberg.org/gruf/go-byteutil v1.0.2/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
codeberg.org/gruf/go-cache/v2 v2.1.4 h1:r+6wJiTHZn0qqf+p1VtAjGOgXXJl7s8txhPIwoSMZtI=
codeberg.org/gruf/go-cache/v2 v2.1.4/go.mod h1:j7teiz814lG0PfSfnUs+6HA+2/jcjTAR71Ou3Wbt2Xk=
codeberg.org/gruf/go-cache/v3 v3.1.8 h1:wbUef/QtRstEb7sSpQYHT5CtSFtKkeZr4ZhOTXqOpac= codeberg.org/gruf/go-cache/v3 v3.1.8 h1:wbUef/QtRstEb7sSpQYHT5CtSFtKkeZr4ZhOTXqOpac=
codeberg.org/gruf/go-cache/v3 v3.1.8/go.mod h1:h6im2UVGdrGtNt4IVKARVeoW4kAdok5ts7CbH15UWXs= codeberg.org/gruf/go-cache/v3 v3.1.8/go.mod h1:h6im2UVGdrGtNt4IVKARVeoW4kAdok5ts7CbH15UWXs=
codeberg.org/gruf/go-debug v1.2.0 h1:WBbTMnK1ArFKUmgv04aO2JiC/daTOB8zQGi521qb7OU= codeberg.org/gruf/go-debug v1.2.0 h1:WBbTMnK1ArFKUmgv04aO2JiC/daTOB8zQGi521qb7OU=

View file

@ -20,7 +20,7 @@ type AuthAuthorizeTestSuite struct {
type authorizeHandlerTestCase struct { type authorizeHandlerTestCase struct {
description string description string
mutateUserAccount func(*gtsmodel.User, *gtsmodel.Account) []string mutateUserAccount func(*gtsmodel.User, *gtsmodel.Account)
expectedStatusCode int expectedStatusCode int
expectedLocationHeader string expectedLocationHeader string
} }
@ -29,44 +29,40 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
tests := []authorizeHandlerTestCase{ tests := []authorizeHandlerTestCase{
{ {
description: "user has their email unconfirmed", description: "user has their email unconfirmed",
mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) {
// nothing to do, weed_lord420 already has their email unconfirmed // nothing to do, weed_lord420 already has their email unconfirmed
return nil
}, },
expectedStatusCode: http.StatusSeeOther, expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.CheckYourEmailPath, expectedLocationHeader: auth.CheckYourEmailPath,
}, },
{ {
description: "user has their email confirmed but is not approved", description: "user has their email confirmed but is not approved",
mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) {
user.ConfirmedAt = time.Now() user.ConfirmedAt = time.Now()
user.Email = user.UnconfirmedEmail user.Email = user.UnconfirmedEmail
return []string{"confirmed_at", "email"}
}, },
expectedStatusCode: http.StatusSeeOther, expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.WaitForApprovalPath, expectedLocationHeader: auth.WaitForApprovalPath,
}, },
{ {
description: "user has their email confirmed and is approved, but User entity has been disabled", description: "user has their email confirmed and is approved, but User entity has been disabled",
mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) {
user.ConfirmedAt = time.Now() user.ConfirmedAt = time.Now()
user.Email = user.UnconfirmedEmail user.Email = user.UnconfirmedEmail
user.Approved = testrig.TrueBool() user.Approved = testrig.TrueBool()
user.Disabled = testrig.TrueBool() user.Disabled = testrig.TrueBool()
return []string{"confirmed_at", "email", "approved", "disabled"}
}, },
expectedStatusCode: http.StatusSeeOther, expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.AccountDisabledPath, expectedLocationHeader: auth.AccountDisabledPath,
}, },
{ {
description: "user has their email confirmed and is approved, but Account entity has been suspended", description: "user has their email confirmed and is approved, but Account entity has been suspended",
mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) {
user.ConfirmedAt = time.Now() user.ConfirmedAt = time.Now()
user.Email = user.UnconfirmedEmail user.Email = user.UnconfirmedEmail
user.Approved = testrig.TrueBool() user.Approved = testrig.TrueBool()
user.Disabled = testrig.FalseBool() user.Disabled = testrig.FalseBool()
account.SuspendedAt = time.Now() account.SuspendedAt = time.Now()
return []string{"confirmed_at", "email", "approved", "disabled"}
}, },
expectedStatusCode: http.StatusSeeOther, expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.AccountDisabledPath, expectedLocationHeader: auth.AccountDisabledPath,
@ -81,6 +77,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
*user = *suite.testUsers["unconfirmed_account"] *user = *suite.testUsers["unconfirmed_account"]
*account = *suite.testAccounts["unconfirmed_account"] *account = *suite.testAccounts["unconfirmed_account"]
user.SignInCount++ // cannot be 0 or fails NULL constraint
testSession := sessions.Default(ctx) testSession := sessions.Default(ctx)
testSession.Set(sessionUserID, user.ID) testSession.Set(sessionUserID, user.ID)
@ -89,14 +86,13 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
panic(fmt.Errorf("failed on case %s: %w", testCase.description, err)) panic(fmt.Errorf("failed on case %s: %w", testCase.description, err))
} }
updatingColumns := testCase.mutateUserAccount(user, account) testCase.mutateUserAccount(user, account)
testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt) testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)
updatingColumns = append(updatingColumns, "updated_at") err := suite.db.UpdateUser(context.Background(), user)
_, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...)
suite.NoError(err) suite.NoError(err)
_, err = suite.db.UpdateAccount(context.Background(), account) err = suite.db.UpdateAccount(context.Background(), account)
suite.NoError(err) suite.NoError(err)
// call the handler // call the handler

View file

@ -90,6 +90,15 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) {
return return
} }
// DO NOT COMMIT THIS UNCOMMENTED, IT WILL CAUSE MASS CHAOS.
// this is being left in as an ode to kim's shitposting.
//
// user := authed.Account.DisplayName
// if user == "" {
// user = authed.Account.Username
// }
// form.Status += "\n\nsent from " + user + "'s iphone\n"
if err := validateCreateStatus(form); err != nil { if err := validateCreateStatus(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return

View file

@ -106,8 +106,9 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusMarkdown() {
// set default post language of account 1 to markdown // set default post language of account 1 to markdown
testAccount := suite.testAccounts["local_account_1"] testAccount := suite.testAccounts["local_account_1"]
testAccount.StatusFormat = "markdown" testAccount.StatusFormat = "markdown"
a := testAccount
a, err := suite.db.UpdateAccount(context.Background(), testAccount) err := suite.db.UpdateAccount(context.Background(), a)
if err != nil { if err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
@ -149,9 +150,8 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusMarkdown() {
func (suite *StatusCreateTestSuite) TestMentionUnknownAccount() { func (suite *StatusCreateTestSuite) TestMentionUnknownAccount() {
// first remove remote account 1 from the database so it gets looked up again // first remove remote account 1 from the database so it gets looked up again
remoteAccount := suite.testAccounts["remote_account_1"] remoteAccount := suite.testAccounts["remote_account_1"]
if err := suite.db.DeleteByID(context.Background(), remoteAccount.ID, &gtsmodel.Account{}); err != nil { err := suite.db.DeleteAccount(context.Background(), remoteAccount.ID)
panic(err) suite.NoError(err)
}
t := suite.testTokens["local_account_1"] t := suite.testTokens["local_account_1"]
oauthToken := oauth.DBTokenToToken(t) oauthToken := oauth.DBTokenToToken(t)

View file

@ -1,171 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
import (
"time"
"codeberg.org/gruf/go-cache/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// AccountCache is a cache wrapper to provide URL and URI lookups for gtsmodel.Account
type AccountCache struct {
cache cache.LookupCache[string, string, *gtsmodel.Account]
}
// NewAccountCache returns a new instantiated AccountCache object
func NewAccountCache() *AccountCache {
c := &AccountCache{}
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.Account]{
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("uri")
lm.RegisterLookup("url")
lm.RegisterLookup("pubkeyid")
lm.RegisterLookup("usernamedomain")
},
AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
if uri := acc.URI; uri != "" {
lm.Set("uri", uri, acc.ID)
}
if url := acc.URL; url != "" {
lm.Set("url", url, acc.ID)
}
lm.Set("pubkeyid", acc.PublicKeyURI, acc.ID)
lm.Set("usernamedomain", usernameDomainKey(acc.Username, acc.Domain), acc.ID)
},
DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
if uri := acc.URI; uri != "" {
lm.Delete("uri", uri)
}
if url := acc.URL; url != "" {
lm.Delete("url", url)
}
lm.Delete("pubkeyid", acc.PublicKeyURI)
lm.Delete("usernamedomain", usernameDomainKey(acc.Username, acc.Domain))
},
})
c.cache.SetTTL(time.Minute*5, false)
c.cache.Start(time.Second * 10)
return c
}
// GetByID attempts to fetch a account from the cache by its ID, you will receive a copy for thread-safety
func (c *AccountCache) GetByID(id string) (*gtsmodel.Account, bool) {
return c.cache.Get(id)
}
// GetByURL attempts to fetch a account from the cache by its URL, you will receive a copy for thread-safety
func (c *AccountCache) GetByURL(url string) (*gtsmodel.Account, bool) {
return c.cache.GetBy("url", url)
}
// GetByURI attempts to fetch a account from the cache by its URI, you will receive a copy for thread-safety
func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) {
return c.cache.GetBy("uri", uri)
}
// GettByUsernameDomain attempts to fetch an account from the cache by its username@domain combo (or just username), you will receive a copy for thread-safety.
func (c *AccountCache) GetByUsernameDomain(username string, domain string) (*gtsmodel.Account, bool) {
return c.cache.GetBy("usernamedomain", usernameDomainKey(username, domain))
}
// GetByPubkeyID attempts to fetch an account from the cache by its public key URI (ID), you will receive a copy for thread-safety.
func (c *AccountCache) GetByPubkeyID(id string) (*gtsmodel.Account, bool) {
return c.cache.GetBy("pubkeyid", id)
}
// Put places a account in the cache, ensuring that the object place is a copy for thread-safety
func (c *AccountCache) Put(account *gtsmodel.Account) {
if account == nil || account.ID == "" {
panic("invalid account")
}
c.cache.Set(account.ID, copyAccount(account))
}
// Invalidate removes (invalidates) one account from the cache by its ID.
func (c *AccountCache) Invalidate(id string) {
c.cache.Invalidate(id)
}
// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects.
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
// this should be a relatively cheap process
func copyAccount(account *gtsmodel.Account) *gtsmodel.Account {
return &gtsmodel.Account{
ID: account.ID,
Username: account.Username,
Domain: account.Domain,
AvatarMediaAttachmentID: account.AvatarMediaAttachmentID,
AvatarMediaAttachment: nil,
AvatarRemoteURL: account.AvatarRemoteURL,
HeaderMediaAttachmentID: account.HeaderMediaAttachmentID,
HeaderMediaAttachment: nil,
HeaderRemoteURL: account.HeaderRemoteURL,
DisplayName: account.DisplayName,
EmojiIDs: account.EmojiIDs,
Emojis: nil,
Fields: account.Fields,
Note: account.Note,
NoteRaw: account.NoteRaw,
Memorial: copyBoolPtr(account.Memorial),
MovedToAccountID: account.MovedToAccountID,
Bot: copyBoolPtr(account.Bot),
CreatedAt: account.CreatedAt,
UpdatedAt: account.UpdatedAt,
Reason: account.Reason,
Locked: copyBoolPtr(account.Locked),
Discoverable: copyBoolPtr(account.Discoverable),
Privacy: account.Privacy,
Sensitive: copyBoolPtr(account.Sensitive),
Language: account.Language,
StatusFormat: account.StatusFormat,
CustomCSS: account.CustomCSS,
URI: account.URI,
URL: account.URL,
LastWebfingeredAt: account.LastWebfingeredAt,
InboxURI: account.InboxURI,
SharedInboxURI: account.SharedInboxURI,
OutboxURI: account.OutboxURI,
FollowingURI: account.FollowingURI,
FollowersURI: account.FollowersURI,
FeaturedCollectionURI: account.FeaturedCollectionURI,
ActorType: account.ActorType,
AlsoKnownAs: account.AlsoKnownAs,
PrivateKey: account.PrivateKey,
PublicKey: account.PublicKey,
PublicKeyURI: account.PublicKeyURI,
SensitizedAt: account.SensitizedAt,
SilencedAt: account.SilencedAt,
SuspendedAt: account.SuspendedAt,
HideCollections: copyBoolPtr(account.HideCollections),
SuspensionOrigin: account.SuspensionOrigin,
EnableRSS: copyBoolPtr(account.EnableRSS),
}
}
func usernameDomainKey(username string, domain string) string {
u := "@" + username
if domain != "" {
return u + "@" + domain
}
return u
}

View file

@ -1,96 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountCacheTestSuite struct {
suite.Suite
data map[string]*gtsmodel.Account
cache *cache.AccountCache
}
func (suite *AccountCacheTestSuite) SetupSuite() {
suite.data = testrig.NewTestAccounts()
}
func (suite *AccountCacheTestSuite) SetupTest() {
suite.cache = cache.NewAccountCache()
}
func (suite *AccountCacheTestSuite) TearDownTest() {
suite.data = nil
suite.cache = nil
}
func (suite *AccountCacheTestSuite) TestAccountCache() {
for _, account := range suite.data {
// Place in the cache
suite.cache.Put(account)
}
for _, account := range suite.data {
var ok bool
var check *gtsmodel.Account
// Check we can retrieve
check, ok = suite.cache.GetByID(account.ID)
if !ok && !accountIs(account, check) {
suite.Fail(fmt.Sprintf("Failed to fetch expected account with ID: %s", account.ID))
}
check, ok = suite.cache.GetByURI(account.URI)
if account.URI != "" && !ok && !accountIs(account, check) {
suite.Fail(fmt.Sprintf("Failed to fetch expected account with URI: %s", account.URI))
}
check, ok = suite.cache.GetByURL(account.URL)
if account.URL != "" && !ok && !accountIs(account, check) {
suite.Fail(fmt.Sprintf("Failed to fetch expected account with URL: %s", account.URL))
}
check, ok = suite.cache.GetByPubkeyID(account.PublicKeyURI)
if account.PublicKeyURI != "" && !ok && !accountIs(account, check) {
suite.Fail(fmt.Sprintf("Failed to fetch expected account with public key URI: %s", account.PublicKeyURI))
}
check, ok = suite.cache.GetByUsernameDomain(account.Username, account.Domain)
if !ok && !accountIs(account, check) {
suite.Fail(fmt.Sprintf("Failed to fetch expected account with username/domain: %s/%s", account.Username, account.Domain))
}
}
}
func TestAccountCache(t *testing.T) {
suite.Run(t, &AccountCacheTestSuite{})
}
func accountIs(account1, account2 *gtsmodel.Account) bool {
if account1 == nil || account2 == nil {
return account1 == account2
}
return account1.ID == account2.ID &&
account1.URI == account2.URI &&
account1.URL == account2.URL &&
account1.PublicKeyURI == account2.PublicKeyURI
}

View file

@ -1,106 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
import (
"time"
"codeberg.org/gruf/go-cache/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// DomainCache is a cache wrapper to provide URL and URI lookups for gtsmodel.Status
type DomainBlockCache struct {
cache cache.LookupCache[string, string, *gtsmodel.DomainBlock]
}
// NewStatusCache returns a new instantiated statusCache object
func NewDomainBlockCache() *DomainBlockCache {
c := &DomainBlockCache{}
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.DomainBlock]{
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("id")
},
AddLookups: func(lm *cache.LookupMap[string, string], block *gtsmodel.DomainBlock) {
// Block can be equal to nil when sentinel
if block != nil && block.ID != "" {
lm.Set("id", block.ID, block.Domain)
}
},
DeleteLookups: func(lm *cache.LookupMap[string, string], block *gtsmodel.DomainBlock) {
// Block can be equal to nil when sentinel
if block != nil && block.ID != "" {
lm.Delete("id", block.ID)
}
},
})
c.cache.SetTTL(time.Minute*5, false)
c.cache.Start(time.Second * 10)
return c
}
// GetByID attempts to fetch a status from the cache by its ID, you will receive a copy for thread-safety
func (c *DomainBlockCache) GetByID(id string) (*gtsmodel.DomainBlock, bool) {
return c.cache.GetBy("id", id)
}
// GetByURL attempts to fetch a status from the cache by its URL, you will receive a copy for thread-safety
func (c *DomainBlockCache) GetByDomain(domain string) (*gtsmodel.DomainBlock, bool) {
return c.cache.Get(domain)
}
// Put places a status in the cache, ensuring that the object place is a copy for thread-safety
func (c *DomainBlockCache) Put(domain string, block *gtsmodel.DomainBlock) {
if domain == "" {
panic("invalid domain")
}
if block == nil {
// This is a sentinel value for (no block)
c.cache.Set(domain, nil)
} else {
// This is a valid domain block
c.cache.Set(domain, copyDomainBlock(block))
}
}
// InvalidateByDomain will invalidate a domain block from the cache by domain name.
func (c *DomainBlockCache) InvalidateByDomain(domain string) {
c.cache.Invalidate(domain)
}
// copyStatus performs a surface-level copy of status, only keeping attached IDs intact, not the objects.
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
// this should be a relatively cheap process
func copyDomainBlock(block *gtsmodel.DomainBlock) *gtsmodel.DomainBlock {
return &gtsmodel.DomainBlock{
ID: block.ID,
CreatedAt: block.CreatedAt,
UpdatedAt: block.UpdatedAt,
Domain: block.Domain,
CreatedByAccountID: block.CreatedByAccountID,
CreatedByAccount: nil,
PrivateComment: block.PrivateComment,
PublicComment: block.PublicComment,
Obfuscate: block.Obfuscate,
SubscriptionID: block.SubscriptionID,
}
}

View file

@ -1,131 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
import (
"time"
"codeberg.org/gruf/go-cache/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// EmojiCache is a cache wrapper to provide ID and URI lookups for gtsmodel.Emoji
type EmojiCache struct {
cache cache.LookupCache[string, string, *gtsmodel.Emoji]
}
// NewEmojiCache returns a new instantiated EmojiCache object
func NewEmojiCache() *EmojiCache {
c := &EmojiCache{}
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.Emoji]{
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("uri")
lm.RegisterLookup("shortcodedomain")
lm.RegisterLookup("imagestaticurl")
},
AddLookups: func(lm *cache.LookupMap[string, string], emoji *gtsmodel.Emoji) {
lm.Set("shortcodedomain", shortcodeDomainKey(emoji.Shortcode, emoji.Domain), emoji.ID)
if uri := emoji.URI; uri != "" {
lm.Set("uri", uri, emoji.ID)
}
if imageStaticURL := emoji.ImageStaticURL; imageStaticURL != "" {
lm.Set("imagestaticurl", imageStaticURL, emoji.ID)
}
},
DeleteLookups: func(lm *cache.LookupMap[string, string], emoji *gtsmodel.Emoji) {
lm.Delete("shortcodedomain", shortcodeDomainKey(emoji.Shortcode, emoji.Domain))
if uri := emoji.URI; uri != "" {
lm.Delete("uri", uri)
}
if imageStaticURL := emoji.ImageStaticURL; imageStaticURL != "" {
lm.Delete("imagestaticurl", imageStaticURL)
}
},
})
c.cache.SetTTL(time.Minute*5, false)
c.cache.Start(time.Second * 10)
return c
}
// GetByID attempts to fetch an emoji from the cache by its ID, you will receive a copy for thread-safety
func (c *EmojiCache) GetByID(id string) (*gtsmodel.Emoji, bool) {
return c.cache.Get(id)
}
// GetByURI attempts to fetch an emoji from the cache by its URI, you will receive a copy for thread-safety
func (c *EmojiCache) GetByURI(uri string) (*gtsmodel.Emoji, bool) {
return c.cache.GetBy("uri", uri)
}
func (c *EmojiCache) GetByShortcodeDomain(shortcode string, domain string) (*gtsmodel.Emoji, bool) {
return c.cache.GetBy("shortcodedomain", shortcodeDomainKey(shortcode, domain))
}
func (c *EmojiCache) GetByImageStaticURL(imageStaticURL string) (*gtsmodel.Emoji, bool) {
return c.cache.GetBy("imagestaticurl", imageStaticURL)
}
// Put places an emoji in the cache, ensuring that the object place is a copy for thread-safety
func (c *EmojiCache) Put(emoji *gtsmodel.Emoji) {
if emoji == nil || emoji.ID == "" {
panic("invalid emoji")
}
c.cache.Set(emoji.ID, copyEmoji(emoji))
}
func (c *EmojiCache) Invalidate(emojiID string) {
c.cache.Invalidate(emojiID)
}
// copyEmoji performs a surface-level copy of emoji, only keeping attached IDs intact, not the objects.
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
// this should be a relatively cheap process
func copyEmoji(emoji *gtsmodel.Emoji) *gtsmodel.Emoji {
return &gtsmodel.Emoji{
ID: emoji.ID,
CreatedAt: emoji.CreatedAt,
UpdatedAt: emoji.UpdatedAt,
Shortcode: emoji.Shortcode,
Domain: emoji.Domain,
ImageRemoteURL: emoji.ImageRemoteURL,
ImageStaticRemoteURL: emoji.ImageStaticRemoteURL,
ImageURL: emoji.ImageURL,
ImageStaticURL: emoji.ImageStaticURL,
ImagePath: emoji.ImagePath,
ImageStaticPath: emoji.ImageStaticPath,
ImageContentType: emoji.ImageContentType,
ImageStaticContentType: emoji.ImageStaticContentType,
ImageFileSize: emoji.ImageFileSize,
ImageStaticFileSize: emoji.ImageStaticFileSize,
ImageUpdatedAt: emoji.ImageUpdatedAt,
Disabled: copyBoolPtr(emoji.Disabled),
URI: emoji.URI,
VisibleInPicker: copyBoolPtr(emoji.VisibleInPicker),
CategoryID: emoji.CategoryID,
}
}
func shortcodeDomainKey(shortcode string, domain string) string {
if domain != "" {
return shortcode + "@" + domain
}
return shortcode
}

View file

@ -1,84 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
import (
"strings"
"time"
"codeberg.org/gruf/go-cache/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// EmojiCategoryCache is a cache wrapper to provide ID lookups for gtsmodel.EmojiCategory
type EmojiCategoryCache struct {
cache cache.LookupCache[string, string, *gtsmodel.EmojiCategory]
}
// NewEmojiCategoryCache returns a new instantiated EmojiCategoryCache object
func NewEmojiCategoryCache() *EmojiCategoryCache {
c := &EmojiCategoryCache{}
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.EmojiCategory]{
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("name")
},
AddLookups: func(lm *cache.LookupMap[string, string], emojiCategory *gtsmodel.EmojiCategory) {
lm.Set(("name"), strings.ToLower(emojiCategory.Name), emojiCategory.ID)
},
DeleteLookups: func(lm *cache.LookupMap[string, string], emojiCategory *gtsmodel.EmojiCategory) {
lm.Delete("name", strings.ToLower(emojiCategory.Name))
},
})
c.cache.SetTTL(time.Minute*5, false)
c.cache.Start(time.Second * 10)
return c
}
// GetByID attempts to fetch an emojiCategory from the cache by its ID, you will receive a copy for thread-safety
func (c *EmojiCategoryCache) GetByID(id string) (*gtsmodel.EmojiCategory, bool) {
return c.cache.Get(id)
}
// GetByName attempts to fetch an emojiCategory from the cache by its name, you will receive a copy for thread-safety
func (c *EmojiCategoryCache) GetByName(name string) (*gtsmodel.EmojiCategory, bool) {
return c.cache.GetBy("name", strings.ToLower(name))
}
// Put places an emojiCategory in the cache, ensuring that the object place is a copy for thread-safety
func (c *EmojiCategoryCache) Put(emoji *gtsmodel.EmojiCategory) {
if emoji == nil || emoji.ID == "" {
panic("invalid emoji")
}
c.cache.Set(emoji.ID, copyEmojiCategory(emoji))
}
func (c *EmojiCategoryCache) Invalidate(emojiID string) {
c.cache.Invalidate(emojiID)
}
func copyEmojiCategory(emojiCategory *gtsmodel.EmojiCategory) *gtsmodel.EmojiCategory {
return &gtsmodel.EmojiCategory{
ID: emojiCategory.ID,
CreatedAt: emojiCategory.CreatedAt,
UpdatedAt: emojiCategory.UpdatedAt,
Name: emojiCategory.Name,
}
}

View file

@ -1,138 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
import (
"time"
"codeberg.org/gruf/go-cache/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// StatusCache is a cache wrapper to provide URL and URI lookups for gtsmodel.Status
type StatusCache struct {
cache cache.LookupCache[string, string, *gtsmodel.Status]
}
// NewStatusCache returns a new instantiated statusCache object
func NewStatusCache() *StatusCache {
c := &StatusCache{}
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.Status]{
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("uri")
lm.RegisterLookup("url")
},
AddLookups: func(lm *cache.LookupMap[string, string], status *gtsmodel.Status) {
if uri := status.URI; uri != "" {
lm.Set("uri", uri, status.ID)
}
if url := status.URL; url != "" {
lm.Set("url", url, status.ID)
}
},
DeleteLookups: func(lm *cache.LookupMap[string, string], status *gtsmodel.Status) {
if uri := status.URI; uri != "" {
lm.Delete("uri", uri)
}
if url := status.URL; url != "" {
lm.Delete("url", url)
}
},
})
c.cache.SetTTL(time.Minute*5, false)
c.cache.Start(time.Second * 10)
return c
}
// GetByID attempts to fetch a status from the cache by its ID, you will receive a copy for thread-safety
func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) {
return c.cache.Get(id)
}
// GetByURL attempts to fetch a status from the cache by its URL, you will receive a copy for thread-safety
func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) {
return c.cache.GetBy("url", url)
}
// GetByURI attempts to fetch a status from the cache by its URI, you will receive a copy for thread-safety
func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) {
return c.cache.GetBy("uri", uri)
}
// Put places a status in the cache, ensuring that the object place is a copy for thread-safety
func (c *StatusCache) Put(status *gtsmodel.Status) {
if status == nil || status.ID == "" {
panic("invalid status")
}
c.cache.Set(status.ID, copyStatus(status))
}
// Invalidate invalidates one status from the cache using the ID of the status as key.
func (c *StatusCache) Invalidate(statusID string) {
c.cache.Invalidate(statusID)
}
// copyStatus performs a surface-level copy of status, only keeping attached IDs intact, not the objects.
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
// this should be a relatively cheap process
func copyStatus(status *gtsmodel.Status) *gtsmodel.Status {
return &gtsmodel.Status{
ID: status.ID,
URI: status.URI,
URL: status.URL,
Content: status.Content,
AttachmentIDs: status.AttachmentIDs,
Attachments: nil,
TagIDs: status.TagIDs,
Tags: nil,
MentionIDs: status.MentionIDs,
Mentions: nil,
EmojiIDs: status.EmojiIDs,
Emojis: nil,
Local: copyBoolPtr(status.Local),
CreatedAt: status.CreatedAt,
UpdatedAt: status.UpdatedAt,
AccountID: status.AccountID,
Account: nil,
AccountURI: status.AccountURI,
InReplyToID: status.InReplyToID,
InReplyTo: nil,
InReplyToURI: status.InReplyToURI,
InReplyToAccountID: status.InReplyToAccountID,
InReplyToAccount: nil,
BoostOfID: status.BoostOfID,
BoostOf: nil,
BoostOfAccountID: status.BoostOfAccountID,
BoostOfAccount: nil,
ContentWarning: status.ContentWarning,
Visibility: status.Visibility,
Sensitive: copyBoolPtr(status.Sensitive),
Language: status.Language,
CreatedWithApplicationID: status.CreatedWithApplicationID,
ActivityStreamsType: status.ActivityStreamsType,
Text: status.Text,
Pinned: copyBoolPtr(status.Pinned),
Federated: copyBoolPtr(status.Federated),
Boostable: copyBoolPtr(status.Boostable),
Replyable: copyBoolPtr(status.Replyable),
Likeable: copyBoolPtr(status.Likeable),
}
}

View file

@ -1,113 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type StatusCacheTestSuite struct {
suite.Suite
data map[string]*gtsmodel.Status
cache *cache.StatusCache
}
func (suite *StatusCacheTestSuite) SetupSuite() {
suite.data = testrig.NewTestStatuses()
}
func (suite *StatusCacheTestSuite) SetupTest() {
suite.cache = cache.NewStatusCache()
}
func (suite *StatusCacheTestSuite) TearDownTest() {
suite.data = nil
suite.cache = nil
}
func (suite *StatusCacheTestSuite) TestStatusCache() {
for _, status := range suite.data {
// Place in the cache
suite.cache.Put(status)
}
for _, status := range suite.data {
var ok bool
var check *gtsmodel.Status
// Check we can retrieve
check, ok = suite.cache.GetByID(status.ID)
if !ok && !statusIs(status, check) {
suite.Fail(fmt.Sprintf("Failed to fetch expected account with ID: %s", status.ID))
}
check, ok = suite.cache.GetByURI(status.URI)
if status.URI != "" && !ok && !statusIs(status, check) {
suite.Fail(fmt.Sprintf("Failed to fetch expected account with URI: %s", status.URI))
}
check, ok = suite.cache.GetByURL(status.URL)
if status.URL != "" && !ok && !statusIs(status, check) {
suite.Fail(fmt.Sprintf("Failed to fetch expected account with URL: %s", status.URL))
}
}
}
func (suite *StatusCacheTestSuite) TestBoolPointerCopying() {
originalStatus := suite.data["local_account_1_status_1"]
// mark the status as pinned + cache it
pinned := true
originalStatus.Pinned = &pinned
suite.cache.Put(originalStatus)
// retrieve it
cachedStatus, ok := suite.cache.GetByID(originalStatus.ID)
if !ok {
suite.FailNow("status wasn't retrievable from cache")
}
// we should be able to change the original status values + cached
// values independently since they use different pointers
suite.True(*cachedStatus.Pinned)
*originalStatus.Pinned = false
suite.False(*originalStatus.Pinned)
suite.True(*cachedStatus.Pinned)
*originalStatus.Pinned = true
*cachedStatus.Pinned = false
suite.True(*originalStatus.Pinned)
suite.False(*cachedStatus.Pinned)
}
func TestStatusCache(t *testing.T) {
suite.Run(t, &StatusCacheTestSuite{})
}
func statusIs(status1, status2 *gtsmodel.Status) bool {
if status1 == nil || status2 == nil {
return status1 == status2
}
return status1.ID == status2.ID &&
status1.URI == status2.URI &&
status1.URL == status2.URL
}

141
internal/cache/user.go vendored
View file

@ -1,141 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
import (
"time"
"codeberg.org/gruf/go-cache/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// UserCache is a cache wrapper to provide lookups for gtsmodel.User
type UserCache struct {
cache cache.LookupCache[string, string, *gtsmodel.User]
}
// NewUserCache returns a new instantiated UserCache object
func NewUserCache() *UserCache {
c := &UserCache{}
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.User]{
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("accountid")
lm.RegisterLookup("email")
lm.RegisterLookup("unconfirmedemail")
lm.RegisterLookup("confirmationtoken")
},
AddLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
lm.Set("accountid", user.AccountID, user.ID)
if email := user.Email; email != "" {
lm.Set("email", email, user.ID)
}
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
lm.Set("unconfirmedemail", unconfirmedEmail, user.ID)
}
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
lm.Set("confirmationtoken", confirmationToken, user.ID)
}
},
DeleteLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
lm.Delete("accountid", user.AccountID)
if email := user.Email; email != "" {
lm.Delete("email", email)
}
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
lm.Delete("unconfirmedemail", unconfirmedEmail)
}
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
lm.Delete("confirmationtoken", confirmationToken)
}
},
})
c.cache.SetTTL(time.Minute*5, false)
c.cache.Start(time.Second * 10)
return c
}
// GetByID attempts to fetch a user from the cache by its ID, you will receive a copy for thread-safety
func (c *UserCache) GetByID(id string) (*gtsmodel.User, bool) {
return c.cache.Get(id)
}
// GetByAccountID attempts to fetch a user from the cache by its account ID, you will receive a copy for thread-safety
func (c *UserCache) GetByAccountID(accountID string) (*gtsmodel.User, bool) {
return c.cache.GetBy("accountid", accountID)
}
// GetByEmail attempts to fetch a user from the cache by its email address, you will receive a copy for thread-safety
func (c *UserCache) GetByEmail(email string) (*gtsmodel.User, bool) {
return c.cache.GetBy("email", email)
}
// GetByUnconfirmedEmail attempts to fetch a user from the cache by its confirmation token, you will receive a copy for thread-safety
func (c *UserCache) GetByConfirmationToken(token string) (*gtsmodel.User, bool) {
return c.cache.GetBy("confirmationtoken", token)
}
// Put places a user in the cache, ensuring that the object place is a copy for thread-safety
func (c *UserCache) Put(user *gtsmodel.User) {
if user == nil || user.ID == "" {
panic("invalid user")
}
c.cache.Set(user.ID, copyUser(user))
}
// Invalidate invalidates one user from the cache using the ID of the user as key.
func (c *UserCache) Invalidate(userID string) {
c.cache.Invalidate(userID)
}
func copyUser(user *gtsmodel.User) *gtsmodel.User {
return &gtsmodel.User{
ID: user.ID,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
Email: user.Email,
AccountID: user.AccountID,
Account: nil,
EncryptedPassword: user.EncryptedPassword,
SignUpIP: user.SignUpIP,
CurrentSignInAt: user.CurrentSignInAt,
CurrentSignInIP: user.CurrentSignInIP,
LastSignInAt: user.LastSignInAt,
LastSignInIP: user.LastSignInIP,
SignInCount: user.SignInCount,
InviteID: user.InviteID,
ChosenLanguages: user.ChosenLanguages,
FilteredLanguages: user.FilteredLanguages,
Locale: user.Locale,
CreatedByApplicationID: user.CreatedByApplicationID,
CreatedByApplication: nil,
LastEmailedAt: user.LastEmailedAt,
ConfirmationToken: user.ConfirmationToken,
ConfirmationSentAt: user.ConfirmationSentAt,
ConfirmedAt: user.ConfirmedAt,
UnconfirmedEmail: user.UnconfirmedEmail,
Moderator: copyBoolPtr(user.Moderator),
Admin: copyBoolPtr(user.Admin),
Disabled: copyBoolPtr(user.Disabled),
Approved: copyBoolPtr(user.Approved),
ResetPasswordToken: user.ResetPasswordToken,
ResetPasswordSentAt: user.ResetPasswordSentAt,
}
}

View file

@ -1,31 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
// copyBoolPtr returns a bool pointer with the same value as the pointer passed into it.
//
// Useful when copying things from the cache to a caller.
func copyBoolPtr(in *bool) *bool {
if in == nil {
return nil
}
b := new(bool)
*b = *in
return b
}

View file

@ -43,10 +43,10 @@ type Account interface {
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error)
// PutAccount puts one account in the database. // PutAccount puts one account in the database.
PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) PutAccount(ctx context.Context, account *gtsmodel.Account) Error
// UpdateAccount updates one account by ID. // UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) UpdateAccount(ctx context.Context, account *gtsmodel.Account) Error
// DeleteAccount deletes one account from the database by its ID. // DeleteAccount deletes one account from the database by its ID.
// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the // DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the

View file

@ -24,7 +24,7 @@
"strings" "strings"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/cache" "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -35,10 +35,29 @@
type accountDB struct { type accountDB struct {
conn *DBConn conn *DBConn
cache *cache.AccountCache cache *result.Cache[*gtsmodel.Account]
status *statusDB status *statusDB
} }
func (a *accountDB) init() {
// Initialize account result cache
a.cache = result.NewSized([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
{Name: "URL"},
{Name: "Username.Domain"},
{Name: "PublicKeyURI"},
}, func(a1 *gtsmodel.Account) *gtsmodel.Account {
a2 := new(gtsmodel.Account)
*a2 = *a1
return a2
}, 1000)
// Set cache TTL and start sweep routine
a.cache.SetTTL(time.Minute*5, false)
a.cache.Start(time.Second * 10)
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
return a.conn. return a.conn.
NewSelect(). NewSelect().
@ -51,45 +70,41 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
func() (*gtsmodel.Account, bool) { "ID",
return a.cache.GetByID(id)
},
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
}, },
id,
) )
} }
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
func() (*gtsmodel.Account, bool) { "URI",
return a.cache.GetByURI(uri)
},
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
}, },
uri,
) )
} }
func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
func() (*gtsmodel.Account, bool) { "URL",
return a.cache.GetByURL(url)
},
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
}, },
url,
) )
} }
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
username = strings.ToLower(username)
return a.getAccount( return a.getAccount(
ctx, ctx,
func() (*gtsmodel.Account, bool) { "Username.Domain",
return a.cache.GetByUsernameDomain(username, domain)
},
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
q := a.newAccountQ(account) q := a.newAccountQ(account)
@ -97,49 +112,61 @@ func(account *gtsmodel.Account) error {
q = q.Where("? = ?", bun.Ident("account.username"), username) q = q.Where("? = ?", bun.Ident("account.username"), username)
q = q.Where("? = ?", bun.Ident("account.domain"), domain) q = q.Where("? = ?", bun.Ident("account.domain"), domain)
} else { } else {
q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) q = q.Where("? = ?", bun.Ident("account.username"), username)
q = q.Where("? IS NULL", bun.Ident("account.domain")) q = q.Where("? IS NULL", bun.Ident("account.domain"))
} }
return q.Scan(ctx) return q.Scan(ctx)
}, },
username,
domain,
) )
} }
func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
func() (*gtsmodel.Account, bool) { "PublicKeyURI",
return a.cache.GetByPubkeyID(id)
},
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
}, },
id,
) )
} }
func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
// Attempt to fetch cached account var username string
account, cached := cacheGet()
if !cached { if domain == "" {
account = &gtsmodel.Account{} // I.e. our local instance account
username = config.GetHost()
} else {
// A remote instance account
username = domain
}
return a.GetAccountByUsernameDomain(ctx, username, domain)
}
func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) {
return a.cache.Load(lookup, func() (*gtsmodel.Account, error) {
var account gtsmodel.Account
// Not cached! Perform database query // Not cached! Perform database query
err := dbQuery(account) if err := dbQuery(&account); err != nil {
if err != nil {
return nil, a.conn.ProcessError(err) return nil, a.conn.ProcessError(err)
} }
// Place in the cache return &account, nil
a.cache.Put(account) }, keyParts...)
} }
return account, nil func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
} return a.cache.Store(account, func() error {
// It is safe to run this database transaction within cache.Store
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { // as the cache does not attempt a mutex lock until AFTER hook.
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { //
return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this account and any emojis it uses // create links between this account and any emojis it uses
for _, i := range account.EmojiIDs { for _, i := range account.EmojiIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{ if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
@ -153,19 +180,19 @@ func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (
// insert the account // insert the account
_, err := tx.NewInsert().Model(account).Exec(ctx) _, err := tx.NewInsert().Model(account).Exec(ctx)
return err return err
}); err != nil { })
return nil, a.conn.ProcessError(err) })
} }
a.cache.Put(account) func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
return account, nil
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
// Update the account's last-updated // Update the account's last-updated
account.UpdatedAt = time.Now() account.UpdatedAt = time.Now()
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { return a.cache.Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this account and any emojis it uses // create links between this account and any emojis it uses
// first clear out any old emoji links // first clear out any old emoji links
if _, err := tx. if _, err := tx.
@ -189,21 +216,13 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
} }
// update the account // update the account
if _, err := tx. _, err := tx.NewUpdate().
NewUpdate().
Model(account). Model(account).
Where("? = ?", bun.Ident("account.id"), account.ID). Where("? = ?", bun.Ident("account.id"), account.ID).
Exec(ctx); err != nil { Exec(ctx)
return err return err
} })
})
return nil
}); err != nil {
return nil, a.conn.ProcessError(err)
}
a.cache.Put(account)
return account, nil
} }
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
@ -219,40 +238,19 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
// delete the account // delete the account
_, err := tx. _, err := tx.
NewUpdate(). NewDelete().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Where("? = ?", bun.Ident("account.id"), id). Where("? = ?", bun.Ident("account.id"), id).
Exec(ctx) Exec(ctx)
return err return err
}); err != nil { }); err != nil {
return a.conn.ProcessError(err) return err
} }
a.cache.Invalidate(id) a.cache.Invalidate("ID", id)
return nil return nil
} }
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account)
if domain != "" {
q = q.
Where("? = ?", bun.Ident("account.username"), domain).
Where("? = ?", bun.Ident("account.domain"), domain)
} else {
q = q.
Where("? = ?", bun.Ident("account.username"), config.GetHost()).
WhereGroup(" AND ", whereEmptyOrNull("domain"))
}
if err := q.Scan(ctx); err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) { func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) {
createdAt := time.Time{} createdAt := time.Time{}

View file

@ -92,7 +92,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
testAccount.DisplayName = "new display name!" testAccount.DisplayName = "new display name!"
testAccount.EmojiIDs = []string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"} testAccount.EmojiIDs = []string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"}
_, err := suite.db.UpdateAccount(ctx, testAccount) err := suite.db.UpdateAccount(ctx, testAccount)
suite.NoError(err) suite.NoError(err)
updated, err := suite.db.GetAccountByID(ctx, testAccount.ID) updated, err := suite.db.GetAccountByID(ctx, testAccount.ID)
@ -127,7 +127,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
// update again to remove emoji associations // update again to remove emoji associations
testAccount.EmojiIDs = []string{} testAccount.EmojiIDs = []string{}
_, err = suite.db.UpdateAccount(ctx, testAccount) err = suite.db.UpdateAccount(ctx, testAccount)
suite.NoError(err) suite.NoError(err)
updated, err = suite.db.GetAccountByID(ctx, testAccount.ID) updated, err = suite.db.GetAccountByID(ctx, testAccount.ID)

View file

@ -29,7 +29,6 @@
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -45,8 +44,8 @@
type adminDB struct { type adminDB struct {
conn *DBConn conn *DBConn
userCache *cache.UserCache accounts *accountDB
accountCache *cache.AccountCache users *userDB
} }
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
@ -140,13 +139,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
} }
// insert the new account! // insert the new account!
if _, err = a.conn. if err := a.accounts.PutAccount(ctx, acct); err != nil {
NewInsert(). return nil, err
Model(acct).
Exec(ctx); err != nil {
return nil, a.conn.ProcessError(err)
} }
a.accountCache.Put(acct)
} }
// we either created or already had an account by now, // we either created or already had an account by now,
@ -190,13 +185,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
} }
// insert the user! // insert the user!
if _, err = a.conn. if err := a.users.PutUser(ctx, u); err != nil {
NewInsert(). return nil, err
Model(u).
Exec(ctx); err != nil {
return nil, a.conn.ProcessError(err)
} }
a.userCache.Put(u)
return u, nil return u, nil
} }
@ -249,15 +240,11 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
FeaturedCollectionURI: newAccountURIs.CollectionURI, FeaturedCollectionURI: newAccountURIs.CollectionURI,
} }
insertQ := a.conn. // insert the new account!
NewInsert(). if err := a.accounts.PutAccount(ctx, acct); err != nil {
Model(acct) return err
if _, err := insertQ.Exec(ctx); err != nil {
return a.conn.ProcessError(err)
} }
a.accountCache.Put(acct)
log.Infof("instance account %s CREATED with id %s", username, acct.ID) log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil return nil
} }

View file

@ -70,6 +70,8 @@ func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {
} }
func (suite *AdminTestSuite) TestCreateInstanceAccount() { func (suite *AdminTestSuite) TestCreateInstanceAccount() {
// reinitialize test DB to clear caches
suite.db = testrig.NewTestDB()
// we need to take an empty db for this... // we need to take an empty db for this...
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
// ...with tables created but no data // ...with tables created but no data

View file

@ -34,7 +34,6 @@
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib" "github.com/jackc/pgx/v4/stdlib"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations" "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations"
@ -46,7 +45,6 @@
"github.com/uptrace/bun/dialect/sqlitedialect" "github.com/uptrace/bun/dialect/sqlitedialect"
"github.com/uptrace/bun/migrate" "github.com/uptrace/bun/migrate"
grufcache "codeberg.org/gruf/go-cache/v2"
"modernc.org/sqlite" "modernc.org/sqlite"
) )
@ -160,52 +158,45 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
return nil, fmt.Errorf("db migration error: %s", err) return nil, fmt.Errorf("db migration error: %s", err)
} }
// Prepare caches required by more than one struct
userCache := cache.NewUserCache()
accountCache := cache.NewAccountCache()
// Prepare other caches
// Prepare mentions cache
// TODO: move into internal/cache
mentionCache := grufcache.New[string, *gtsmodel.Mention]()
mentionCache.SetTTL(time.Minute*5, false)
mentionCache.Start(time.Second * 10)
// Prepare notifications cache
// TODO: move into internal/cache
notifCache := grufcache.New[string, *gtsmodel.Notification]()
notifCache.SetTTL(time.Minute*5, false)
notifCache.Start(time.Second * 10)
// Create DB structs that require ptrs to each other // Create DB structs that require ptrs to each other
accounts := &accountDB{conn: conn, cache: accountCache} account := &accountDB{conn: conn}
status := &statusDB{conn: conn, cache: cache.NewStatusCache()} admin := &adminDB{conn: conn}
emoji := &emojiDB{conn: conn, emojiCache: cache.NewEmojiCache(), categoryCache: cache.NewEmojiCategoryCache()} domain := &domainDB{conn: conn}
mention := &mentionDB{conn: conn}
notif := &notificationDB{conn: conn}
status := &statusDB{conn: conn}
emoji := &emojiDB{conn: conn}
timeline := &timelineDB{conn: conn} timeline := &timelineDB{conn: conn}
tombstone := &tombstoneDB{conn: conn} tombstone := &tombstoneDB{conn: conn}
user := &userDB{conn: conn}
// Setup DB cross-referencing // Setup DB cross-referencing
accounts.status = status account.status = status
status.accounts = accounts admin.users = user
status.accounts = account
timeline.status = status timeline.status = status
// Initialize db structs // Initialize db structs
account.init()
domain.init()
emoji.init()
mention.init()
notif.init()
status.init()
tombstone.init() tombstone.init()
user.init()
ps := &DBService{ ps := &DBService{
Account: accounts, Account: account,
Admin: &adminDB{ Admin: &adminDB{
conn: conn, conn: conn,
userCache: userCache, accounts: account,
accountCache: accountCache, users: user,
}, },
Basic: &basicDB{ Basic: &basicDB{
conn: conn, conn: conn,
}, },
Domain: &domainDB{ Domain: domain,
conn: conn,
cache: cache.NewDomainBlockCache(),
},
Emoji: emoji, Emoji: emoji,
Instance: &instanceDB{ Instance: &instanceDB{
conn: conn, conn: conn,
@ -213,14 +204,8 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
Media: &mediaDB{ Media: &mediaDB{
conn: conn, conn: conn,
}, },
Mention: &mentionDB{ Mention: mention,
conn: conn, Notification: notif,
cache: mentionCache,
},
Notification: &notificationDB{
conn: conn,
cache: notifCache,
},
Relationship: &relationshipDB{ Relationship: &relationshipDB{
conn: conn, conn: conn,
}, },
@ -229,10 +214,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
}, },
Status: status, Status: status,
Timeline: timeline, Timeline: timeline,
User: &userDB{ User: user,
conn: conn,
cache: userCache,
},
Tombstone: tombstone, Tombstone: tombstone,
conn: conn, conn: conn,
} }

View file

@ -20,11 +20,11 @@
import ( import (
"context" "context"
"database/sql"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache" "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -34,7 +34,22 @@
type domainDB struct { type domainDB struct {
conn *DBConn conn *DBConn
cache *cache.DomainBlockCache cache *result.Cache[*gtsmodel.DomainBlock]
}
func (d *domainDB) init() {
// Initialize domain block result cache
d.cache = result.NewSized([]result.Lookup{
{Name: "Domain"},
}, func(d1 *gtsmodel.DomainBlock) *gtsmodel.DomainBlock {
d2 := new(gtsmodel.DomainBlock)
*d2 = *d1
return d2
}, 1000)
// Set cache TTL and start sweep routine
d.cache.SetTTL(time.Minute*5, false)
d.cache.Start(time.Second * 10)
} }
// normalizeDomain converts the given domain to lowercase // normalizeDomain converts the given domain to lowercase
@ -49,76 +64,53 @@ func normalizeDomain(domain string) (out string, err error) {
} }
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error {
domain, err := normalizeDomain(block.Domain) var err error
block.Domain, err = normalizeDomain(block.Domain)
if err != nil { if err != nil {
return err return err
} }
block.Domain = domain
// Attempt to insert new domain block return d.cache.Store(block, func() error {
if _, err := d.conn.NewInsert(). _, err := d.conn.NewInsert().
Model(block). Model(block).
Exec(ctx); err != nil { Exec(ctx)
return d.conn.ProcessError(err) return d.conn.ProcessError(err)
} })
// Cache this domain block
d.cache.Put(block.Domain, block)
return nil
} }
func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) {
var err error var err error
domain, err = normalizeDomain(domain) domain, err = normalizeDomain(domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return d.cache.Load("Domain", func() (*gtsmodel.DomainBlock, error) {
// Check for easy case, domain referencing *us* // Check for easy case, domain referencing *us*
if domain == "" || domain == config.GetAccountDomain() { if domain == "" || domain == config.GetAccountDomain() {
return nil, db.ErrNoEntries return nil, db.ErrNoEntries
} }
// Check for already cached rblock var block gtsmodel.DomainBlock
if block, ok := d.cache.GetByDomain(domain); ok {
// A 'nil' return value is a sentinel value for no block
if block == nil {
return nil, db.ErrNoEntries
}
// Else, this block exists
return block, nil
}
block := &gtsmodel.DomainBlock{}
q := d.conn. q := d.conn.
NewSelect(). NewSelect().
Model(block). Model(&block).
Where("? = ?", bun.Ident("domain_block.domain"), domain). Where("? = ?", bun.Ident("domain_block.domain"), domain).
Limit(1) Limit(1)
if err := q.Scan(ctx); err != nil {
// Query database for domain block
switch err := q.Scan(ctx); err {
// No error, block found
case nil:
d.cache.Put(domain, block)
return block, nil
// No error, simply not found
case sql.ErrNoRows:
d.cache.Put(domain, nil)
return nil, db.ErrNoEntries
// Any other db error
default:
return nil, d.conn.ProcessError(err) return nil, d.conn.ProcessError(err)
} }
return &block, nil
}, domain)
} }
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {
var err error var err error
domain, err = normalizeDomain(domain) domain, err = normalizeDomain(domain)
if err != nil { if err != nil {
return err return err
@ -133,7 +125,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
} }
// Clear domain from cache // Clear domain from cache
d.cache.InvalidateByDomain(domain) d.cache.Invalidate(domain)
return nil return nil
} }

View file

@ -23,7 +23,7 @@
"strings" "strings"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/cache" "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
@ -33,8 +33,40 @@
type emojiDB struct { type emojiDB struct {
conn *DBConn conn *DBConn
emojiCache *cache.EmojiCache emojiCache *result.Cache[*gtsmodel.Emoji]
categoryCache *cache.EmojiCategoryCache categoryCache *result.Cache[*gtsmodel.EmojiCategory]
}
func (e *emojiDB) init() {
// Initialize emoji result cache
e.emojiCache = result.NewSized([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
{Name: "Shortcode.Domain"},
{Name: "ImageStaticURL"},
}, func(e1 *gtsmodel.Emoji) *gtsmodel.Emoji {
e2 := new(gtsmodel.Emoji)
*e2 = *e1
return e2
}, 1000)
// Set cache TTL and start sweep routine
e.emojiCache.SetTTL(time.Minute*5, false)
e.emojiCache.Start(time.Second * 10)
// Initialize category result cache
e.categoryCache = result.NewSized([]result.Lookup{
{Name: "ID"},
{Name: "Name"},
}, func(c1 *gtsmodel.EmojiCategory) *gtsmodel.EmojiCategory {
c2 := new(gtsmodel.EmojiCategory)
*c2 = *c1
return c2
}, 1000)
// Set cache TTL and start sweep routine
e.categoryCache.SetTTL(time.Minute*5, false)
e.categoryCache.Start(time.Second * 10)
} }
func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery { func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery {
@ -51,12 +83,10 @@ func (e *emojiDB) newEmojiCategoryQ(emojiCategory *gtsmodel.EmojiCategory) *bun.
} }
func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error { func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error {
if _, err := e.conn.NewInsert().Model(emoji).Exec(ctx); err != nil { return e.emojiCache.Store(emoji, func() error {
_, err := e.conn.NewInsert().Model(emoji).Exec(ctx)
return e.conn.ProcessError(err) return e.conn.ProcessError(err)
} })
e.emojiCache.Put(emoji)
return nil
} }
func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) (*gtsmodel.Emoji, db.Error) {
@ -72,7 +102,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column
return nil, e.conn.ProcessError(err) return nil, e.conn.ProcessError(err)
} }
e.emojiCache.Invalidate(emoji.ID) e.emojiCache.Invalidate("ID", emoji.ID)
return emoji, nil return emoji, nil
} }
@ -109,7 +139,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {
return err return err
} }
e.emojiCache.Invalidate(id) e.emojiCache.Invalidate("ID", id)
return nil return nil
} }
@ -252,33 +282,29 @@ func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.E
func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) {
return e.getEmoji( return e.getEmoji(
ctx, ctx,
func() (*gtsmodel.Emoji, bool) { "ID",
return e.emojiCache.GetByID(id)
},
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx) return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
}, },
id,
) )
} }
func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) {
return e.getEmoji( return e.getEmoji(
ctx, ctx,
func() (*gtsmodel.Emoji, bool) { "URI",
return e.emojiCache.GetByURI(uri)
},
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx) return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
}, },
uri,
) )
} }
func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) {
return e.getEmoji( return e.getEmoji(
ctx, ctx,
func() (*gtsmodel.Emoji, bool) { "Shortcode.Domain",
return e.emojiCache.GetByShortcodeDomain(shortcode, domain)
},
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
q := e.newEmojiQ(emoji) q := e.newEmojiQ(emoji)
@ -292,31 +318,30 @@ func(emoji *gtsmodel.Emoji) error {
return q.Scan(ctx) return q.Scan(ctx)
}, },
shortcode,
domain,
) )
} }
func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) {
return e.getEmoji( return e.getEmoji(
ctx, ctx,
func() (*gtsmodel.Emoji, bool) { "ImageStaticURL",
return e.emojiCache.GetByImageStaticURL(imageStaticURL)
},
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
return e. return e.
newEmojiQ(emoji). newEmojiQ(emoji).
Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL). Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL).
Scan(ctx) Scan(ctx)
}, },
imageStaticURL,
) )
} }
func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error { func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error {
if _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx); err != nil { return e.categoryCache.Store(emojiCategory, func() error {
_, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx)
return e.conn.ProcessError(err) return e.conn.ProcessError(err)
} })
e.categoryCache.Put(emojiCategory)
return nil
} }
func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) {
@ -338,45 +363,36 @@ func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCate
func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) {
return e.getEmojiCategory( return e.getEmojiCategory(
ctx, ctx,
func() (*gtsmodel.EmojiCategory, bool) { "ID",
return e.categoryCache.GetByID(id)
},
func(emojiCategory *gtsmodel.EmojiCategory) error { func(emojiCategory *gtsmodel.EmojiCategory) error {
return e.newEmojiCategoryQ(emojiCategory).Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx) return e.newEmojiCategoryQ(emojiCategory).Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx)
}, },
id,
) )
} }
func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) {
return e.getEmojiCategory( return e.getEmojiCategory(
ctx, ctx,
func() (*gtsmodel.EmojiCategory, bool) { "Name",
return e.categoryCache.GetByName(name)
},
func(emojiCategory *gtsmodel.EmojiCategory) error { func(emojiCategory *gtsmodel.EmojiCategory) error {
return e.newEmojiCategoryQ(emojiCategory).Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx) return e.newEmojiCategoryQ(emojiCategory).Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx)
}, },
name,
) )
} }
func (e *emojiDB) getEmoji(ctx context.Context, cacheGet func() (*gtsmodel.Emoji, bool), dbQuery func(*gtsmodel.Emoji) error) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, db.Error) {
// Attempt to fetch cached emoji return e.emojiCache.Load(lookup, func() (*gtsmodel.Emoji, error) {
emoji, cached := cacheGet() var emoji gtsmodel.Emoji
if !cached {
emoji = &gtsmodel.Emoji{}
// Not cached! Perform database query // Not cached! Perform database query
err := dbQuery(emoji) if err := dbQuery(&emoji); err != nil {
if err != nil {
return nil, e.conn.ProcessError(err) return nil, e.conn.ProcessError(err)
} }
// Place in the cache return &emoji, nil
e.emojiCache.Put(emoji) }, keyParts...)
}
return emoji, nil
} }
func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) { func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) {
@ -399,24 +415,17 @@ func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsm
return emojis, nil return emojis, nil
} }
func (e *emojiDB) getEmojiCategory(ctx context.Context, cacheGet func() (*gtsmodel.EmojiCategory, bool), dbQuery func(*gtsmodel.EmojiCategory) error) (*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, db.Error) {
// Attempt to fetch cached emoji categories return e.categoryCache.Load(lookup, func() (*gtsmodel.EmojiCategory, error) {
emojiCategory, cached := cacheGet() var category gtsmodel.EmojiCategory
if !cached {
emojiCategory = &gtsmodel.EmojiCategory{}
// Not cached! Perform database query // Not cached! Perform database query
err := dbQuery(emojiCategory) if err := dbQuery(&category); err != nil {
if err != nil {
return nil, e.conn.ProcessError(err) return nil, e.conn.ProcessError(err)
} }
// Place in the cache return &category, nil
e.categoryCache.Put(emojiCategory) }, keyParts...)
}
return emojiCategory, nil
} }
func (e *emojiDB) emojiCategoriesFromIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) emojiCategoriesFromIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) {

View file

@ -20,8 +20,9 @@
import ( import (
"context" "context"
"time"
"codeberg.org/gruf/go-cache/v2" "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
@ -30,7 +31,22 @@
type mentionDB struct { type mentionDB struct {
conn *DBConn conn *DBConn
cache cache.Cache[string, *gtsmodel.Mention] cache *result.Cache[*gtsmodel.Mention]
}
func (m *mentionDB) init() {
// Initialize notification result cache
m.cache = result.NewSized([]result.Lookup{
{Name: "ID"},
}, func(m1 *gtsmodel.Mention) *gtsmodel.Mention {
m2 := new(gtsmodel.Mention)
*m2 = *m1
return m2
}, 1000)
// Set cache TTL and start sweep routine
m.cache.SetTTL(time.Minute*5, false)
m.cache.Start(time.Second * 10)
} }
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
@ -42,8 +58,9 @@ func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
Relation("TargetAccount") Relation("TargetAccount")
} }
func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
mention := gtsmodel.Mention{} return m.cache.Load("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention
q := m.newMentionQ(&mention). q := m.newMentionQ(&mention).
Where("? = ?", bun.Ident("mention.id"), id) Where("? = ?", bun.Ident("mention.id"), id)
@ -52,17 +69,8 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment
return nil, m.conn.ProcessError(err) return nil, m.conn.ProcessError(err)
} }
copy := mention
m.cache.Set(mention.ID, &copy)
return &mention, nil return &mention, nil
} }, id)
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
if mention, ok := m.cache.Get(id); ok {
return mention, nil
}
return m.getMentionDB(ctx, id)
} }
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {

View file

@ -20,8 +20,9 @@
import ( import (
"context" "context"
"time"
"codeberg.org/gruf/go-cache/v2" "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
@ -30,31 +31,40 @@
type notificationDB struct { type notificationDB struct {
conn *DBConn conn *DBConn
cache cache.Cache[string, *gtsmodel.Notification] cache *result.Cache[*gtsmodel.Notification]
}
func (n *notificationDB) init() {
// Initialize notification result cache
n.cache = result.NewSized([]result.Lookup{
{Name: "ID"},
}, func(n1 *gtsmodel.Notification) *gtsmodel.Notification {
n2 := new(gtsmodel.Notification)
*n2 = *n1
return n2
}, 1000)
// Set cache TTL and start sweep routine
n.cache.SetTTL(time.Minute*5, false)
n.cache.Start(time.Second * 10)
} }
func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
if notification, ok := n.cache.Get(id); ok { return n.cache.Load("ID", func() (*gtsmodel.Notification, error) {
return notification, nil var notif gtsmodel.Notification
}
dst := gtsmodel.Notification{ID: id}
q := n.conn.NewSelect(). q := n.conn.NewSelect().
Model(&dst). Model(&notif).
Relation("OriginAccount"). Relation("OriginAccount").
Relation("TargetAccount"). Relation("TargetAccount").
Relation("Status"). Relation("Status").
Where("? = ?", bun.Ident("notification.id"), id) Where("? = ?", bun.Ident("notification.id"), id)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, n.conn.ProcessError(err) return nil, n.conn.ProcessError(err)
} }
copy := dst return &notif, nil
n.cache.Set(id, &copy) }, id)
return &dst, nil
} }
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {

View file

@ -25,7 +25,7 @@
"errors" "errors"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/cache" "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
@ -34,14 +34,27 @@
type statusDB struct { type statusDB struct {
conn *DBConn conn *DBConn
cache *cache.StatusCache cache *result.Cache[*gtsmodel.Status]
// TODO: keep method definitions in same place but instead have receiver
// all point to one single "db" type, so they can all share methods
// and caches where necessary
accounts *accountDB accounts *accountDB
} }
func (s *statusDB) init() {
// Initialize status result cache
s.cache = result.NewSized([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
{Name: "URL"},
}, func(s1 *gtsmodel.Status) *gtsmodel.Status {
s2 := new(gtsmodel.Status)
*s2 = *s1
return s2
}, 1000)
// Set cache TTL and start sweep routine
s.cache.SetTTL(time.Minute*5, false)
s.cache.Start(time.Second * 10)
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn. return s.conn.
NewSelect(). NewSelect().
@ -68,61 +81,62 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
return s.getStatus( return s.getStatus(
ctx, ctx,
func() (*gtsmodel.Status, bool) { "ID",
return s.cache.GetByID(id)
},
func(status *gtsmodel.Status) error { func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)
}, },
id,
) )
} }
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
return s.getStatus( return s.getStatus(
ctx, ctx,
func() (*gtsmodel.Status, bool) { "URI",
return s.cache.GetByURI(uri)
},
func(status *gtsmodel.Status) error { func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)
}, },
uri,
) )
} }
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
return s.getStatus( return s.getStatus(
ctx, ctx,
func() (*gtsmodel.Status, bool) { "URL",
return s.cache.GetByURL(url)
},
func(status *gtsmodel.Status) error { func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)
}, },
url,
) )
} }
func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) { func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, db.Error) {
// Attempt to fetch cached status // Fetch status from database cache with loader callback
status, cached := cacheGet() status, err := s.cache.Load(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status
if !cached {
status = &gtsmodel.Status{}
// Not cached! Perform database query // Not cached! Perform database query
if err := dbQuery(status); err != nil { if err := dbQuery(&status); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.conn.ProcessError(err)
} }
// If there is boosted, fetch from DB also // If there is boosted, fetch from DB also
if status.BoostOfID != "" { if status.BoostOfID != "" {
boostOf, err := s.GetStatusByID(ctx, status.BoostOfID) status.BoostOf = &gtsmodel.Status{}
if err == nil { err := s.newStatusQ(status.BoostOf).
status.BoostOf = boostOf Where("? = ?", bun.Ident("status.id"), status.BoostOfID).
Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
} }
} }
// Place in the cache return &status, nil
s.cache.Put(status) }, keyParts...)
if err != nil {
// error already processed
return nil, err
} }
// Set the status author account // Set the status author account
@ -137,7 +151,11 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta
} }
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { return s.cache.Store(status, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses // create links between this status and any emojis it uses
for _, i := range status.EmojiIDs { for _, i := range status.EmojiIDs {
if _, err := tx. if _, err := tx.
@ -146,7 +164,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
StatusID: status.ID, StatusID: status.ID,
EmojiID: i, EmojiID: i,
}).Exec(ctx); err != nil { }).Exec(ctx); err != nil {
err = s.conn.errProc(err) err = s.conn.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -161,7 +179,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
StatusID: status.ID, StatusID: status.ID,
TagID: i, TagID: i,
}).Exec(ctx); err != nil { }).Exec(ctx); err != nil {
err = s.conn.errProc(err) err = s.conn.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -177,7 +195,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
Model(a). Model(a).
Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
err = s.conn.errProc(err) err = s.conn.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -185,25 +203,14 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
} }
// Finally, insert the status // Finally, insert the status
if _, err := tx. _, err := tx.NewInsert().Model(status).Exec(ctx)
NewInsert().
Model(status).
Exec(ctx); err != nil {
return err return err
}
return nil
}) })
if err != nil { })
return s.conn.ProcessError(err)
} }
s.cache.Put(status) func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
return nil if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
}
func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) {
err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses // create links between this status and any emojis it uses
for _, i := range status.EmojiIDs { for _, i := range status.EmojiIDs {
if _, err := tx. if _, err := tx.
@ -212,7 +219,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
StatusID: status.ID, StatusID: status.ID,
EmojiID: i, EmojiID: i,
}).Exec(ctx); err != nil { }).Exec(ctx); err != nil {
err = s.conn.errProc(err) err = s.conn.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -227,14 +234,14 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
StatusID: status.ID, StatusID: status.ID,
TagID: i, TagID: i,
}).Exec(ctx); err != nil { }).Exec(ctx); err != nil {
err = s.conn.errProc(err) err = s.conn.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
} }
} }
// change the status ID of the media attachments to this status // change the status ID of the media attachments to the new status
for _, a := range status.Attachments { for _, a := range status.Attachments {
a.StatusID = status.ID a.StatusID = status.ID
a.UpdatedAt = time.Now() a.UpdatedAt = time.Now()
@ -243,31 +250,31 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
Model(a). Model(a).
Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
err = s.conn.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
} }
}
// Finally, update the status itself // Finally, insert the status
if _, err := tx. _, err := tx.
NewUpdate(). NewUpdate().
Model(status). Model(status).
Where("? = ?", bun.Ident("status.id"), status.ID). Where("? = ?", bun.Ident("status.id"), status.ID).
Exec(ctx); err != nil { Exec(ctx)
return err
}); err != nil {
return err return err
} }
// Drop any old value from cache by this ID
s.cache.Invalidate("ID", status.ID)
return nil return nil
})
if err != nil {
return nil, s.conn.ProcessError(err)
}
s.cache.Put(status)
return status, nil
} }
func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// delete links between this status and any emojis it uses // delete links between this status and any emojis it uses
if _, err := tx. if _, err := tx.
NewDelete(). NewDelete().
@ -296,36 +303,41 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
} }
return nil return nil
}) }); err != nil {
if err != nil { return err
return s.conn.ProcessError(err)
} }
s.cache.Invalidate(id) // Drop any old value from cache by this ID
s.cache.Invalidate("ID", id)
return nil return nil
} }
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(ctx, status, &parents, onlyDirect)
return parents, nil
}
func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
if err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect { if onlyDirect {
return // Only want the direct parent, no further than first level
parent, err := s.GetStatusByID(ctx, status.InReplyToID)
if err != nil {
return nil, err
}
return []*gtsmodel.Status{parent}, nil
} }
s.statusParent(ctx, parentStatus, foundStatuses, false) var parents []*gtsmodel.Status
for id := status.InReplyToID; id != ""; {
parent, err := s.GetStatusByID(ctx, id)
if err != nil {
return nil, err
}
// Append parent to slice
parents = append(parents, parent)
// Set the next parent ID
id = parent.InReplyToID
}
return parents, nil
} }
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
@ -350,7 +362,7 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu
} }
func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
childIDs := []string{} var childIDs []string
q := s.conn. q := s.conn.
NewSelect(). NewSelect().
@ -471,6 +483,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.conn.ProcessError(err)
} }
return faves, nil return faves, nil
} }

View file

@ -35,44 +35,52 @@ type TimelineTestSuite struct {
} }
func (suite *TimelineTestSuite) TestGetPublicTimeline() { func (suite *TimelineTestSuite) TestGetPublicTimeline() {
s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false) ctx := context.Background()
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err) suite.NoError(err)
suite.Len(s, 6) suite.Len(s, 6)
} }
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
futureStatus := getFutureStatus() ctx := context.Background()
if err := suite.db.Put(context.Background(), futureStatus); err != nil {
suite.FailNow(err.Error())
}
s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false) futureStatus := getFutureStatus()
err := suite.db.PutStatus(ctx, futureStatus)
suite.NoError(err) suite.NoError(err)
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
suite.NotContains(s, futureStatus)
suite.Len(s, 6) suite.Len(s, 6)
} }
func (suite *TimelineTestSuite) TestGetHomeTimeline() { func (suite *TimelineTestSuite) TestGetHomeTimeline() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"] viewingAccount := suite.testAccounts["local_account_1"]
s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
suite.NoError(err) suite.NoError(err)
suite.Len(s, 16) suite.Len(s, 16)
} }
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"] viewingAccount := suite.testAccounts["local_account_1"]
futureStatus := getFutureStatus() futureStatus := getFutureStatus()
if err := suite.db.Put(context.Background(), futureStatus); err != nil { err := suite.db.PutStatus(ctx, futureStatus)
suite.FailNow(err.Error()) suite.NoError(err)
}
s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false)
suite.NoError(err) suite.NoError(err)
suite.NotContains(s, futureStatus)
suite.Len(s, 16) suite.Len(s, 16)
} }

View file

@ -43,7 +43,7 @@ func (t *tombstoneDB) init() {
t2 := new(gtsmodel.Tombstone) t2 := new(gtsmodel.Tombstone)
*t2 = *t1 *t2 = *t1
return t2 return t2
}, 1000) }, 100)
// Set cache TTL and start sweep routine // Set cache TTL and start sweep routine
t.cache.SetTTL(time.Minute*5, false) t.cache.SetTTL(time.Minute*5, false)

View file

@ -22,7 +22,7 @@
"context" "context"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/cache" "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@ -30,111 +30,121 @@
type userDB struct { type userDB struct {
conn *DBConn conn *DBConn
cache *cache.UserCache cache *result.Cache[*gtsmodel.User]
} }
func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery { func (u *userDB) init() {
return u.conn. // Initialize user result cache
NewSelect(). u.cache = result.NewSized([]result.Lookup{
Model(user). {Name: "ID"},
Relation("Account") {Name: "AccountID"},
} {Name: "Email"},
{Name: "ConfirmationToken"},
}, func(u1 *gtsmodel.User) *gtsmodel.User {
u2 := new(gtsmodel.User)
*u2 = *u1
return u2
}, 1000)
func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) { // Set cache TTL and start sweep routine
// Attempt to fetch cached user u.cache.SetTTL(time.Minute*5, false)
user, cached := cacheGet() u.cache.Start(time.Second * 10)
if !cached {
user = &gtsmodel.User{}
// Not cached! Perform database query
err := dbQuery(user)
if err != nil {
return nil, u.conn.ProcessError(err)
}
// Place in the cache
u.cache.Put(user)
}
return user, nil
} }
func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) {
return u.getUser( return u.cache.Load("ID", func() (*gtsmodel.User, error) {
ctx, var user gtsmodel.User
func() (*gtsmodel.User, bool) {
return u.cache.GetByID(id) q := u.conn.
}, NewSelect().
func(user *gtsmodel.User) error { Model(&user).
return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx) Relation("Account").
}, Where("? = ?", bun.Ident("user.id"), id)
)
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
}
return &user, nil
}, id)
} }
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) { func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) {
return u.getUser( return u.cache.Load("AccountID", func() (*gtsmodel.User, error) {
ctx, var user gtsmodel.User
func() (*gtsmodel.User, bool) {
return u.cache.GetByAccountID(accountID) q := u.conn.
}, NewSelect().
func(user *gtsmodel.User) error { Model(&user).
return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx) Relation("Account").
}, Where("? = ?", bun.Ident("user.account_id"), accountID)
)
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
}
return &user, nil
}, accountID)
} }
func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) { func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) {
return u.getUser( return u.cache.Load("Email", func() (*gtsmodel.User, error) {
ctx, var user gtsmodel.User
func() (*gtsmodel.User, bool) {
return u.cache.GetByEmail(emailAddress) q := u.conn.
}, NewSelect().
func(user *gtsmodel.User) error { Model(&user).
return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx) Relation("Account").
}, Where("? = ?", bun.Ident("user.email"), emailAddress)
)
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
}
return &user, nil
}, emailAddress)
} }
func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) { func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) {
return u.getUser( return u.cache.Load("ConfirmationToken", func() (*gtsmodel.User, error) {
ctx, var user gtsmodel.User
func() (*gtsmodel.User, bool) {
return u.cache.GetByConfirmationToken(confirmationToken)
},
func(user *gtsmodel.User) error {
return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx)
},
)
}
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) { q := u.conn.
if _, err := u.conn. NewSelect().
NewInsert(). Model(&user).
Model(user). Relation("Account").
Exec(ctx); err != nil { Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken)
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err) return nil, u.conn.ProcessError(err)
} }
u.cache.Put(user) return &user, nil
return user, nil }, confirmationToken)
} }
func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) { func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) db.Error {
return u.cache.Store(user, func() error {
_, err := u.conn.
NewInsert().
Model(user).
Exec(ctx)
return u.conn.ProcessError(err)
})
}
func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User) db.Error {
// Update the user's last-updated // Update the user's last-updated
user.UpdatedAt = time.Now() user.UpdatedAt = time.Now()
if _, err := u.conn. return u.cache.Store(user, func() error {
_, err := u.conn.
NewUpdate(). NewUpdate().
Model(user). Model(user).
Where("? = ?", bun.Ident("user.id"), user.ID). Where("? = ?", bun.Ident("user.id"), user.ID).
Column(columns...). Exec(ctx)
Exec(ctx); err != nil { return u.conn.ProcessError(err)
return nil, u.conn.ProcessError(err) })
}
u.cache.Invalidate(user.ID)
return user, nil
} }
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
@ -146,6 +156,7 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
return u.conn.ProcessError(err) return u.conn.ProcessError(err)
} }
u.cache.Invalidate(userID) // Invalidate user from cache
u.cache.Invalidate("ID", userID)
return nil return nil
} }

View file

@ -50,21 +50,20 @@ func (suite *UserTestSuite) TestGetUserByAccountID() {
func (suite *UserTestSuite) TestUpdateUserSelectedColumns() { func (suite *UserTestSuite) TestUpdateUserSelectedColumns() {
testUser := suite.testUsers["local_account_1"] testUser := suite.testUsers["local_account_1"]
user := &gtsmodel.User{
ID: testUser.ID,
Email: "whatever",
Locale: "es",
}
user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale") updateUser := new(gtsmodel.User)
*updateUser = *testUser
updateUser.Email = "whatever"
updateUser.Locale = "es"
err := suite.db.UpdateUser(context.Background(), updateUser)
suite.NoError(err) suite.NoError(err)
suite.NotNil(user)
dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID) dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID)
suite.NoError(err) suite.NoError(err)
suite.NotNil(dbUser) suite.NotNil(dbUser)
suite.Equal("whatever", dbUser.Email) suite.Equal(updateUser.Email, dbUser.Email)
suite.Equal("es", dbUser.Locale) suite.Equal(updateUser.Locale, dbUser.Locale)
suite.Equal(testUser.AccountID, dbUser.AccountID) suite.Equal(testUser.AccountID, dbUser.AccountID)
} }

View file

@ -39,7 +39,7 @@ type Status interface {
PutStatus(ctx context.Context, status *gtsmodel.Status) Error PutStatus(ctx context.Context, status *gtsmodel.Status) Error
// UpdateStatus updates one status in the database and returns it to the caller. // UpdateStatus updates one status in the database and returns it to the caller.
UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, Error) UpdateStatus(ctx context.Context, status *gtsmodel.Status) Error
// DeleteStatusByID deletes one status from the database. // DeleteStatusByID deletes one status from the database.
DeleteStatusByID(ctx context.Context, id string) Error DeleteStatusByID(ctx context.Context, id string) Error

View file

@ -34,9 +34,10 @@ type User interface {
GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error)
// GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong. // GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong.
GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error)
// UpdateUser updates one user by its primary key. If columns is set, only given columns // PutUser will attempt to place user in the database
// will be updated. If not set, all columns will be updated. PutUser(ctx context.Context, user *gtsmodel.User) Error
UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, Error) // UpdateUser updates one user by its primary key.
UpdateUser(ctx context.Context, user *gtsmodel.User) Error
// DeleteUserByID deletes one user by its ID. // DeleteUserByID deletes one user by its ID.
DeleteUserByID(ctx context.Context, userID string) Error DeleteUserByID(ctx context.Context, userID string) Error
} }

View file

@ -276,7 +276,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
foundAccount.LastWebfingeredAt = fingered foundAccount.LastWebfingeredAt = fingered
foundAccount.UpdatedAt = time.Now() foundAccount.UpdatedAt = time.Now()
foundAccount, err = d.db.PutAccount(ctx, foundAccount) err = d.db.PutAccount(ctx, foundAccount)
if err != nil { if err != nil {
err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err) err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err)
return return
@ -338,7 +338,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
} }
if accountDomainChanged || sharedInboxChanged || fieldsChanged || fingeredChanged { if accountDomainChanged || sharedInboxChanged || fieldsChanged || fingeredChanged {
foundAccount, err = d.db.UpdateAccount(ctx, foundAccount) err = d.db.UpdateAccount(ctx, foundAccount)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err) return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err)
} }

View file

@ -107,7 +107,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountAsRemoteURLNoSharedInb
targetAccount := suite.testAccounts["local_account_2"] targetAccount := suite.testAccounts["local_account_2"]
targetAccount.SharedInboxURI = nil targetAccount.SharedInboxURI = nil
if _, err := suite.db.UpdateAccount(context.Background(), targetAccount); err != nil { if err := suite.db.UpdateAccount(context.Background(), targetAccount); err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }

View file

@ -45,8 +45,10 @@ func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status
if err := d.populateStatusFields(ctx, status, username, includeParent); err != nil { if err := d.populateStatusFields(ctx, status, username, includeParent); err != nil {
return nil, err return nil, err
} }
if err := d.db.UpdateStatus(ctx, status); err != nil {
return d.db.UpdateStatus(ctx, status) return nil, err
}
return status, nil
} }
// GetRemoteStatus completely dereferences a remote status, converts it to a GtS model status, // GetRemoteStatus completely dereferences a remote status, converts it to a GtS model status,

View file

@ -68,7 +68,7 @@ func (suite *InboxTestSuite) TestInboxesForAccountIRIWithSharedInbox() {
testAccount := suite.testAccounts["local_account_1"] testAccount := suite.testAccounts["local_account_1"]
sharedInbox := "http://some-inbox-iri/weeeeeeeeeeeee" sharedInbox := "http://some-inbox-iri/weeeeeeeeeeeee"
testAccount.SharedInboxURI = &sharedInbox testAccount.SharedInboxURI = &sharedInbox
if _, err := suite.db.UpdateAccount(ctx, testAccount); err != nil { if err := suite.db.UpdateAccount(ctx, testAccount); err != nil {
suite.FailNow("error updating account") suite.FailNow("error updating account")
} }

View file

@ -273,7 +273,7 @@ func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
account.SuspendedAt = time.Now() account.SuspendedAt = time.Now()
account.SuspensionOrigin = origin account.SuspensionOrigin = origin
account, err := p.db.UpdateAccount(ctx, account) err := p.db.UpdateAccount(ctx, account)
if err != nil { if err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }

View file

@ -164,7 +164,7 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form
account.EnableRSS = form.EnableRSS account.EnableRSS = form.EnableRSS
} }
updatedAccount, err := p.db.UpdateAccount(ctx, account) err := p.db.UpdateAccount(ctx, account)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not update account %s: %s", account.ID, err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not update account %s: %s", account.ID, err))
} }
@ -172,11 +172,11 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form
p.clientWorker.Queue(messages.FromClientAPI{ p.clientWorker.Queue(messages.FromClientAPI{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate, APActivityType: ap.ActivityUpdate,
GTSModel: updatedAccount, GTSModel: account,
OriginAccount: updatedAccount, OriginAccount: account,
}) })
acctSensitive, err := p.tc.AccountToAPIAccountSensitive(ctx, updatedAccount) acctSensitive, err := p.tc.AccountToAPIAccountSensitive(ctx, account)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not convert account into apisensitive account: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not convert account into apisensitive account: %s", err))
} }

View file

@ -129,7 +129,7 @@ func (suite *FromClientAPITestSuite) TestProcessStatusDelete() {
suite.NoError(errWithCode) suite.NoError(errWithCode)
// delete the status from the db first, to mimic what would have already happened earlier up the flow // delete the status from the db first, to mimic what would have already happened earlier up the flow
err := suite.db.DeleteByID(ctx, deletedStatus.ID, &gtsmodel.Status{}) err := suite.db.DeleteStatusByID(ctx, deletedStatus.ID)
suite.NoError(err) suite.NoError(err)
// process the status delete // process the status delete

View file

@ -235,7 +235,7 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
if updateInstanceAccount { if updateInstanceAccount {
// if either avatar or header is updated, we need // if either avatar or header is updated, we need
// to update the instance account that stores them // to update the instance account that stores them
if _, err := p.db.UpdateAccount(ctx, ia); err != nil { if err := p.db.UpdateAccount(ctx, ia); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance account: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance account: %s", err))
} }
} }

View file

@ -28,7 +28,7 @@
"time" "time"
"codeberg.org/gruf/go-byteutil" "codeberg.org/gruf/go-byteutil"
"codeberg.org/gruf/go-cache/v2" "codeberg.org/gruf/go-cache/v3"
"github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
@ -67,8 +67,8 @@ func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, clie
fedDB: federatingDB, fedDB: federatingDB,
clock: clock, clock: clock,
client: client, client: client,
trspCache: cache.New[string, *transport](), trspCache: cache.New[string, *transport](0, 100, 0),
badHosts: cache.New[string, struct{}](), badHosts: cache.New[string, struct{}](0, 1000, 0),
userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, version), userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, version),
} }
@ -110,7 +110,7 @@ func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Tra
} }
// Cache this transport under pubkey // Cache this transport under pubkey
if !c.trspCache.Put(pubStr, transp) { if !c.trspCache.Add(pubStr, transp) {
var cached *transport var cached *transport
cached, ok = c.trspCache.Get(pubStr) cached, ok = c.trspCache.Get(pubStr)

View file

@ -27,11 +27,11 @@
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"codeberg.org/gruf/go-cache/v2" "codeberg.org/gruf/go-cache/v3"
) )
func newETagCache() cache.Cache[string, eTagCacheEntry] { func newETagCache() cache.Cache[string, eTagCacheEntry] {
eTagCache := cache.New[string, eTagCacheEntry]() eTagCache := cache.New[string, eTagCacheEntry](0, 1000, 0)
eTagCache.SetTTL(time.Hour, false) eTagCache.SetTTL(time.Hour, false)
if !eTagCache.Start(time.Minute) { if !eTagCache.Start(time.Minute) {
log.Panic("could not start eTagCache") log.Panic("could not start eTagCache")

View file

@ -123,7 +123,7 @@ func (m *Module) rssFeedGETHandler(c *gin.Context) {
cacheEntry.lastModified = accountLastPostedPublic cacheEntry.lastModified = accountLastPostedPublic
cacheEntry.eTag = eTag cacheEntry.eTag = eTag
m.eTagCache.Put(cacheKey, cacheEntry) m.eTagCache.Set(cacheKey, cacheEntry)
} }
c.Header(eTagHeader, cacheEntry.eTag) c.Header(eTagHeader, cacheEntry.eTag)

View file

@ -22,7 +22,7 @@
"errors" "errors"
"net/http" "net/http"
"codeberg.org/gruf/go-cache/v2" "codeberg.org/gruf/go-cache/v3"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"

View file

@ -1,9 +0,0 @@
MIT License
Copyright (c) 2021 gruf
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -1,3 +0,0 @@
# go-cache
A TTL cache designed to be used as a base for your own customizations, or used straight out of the box

View file

@ -1,67 +0,0 @@
package cache
import "time"
// Cache represents a TTL cache with customizable callbacks, it
// exists here to abstract away the "unsafe" methods in the case that
// you do not want your own implementation atop TTLCache{}.
type Cache[Key comparable, Value any] interface {
// Start will start the cache background eviction routine with given sweep frequency.
// If already running or a freq <= 0 provided, this is a no-op. This will block until
// the eviction routine has started
Start(freq time.Duration) bool
// Stop will stop cache background eviction routine. If not running this is a no-op. This
// will block until the eviction routine has stopped
Stop() bool
// SetEvictionCallback sets the eviction callback to the provided hook
SetEvictionCallback(hook Hook[Key, Value])
// SetInvalidateCallback sets the invalidate callback to the provided hook
SetInvalidateCallback(hook Hook[Key, Value])
// SetTTL sets the cache item TTL. Update can be specified to force updates of existing items in
// the cache, this will simply add the change in TTL to their current expiry time
SetTTL(ttl time.Duration, update bool)
// Get fetches the value with key from the cache, extending its TTL
Get(key Key) (value Value, ok bool)
// Put attempts to place the value at key in the cache, doing nothing if
// a value with this key already exists. Returned bool is success state
Put(key Key, value Value) bool
// Set places the value at key in the cache. This will overwrite any
// existing value, and call the update callback so. Existing values
// will have their TTL extended upon update
Set(key Key, value Value)
// CAS will attempt to perform a CAS operation on 'key', using provided
// comparison and swap values. Returned bool is success.
CAS(key Key, cmp, swp Value) bool
// Swap will attempt to perform a swap on 'key', replacing the value there
// and returning the existing value. If no value exists for key, this will
// set the value and return the zero value for V.
Swap(key Key, swp Value) Value
// Has checks the cache for a value with key, this will not update TTL
Has(key Key) bool
// Invalidate deletes a value from the cache, calling the invalidate callback
Invalidate(key Key) bool
// Clear empties the cache, calling the invalidate callback
Clear()
// Size returns the current size of the cache
Size() int
}
// New returns a new initialized Cache.
func New[K comparable, V any]() Cache[K, V] {
c := &TTLCache[K, V]{}
c.Init()
return c
}

View file

@ -1,23 +0,0 @@
package cache
import (
"reflect"
)
type Comparable interface {
Equal(any) bool
}
// Compare returns whether 2 values are equal using the Comparable
// interface, or failing that falls back to use reflect.DeepEqual().
func Compare(i1, i2 any) bool {
c1, ok1 := i1.(Comparable)
if ok1 {
return c1.Equal(i2)
}
c2, ok2 := i2.(Comparable)
if ok2 {
return c2.Equal(i1)
}
return reflect.DeepEqual(i1, i2)
}

View file

@ -1,6 +0,0 @@
package cache
// Hook defines a function hook that can be supplied as a callback.
type Hook[Key comparable, Value any] func(key Key, value Value)
func emptyHook[K comparable, V any](K, V) {}

View file

@ -1,210 +0,0 @@
package cache
// LookupCfg is the LookupCache configuration.
type LookupCfg[OGKey, AltKey comparable, Value any] struct {
// RegisterLookups is called on init to register lookups
// within LookupCache's internal LookupMap
RegisterLookups func(*LookupMap[OGKey, AltKey])
// AddLookups is called on each addition to the cache, to
// set any required additional key lookups for supplied item
AddLookups func(*LookupMap[OGKey, AltKey], Value)
// DeleteLookups is called on each eviction/invalidation of
// an item in the cache, to remove any unused key lookups
DeleteLookups func(*LookupMap[OGKey, AltKey], Value)
}
// LookupCache is a cache built on-top of TTLCache, providing multi-key
// lookups for items in the cache by means of additional lookup maps. These
// maps simply store additional keys => original key, with hook-ins to automatically
// call user supplied functions on adding an item, or on updating/deleting an
// item to keep the LookupMap up-to-date.
type LookupCache[OGKey, AltKey comparable, Value any] interface {
Cache[OGKey, Value]
// GetBy fetches a cached value by supplied lookup identifier and key
GetBy(lookup string, key AltKey) (value Value, ok bool)
// CASBy will attempt to perform a CAS operation on supplied lookup identifier and key
CASBy(lookup string, key AltKey, cmp, swp Value) bool
// SwapBy will attempt to perform a swap operation on supplied lookup identifier and key
SwapBy(lookup string, key AltKey, swp Value) Value
// HasBy checks if a value is cached under supplied lookup identifier and key
HasBy(lookup string, key AltKey) bool
// InvalidateBy invalidates a value by supplied lookup identifier and key
InvalidateBy(lookup string, key AltKey) bool
}
type lookupTTLCache[OK, AK comparable, V any] struct {
TTLCache[OK, V]
config LookupCfg[OK, AK, V]
lookup LookupMap[OK, AK]
}
// NewLookup returns a new initialized LookupCache.
func NewLookup[OK, AK comparable, V any](cfg LookupCfg[OK, AK, V]) LookupCache[OK, AK, V] {
switch {
case cfg.RegisterLookups == nil:
panic("cache: nil lookups register function")
case cfg.AddLookups == nil:
panic("cache: nil lookups add function")
case cfg.DeleteLookups == nil:
panic("cache: nil delete lookups function")
}
c := &lookupTTLCache[OK, AK, V]{config: cfg}
c.TTLCache.Init()
c.lookup.lookup = make(map[string]map[AK]OK)
c.config.RegisterLookups(&c.lookup)
c.SetEvictionCallback(nil)
c.SetInvalidateCallback(nil)
return c
}
func (c *lookupTTLCache[OK, AK, V]) SetEvictionCallback(hook Hook[OK, V]) {
if hook == nil {
hook = emptyHook[OK, V]
}
c.TTLCache.SetEvictionCallback(func(key OK, value V) {
hook(key, value)
c.config.DeleteLookups(&c.lookup, value)
})
}
func (c *lookupTTLCache[OK, AK, V]) SetInvalidateCallback(hook Hook[OK, V]) {
if hook == nil {
hook = emptyHook[OK, V]
}
c.TTLCache.SetInvalidateCallback(func(key OK, value V) {
hook(key, value)
c.config.DeleteLookups(&c.lookup, value)
})
}
func (c *lookupTTLCache[OK, AK, V]) GetBy(lookup string, key AK) (V, bool) {
c.Lock()
origKey, ok := c.lookup.Get(lookup, key)
if !ok {
c.Unlock()
var value V
return value, false
}
v, ok := c.GetUnsafe(origKey)
c.Unlock()
return v, ok
}
func (c *lookupTTLCache[OK, AK, V]) Put(key OK, value V) bool {
c.Lock()
put := c.PutUnsafe(key, value)
if put {
c.config.AddLookups(&c.lookup, value)
}
c.Unlock()
return put
}
func (c *lookupTTLCache[OK, AK, V]) Set(key OK, value V) {
c.Lock()
defer c.Unlock()
c.SetUnsafe(key, value)
c.config.AddLookups(&c.lookup, value)
}
func (c *lookupTTLCache[OK, AK, V]) CASBy(lookup string, key AK, cmp, swp V) bool {
c.Lock()
defer c.Unlock()
origKey, ok := c.lookup.Get(lookup, key)
if !ok {
return false
}
return c.CASUnsafe(origKey, cmp, swp)
}
func (c *lookupTTLCache[OK, AK, V]) SwapBy(lookup string, key AK, swp V) V {
c.Lock()
defer c.Unlock()
origKey, ok := c.lookup.Get(lookup, key)
if !ok {
var value V
return value
}
return c.SwapUnsafe(origKey, swp)
}
func (c *lookupTTLCache[OK, AK, V]) HasBy(lookup string, key AK) bool {
c.Lock()
has := c.lookup.Has(lookup, key)
c.Unlock()
return has
}
func (c *lookupTTLCache[OK, AK, V]) InvalidateBy(lookup string, key AK) bool {
c.Lock()
defer c.Unlock()
origKey, ok := c.lookup.Get(lookup, key)
if !ok {
return false
}
c.InvalidateUnsafe(origKey)
return true
}
// LookupMap is a structure that provides lookups for
// keys to primary keys under supplied lookup identifiers.
// This is essentially a wrapper around map[string](map[K1]K2).
type LookupMap[OK comparable, AK comparable] struct {
lookup map[string](map[AK]OK)
}
// RegisterLookup registers a lookup identifier in the LookupMap,
// note this can only be doing during the cfg.RegisterLookups() hook.
func (l *LookupMap[OK, AK]) RegisterLookup(id string) {
if _, ok := l.lookup[id]; ok {
panic("cache: lookup mapping already exists for identifier")
}
l.lookup[id] = make(map[AK]OK, 100)
}
// Get fetches an entry's primary key for lookup identifier and key.
func (l *LookupMap[OK, AK]) Get(id string, key AK) (OK, bool) {
keys, ok := l.lookup[id]
if !ok {
var key OK
return key, false
}
origKey, ok := keys[key]
return origKey, ok
}
// Set adds a lookup to the LookupMap under supplied lookup identifier,
// linking supplied key to the supplied primary (original) key.
func (l *LookupMap[OK, AK]) Set(id string, key AK, origKey OK) {
keys, ok := l.lookup[id]
if !ok {
panic("cache: invalid lookup identifier")
}
keys[key] = origKey
}
// Has checks if there exists a lookup for supplied identifier and key.
func (l *LookupMap[OK, AK]) Has(id string, key AK) bool {
keys, ok := l.lookup[id]
if !ok {
return false
}
_, ok = keys[key]
return ok
}
// Delete removes a lookup from LookupMap with supplied identifier and key.
func (l *LookupMap[OK, AK]) Delete(id string, key AK) {
keys, ok := l.lookup[id]
if !ok {
return
}
delete(keys, key)
}

View file

@ -1,20 +0,0 @@
package cache
import (
"time"
"codeberg.org/gruf/go-sched"
)
// scheduler is the global cache runtime scheduler
// for handling regular cache evictions.
var scheduler sched.Scheduler
// schedule will given sweep routine to the global scheduler, and start global scheduler.
func schedule(sweep func(time.Time), freq time.Duration) func() {
if !scheduler.Running() {
// ensure running
_ = scheduler.Start()
}
return scheduler.Schedule(sched.NewJob(sweep).Every(freq))
}

View file

@ -1,310 +0,0 @@
package cache
import (
"sync"
"time"
)
// TTLCache is the underlying Cache implementation, providing both the base
// Cache interface and access to "unsafe" methods so that you may build your
// customized caches ontop of this structure.
type TTLCache[Key comparable, Value any] struct {
cache map[Key](*entry[Value])
evict Hook[Key, Value] // the evict hook is called when an item is evicted from the cache, includes manual delete
invalid Hook[Key, Value] // the invalidate hook is called when an item's data in the cache is invalidated
ttl time.Duration // ttl is the item TTL
stop func() // stop is the cancel function for the scheduled eviction routine
mu sync.Mutex // mu protects TTLCache for concurrent access
}
// Init performs Cache initialization. MUST be called.
func (c *TTLCache[K, V]) Init() {
c.cache = make(map[K](*entry[V]), 100)
c.evict = emptyHook[K, V]
c.invalid = emptyHook[K, V]
c.ttl = time.Minute * 5
}
func (c *TTLCache[K, V]) Start(freq time.Duration) (ok bool) {
// Nothing to start
if freq <= 0 {
return false
}
// Safely start
c.mu.Lock()
if ok = c.stop == nil; ok {
// Not yet running, schedule us
c.stop = schedule(c.sweep, freq)
}
// Done with lock
c.mu.Unlock()
return
}
func (c *TTLCache[K, V]) Stop() (ok bool) {
// Safely stop
c.mu.Lock()
if ok = c.stop != nil; ok {
// We're running, cancel evicts
c.stop()
c.stop = nil
}
// Done with lock
c.mu.Unlock()
return
}
// sweep attempts to evict expired items (with callback!) from cache.
func (c *TTLCache[K, V]) sweep(now time.Time) {
// Lock and defer unlock (in case of hook panic)
c.mu.Lock()
defer c.mu.Unlock()
// Sweep the cache for old items!
for key, item := range c.cache {
if now.After(item.expiry) {
c.evict(key, item.value)
delete(c.cache, key)
}
}
}
// Lock locks the cache mutex.
func (c *TTLCache[K, V]) Lock() {
c.mu.Lock()
}
// Unlock unlocks the cache mutex.
func (c *TTLCache[K, V]) Unlock() {
c.mu.Unlock()
}
func (c *TTLCache[K, V]) SetEvictionCallback(hook Hook[K, V]) {
// Ensure non-nil hook
if hook == nil {
hook = emptyHook[K, V]
}
// Safely set evict hook
c.mu.Lock()
c.evict = hook
c.mu.Unlock()
}
func (c *TTLCache[K, V]) SetInvalidateCallback(hook Hook[K, V]) {
// Ensure non-nil hook
if hook == nil {
hook = emptyHook[K, V]
}
// Safely set invalidate hook
c.mu.Lock()
c.invalid = hook
c.mu.Unlock()
}
func (c *TTLCache[K, V]) SetTTL(ttl time.Duration, update bool) {
// Safely update TTL
c.mu.Lock()
diff := ttl - c.ttl
c.ttl = ttl
if update {
// Update existing cache entries
for _, entry := range c.cache {
entry.expiry.Add(diff)
}
}
// We're done
c.mu.Unlock()
}
func (c *TTLCache[K, V]) Get(key K) (V, bool) {
c.mu.Lock()
value, ok := c.GetUnsafe(key)
c.mu.Unlock()
return value, ok
}
// GetUnsafe is the mutex-unprotected logic for Cache.Get().
func (c *TTLCache[K, V]) GetUnsafe(key K) (V, bool) {
item, ok := c.cache[key]
if !ok {
var value V
return value, false
}
item.expiry = time.Now().Add(c.ttl)
return item.value, true
}
func (c *TTLCache[K, V]) Put(key K, value V) bool {
c.mu.Lock()
success := c.PutUnsafe(key, value)
c.mu.Unlock()
return success
}
// PutUnsafe is the mutex-unprotected logic for Cache.Put().
func (c *TTLCache[K, V]) PutUnsafe(key K, value V) bool {
// If already cached, return
if _, ok := c.cache[key]; ok {
return false
}
// Create new cached item
c.cache[key] = &entry[V]{
value: value,
expiry: time.Now().Add(c.ttl),
}
return true
}
func (c *TTLCache[K, V]) Set(key K, value V) {
c.mu.Lock()
defer c.mu.Unlock() // defer in case of hook panic
c.SetUnsafe(key, value)
}
// SetUnsafe is the mutex-unprotected logic for Cache.Set(), it calls externally-set functions.
func (c *TTLCache[K, V]) SetUnsafe(key K, value V) {
item, ok := c.cache[key]
if ok {
// call invalidate hook
c.invalid(key, item.value)
} else {
// alloc new item
item = &entry[V]{}
c.cache[key] = item
}
// Update the item + expiry
item.value = value
item.expiry = time.Now().Add(c.ttl)
}
func (c *TTLCache[K, V]) CAS(key K, cmp V, swp V) bool {
c.mu.Lock()
ok := c.CASUnsafe(key, cmp, swp)
c.mu.Unlock()
return ok
}
// CASUnsafe is the mutex-unprotected logic for Cache.CAS().
func (c *TTLCache[K, V]) CASUnsafe(key K, cmp V, swp V) bool {
// Check for item
item, ok := c.cache[key]
if !ok || !Compare(item.value, cmp) {
return false
}
// Invalidate item
c.invalid(key, item.value)
// Update item + expiry
item.value = swp
item.expiry = time.Now().Add(c.ttl)
return ok
}
func (c *TTLCache[K, V]) Swap(key K, swp V) V {
c.mu.Lock()
old := c.SwapUnsafe(key, swp)
c.mu.Unlock()
return old
}
// SwapUnsafe is the mutex-unprotected logic for Cache.Swap().
func (c *TTLCache[K, V]) SwapUnsafe(key K, swp V) V {
// Check for item
item, ok := c.cache[key]
if !ok {
var value V
return value
}
// invalidate old item
c.invalid(key, item.value)
old := item.value
// update item + expiry
item.value = swp
item.expiry = time.Now().Add(c.ttl)
return old
}
func (c *TTLCache[K, V]) Has(key K) bool {
c.mu.Lock()
ok := c.HasUnsafe(key)
c.mu.Unlock()
return ok
}
// HasUnsafe is the mutex-unprotected logic for Cache.Has().
func (c *TTLCache[K, V]) HasUnsafe(key K) bool {
_, ok := c.cache[key]
return ok
}
func (c *TTLCache[K, V]) Invalidate(key K) bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.InvalidateUnsafe(key)
}
// InvalidateUnsafe is mutex-unprotected logic for Cache.Invalidate().
func (c *TTLCache[K, V]) InvalidateUnsafe(key K) bool {
// Check if we have item with key
item, ok := c.cache[key]
if !ok {
return false
}
// Call hook, remove from cache
c.invalid(key, item.value)
delete(c.cache, key)
return true
}
func (c *TTLCache[K, V]) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.ClearUnsafe()
}
// ClearUnsafe is mutex-unprotected logic for Cache.Clean().
func (c *TTLCache[K, V]) ClearUnsafe() {
for key, item := range c.cache {
c.invalid(key, item.value)
delete(c.cache, key)
}
}
func (c *TTLCache[K, V]) Size() int {
c.mu.Lock()
sz := c.SizeUnsafe()
c.mu.Unlock()
return sz
}
// SizeUnsafe is mutex unprotected logic for Cache.Size().
func (c *TTLCache[K, V]) SizeUnsafe() int {
return len(c.cache)
}
// entry represents an item in the cache, with
// it's currently calculated expiry time.
type entry[Value any] struct {
value Value
expiry time.Time
}

17
vendor/codeberg.org/gruf/go-cache/v3/README.md generated vendored Normal file
View file

@ -0,0 +1,17 @@
# go-cache
Provides access to a simple yet flexible, performant TTL cache via the `Cache{}` interface and `cache.New()`. Under the hood this is returning a `ttl.Cache{}`.
## ttl
A TTL cache implementation with much of the inner workings exposed, designed to be used as a base for your own customizations, or used as-is. Access via the base package `cache.New()` is recommended in the latter case, to prevent accidental use of unsafe methods.
## lookup
`lookup.Cache` is an example of a more complex cache implementation using `ttl.Cache{}` as its underpinning. It provides caching of items under multiple keys.
## result
`result.Cache` is an example of a more complex cache implementation using `ttl.Cache{}` as its underpinning.
It provides caching specifically of loadable struct types, with automatic keying by multiple different field members and caching of negative (error) values. All useful when wrapping, for example, a database.

60
vendor/codeberg.org/gruf/go-cache/v3/cache.go generated vendored Normal file
View file

@ -0,0 +1,60 @@
package cache
import (
"time"
ttlcache "codeberg.org/gruf/go-cache/v3/ttl"
)
// Cache represents a TTL cache with customizable callbacks, it exists here to abstract away the "unsafe" methods in the case that you do not want your own implementation atop ttl.Cache{}.
type Cache[Key comparable, Value any] interface {
// Start will start the cache background eviction routine with given sweep frequency. If already running or a freq <= 0 provided, this is a no-op. This will block until the eviction routine has started.
Start(freq time.Duration) bool
// Stop will stop cache background eviction routine. If not running this is a no-op. This will block until the eviction routine has stopped.
Stop() bool
// SetEvictionCallback sets the eviction callback to the provided hook.
SetEvictionCallback(hook func(*ttlcache.Entry[Key, Value]))
// SetInvalidateCallback sets the invalidate callback to the provided hook.
SetInvalidateCallback(hook func(*ttlcache.Entry[Key, Value]))
// SetTTL sets the cache item TTL. Update can be specified to force updates of existing items in the cache, this will simply add the change in TTL to their current expiry time.
SetTTL(ttl time.Duration, update bool)
// Get fetches the value with key from the cache, extending its TTL.
Get(key Key) (value Value, ok bool)
// Add attempts to place the value at key in the cache, doing nothing if a value with this key already exists. Returned bool is success state.
Add(key Key, value Value) bool
// Set places the value at key in the cache. This will overwrite any existing value, and call the update callback so. Existing values will have their TTL extended upon update.
Set(key Key, value Value)
// CAS will attempt to perform a CAS operation on 'key', using provided old and new values, and comparator function. Returned bool is success.
CAS(key Key, old, new Value, cmp func(Value, Value) bool) bool
// Swap will attempt to perform a swap on 'key', replacing the value there and returning the existing value. If no value exists for key, this will set the value and return the zero value for V.
Swap(key Key, swp Value) Value
// Has checks the cache for a value with key, this will not update TTL.
Has(key Key) bool
// Invalidate deletes a value from the cache, calling the invalidate callback.
Invalidate(key Key) bool
// Clear empties the cache, calling the invalidate callback.
Clear()
// Len returns the current length of the cache.
Len() int
// Cap returns the maximum capacity of the cache.
Cap() int
}
// New returns a new initialized Cache with given initial length, maximum capacity and item TTL.
func New[K comparable, V any](len, cap int, ttl time.Duration) Cache[K, V] {
return ttlcache.New[K, V](len, cap, ttl)
}

4
vendor/modules.txt vendored
View file

@ -13,11 +13,9 @@ codeberg.org/gruf/go-bytesize
# codeberg.org/gruf/go-byteutil v1.0.2 # codeberg.org/gruf/go-byteutil v1.0.2
## explicit; go 1.16 ## explicit; go 1.16
codeberg.org/gruf/go-byteutil codeberg.org/gruf/go-byteutil
# codeberg.org/gruf/go-cache/v2 v2.1.4
## explicit; go 1.19
codeberg.org/gruf/go-cache/v2
# codeberg.org/gruf/go-cache/v3 v3.1.8 # codeberg.org/gruf/go-cache/v3 v3.1.8
## explicit; go 1.19 ## explicit; go 1.19
codeberg.org/gruf/go-cache/v3
codeberg.org/gruf/go-cache/v3/result codeberg.org/gruf/go-cache/v3/result
codeberg.org/gruf/go-cache/v3/ttl codeberg.org/gruf/go-cache/v3/ttl
# codeberg.org/gruf/go-debug v1.2.0 # codeberg.org/gruf/go-debug v1.2.0