From ed2477ebea4c3ceec5949821f4950db9669a4a15 Mon Sep 17 00:00:00 2001
From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>
Date: Mon, 31 Jul 2023 11:25:29 +0100
Subject: [PATCH] [performance] cache follow, follow request and block ID lists
(#2027)
---
go.mod | 2 +-
go.sum | 4 +-
internal/api/client/blocks/blocks.go | 2 +
internal/api/client/blocks/blocksget.go | 42 ++--
internal/cache/cache.go | 80 +++++-
internal/cache/gts.go | 150 ++++++++++--
internal/cache/slice.go | 76 ++++++
internal/cache/util.go | 23 +-
internal/config/config.go | 12 +
internal/config/defaults.go | 12 +
internal/config/helpers.gen.go | 229 ++++++++++++++++++
internal/db/account.go | 2 -
internal/db/bundb/account.go | 40 ---
internal/db/bundb/emoji.go | 16 +-
internal/db/bundb/list.go | 9 +-
internal/db/bundb/media.go | 44 +---
internal/db/bundb/relationship.go | 223 ++++++++++++-----
internal/db/bundb/relationship_block.go | 37 ++-
internal/db/bundb/relationship_follow.go | 22 +-
internal/db/bundb/relationship_follow_req.go | 25 +-
internal/db/bundb/status.go | 5 +-
internal/db/relationship.go | 4 +
internal/paging/paging.go | 227 +++++++++++++++++
internal/paging/paging_test.go | 171 +++++++++++++
internal/processing/account/delete.go | 39 ++-
internal/processing/blocks.go | 100 ++++----
test/envparsing.sh | 9 +
.../codeberg.org/gruf/go-cache/v3/ttl/ttl.go | 11 +-
vendor/modules.txt | 2 +-
29 files changed, 1283 insertions(+), 335 deletions(-)
create mode 100644 internal/cache/slice.go
create mode 100644 internal/paging/paging.go
create mode 100644 internal/paging/paging_test.go
diff --git a/go.mod b/go.mod
index 1a4d14b97..98abc64ee 100644
--- a/go.mod
+++ b/go.mod
@@ -5,7 +5,7 @@ go 1.20
require (
codeberg.org/gruf/go-bytesize v1.0.2
codeberg.org/gruf/go-byteutil v1.1.2
- codeberg.org/gruf/go-cache/v3 v3.4.3
+ codeberg.org/gruf/go-cache/v3 v3.4.4
codeberg.org/gruf/go-debug v1.3.0
codeberg.org/gruf/go-errors/v2 v2.2.0
codeberg.org/gruf/go-fastcopy v1.1.2
diff --git a/go.sum b/go.sum
index e700364a5..19964f9f1 100644
--- a/go.sum
+++ b/go.sum
@@ -48,8 +48,8 @@ codeberg.org/gruf/go-bytesize v1.0.2/go.mod h1:n/GU8HzL9f3UNp/mUKyr1qVmTlj7+xacp
codeberg.org/gruf/go-byteutil v1.0.0/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
codeberg.org/gruf/go-byteutil v1.1.2 h1:TQLZtTxTNca9xEfDIndmo7nBYxeS94nrv/9DS3Nk5Tw=
codeberg.org/gruf/go-byteutil v1.1.2/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
-codeberg.org/gruf/go-cache/v3 v3.4.3 h1:GTNq01M17jUJ3B3ehrVTbElpvCqOKgz1x+VB9GEIxXA=
-codeberg.org/gruf/go-cache/v3 v3.4.3/go.mod h1:pTeVPEb9DshXUkd8Dg76UcsLpU6EC/tXQ2qb+JrmxEc=
+codeberg.org/gruf/go-cache/v3 v3.4.4 h1:V0A3EzjhzhULOydD16pwa2DRDwF67OuuP4ORnm//7p8=
+codeberg.org/gruf/go-cache/v3 v3.4.4/go.mod h1:pTeVPEb9DshXUkd8Dg76UcsLpU6EC/tXQ2qb+JrmxEc=
codeberg.org/gruf/go-debug v1.3.0 h1:PIRxQiWUFKtGOGZFdZ3Y0pqyfI0Xr87j224IYe2snZs=
codeberg.org/gruf/go-debug v1.3.0/go.mod h1:N+vSy9uJBQgpQcJUqjctvqFz7tBHJf+S/PIjLILzpLg=
codeberg.org/gruf/go-errors/v2 v2.0.0/go.mod h1:ZRhbdhvgoUA3Yw6e56kd9Ox984RrvbEFC2pOXyHDJP4=
diff --git a/internal/api/client/blocks/blocks.go b/internal/api/client/blocks/blocks.go
index bff9a068e..0eeee2bf1 100644
--- a/internal/api/client/blocks/blocks.go
+++ b/internal/api/client/blocks/blocks.go
@@ -30,8 +30,10 @@
// MaxIDKey is the url query for setting a max ID to return
MaxIDKey = "max_id"
+
// SinceIDKey is the url query for returning results newer than the given ID
SinceIDKey = "since_id"
+
// LimitKey is for specifying maximum number of results to return.
LimitKey = "limit"
)
diff --git a/internal/api/client/blocks/blocksget.go b/internal/api/client/blocks/blocksget.go
index 7aec8b334..505c33db8 100644
--- a/internal/api/client/blocks/blocksget.go
+++ b/internal/api/client/blocks/blocksget.go
@@ -18,14 +18,13 @@
package blocks
import (
- "fmt"
"net/http"
- "strconv"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
)
// BlocksGETHandler swagger:operation GET /api/v1/blocks blocksGet
@@ -104,31 +103,21 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {
return
}
- maxID := ""
- maxIDString := c.Query(MaxIDKey)
- if maxIDString != "" {
- maxID = maxIDString
+ limit, errWithCode := apiutil.ParseLimit(c.Query(LimitKey), 20, 100, 2)
+ if err != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
+ return
}
- sinceID := ""
- sinceIDString := c.Query(SinceIDKey)
- if sinceIDString != "" {
- sinceID = sinceIDString
- }
-
- limit := 20
- limitString := c.Query(LimitKey)
- if limitString != "" {
- i, err := strconv.ParseInt(limitString, 10, 32)
- if err != nil {
- err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
- apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
- return
- }
- limit = int(i)
- }
-
- resp, errWithCode := m.processor.BlocksGet(c.Request.Context(), authed, maxID, sinceID, limit)
+ resp, errWithCode := m.processor.BlocksGet(
+ c.Request.Context(),
+ authed.Account,
+ paging.Pager{
+ SinceID: c.Query(SinceIDKey),
+ MaxID: c.Query(MaxIDKey),
+ Limit: limit,
+ },
+ )
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
@@ -137,5 +126,6 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {
if resp.LinkHeader != "" {
c.Header("Link", resp.LinkHeader)
}
- c.JSON(http.StatusOK, resp.Accounts)
+
+ c.JSON(http.StatusOK, resp.Items)
}
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index 63564935e..e97dce6f9 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -80,6 +80,27 @@ func (c *Caches) setuphooks() {
// Invalidate account ID cached visibility.
c.Visibility.Invalidate("ItemID", account.ID)
c.Visibility.Invalidate("RequesterID", account.ID)
+
+ // Invalidate this account's
+ // following / follower lists.
+ // (see FollowIDs() comment for details).
+ c.GTS.FollowIDs().InvalidateAll(
+ ">"+account.ID,
+ "l>"+account.ID,
+ "<"+account.ID,
+ "l<"+account.ID,
+ )
+
+ // Invalidate this account's
+ // follow requesting / request lists.
+ // (see FollowRequestIDs() comment for details).
+ c.GTS.FollowRequestIDs().InvalidateAll(
+ ">"+account.ID,
+ "<"+account.ID,
+ )
+
+ // Invalidate this account's block lists.
+ c.GTS.BlockIDs().Invalidate(account.ID)
})
c.GTS.Block().SetInvalidateCallback(func(block *gtsmodel.Block) {
@@ -90,6 +111,9 @@ func (c *Caches) setuphooks() {
// Invalidate block target account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.TargetAccountID)
c.Visibility.Invalidate("RequesterID", block.TargetAccountID)
+
+ // Invalidate source account's block lists.
+ c.GTS.BlockIDs().Invalidate(block.AccountID)
})
c.GTS.EmojiCategory().SetInvalidateCallback(func(category *gtsmodel.EmojiCategory) {
@@ -98,6 +122,9 @@ func (c *Caches) setuphooks() {
})
c.GTS.Follow().SetInvalidateCallback(func(follow *gtsmodel.Follow) {
+ // Invalidate follow request with this same ID.
+ c.GTS.FollowRequest().Invalidate("ID", follow.ID)
+
// Invalidate any related list entries.
c.GTS.ListEntry().Invalidate("FollowID", follow.ID)
@@ -108,19 +135,35 @@ func (c *Caches) setuphooks() {
// Invalidate follow target account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.TargetAccountID)
c.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
+
+ // Invalidate source account's following
+ // lists, and destination's follwer lists.
+ // (see FollowIDs() comment for details).
+ c.GTS.FollowIDs().InvalidateAll(
+ ">"+follow.AccountID,
+ "l>"+follow.AccountID,
+ "<"+follow.AccountID,
+ "l<"+follow.AccountID,
+ "<"+follow.TargetAccountID,
+ "l<"+follow.TargetAccountID,
+ ">"+follow.TargetAccountID,
+ "l>"+follow.TargetAccountID,
+ )
})
c.GTS.FollowRequest().SetInvalidateCallback(func(followReq *gtsmodel.FollowRequest) {
- // Invalidate follow request origin account ID cached visibility.
- c.Visibility.Invalidate("ItemID", followReq.AccountID)
- c.Visibility.Invalidate("RequesterID", followReq.AccountID)
-
- // Invalidate follow request target account ID cached visibility.
- c.Visibility.Invalidate("ItemID", followReq.TargetAccountID)
- c.Visibility.Invalidate("RequesterID", followReq.TargetAccountID)
-
- // Invalidate any cached follow with same ID.
+ // Invalidate follow with this same ID.
c.GTS.Follow().Invalidate("ID", followReq.ID)
+
+ // Invalidate source account's followreq
+ // lists, and destinations follow req lists.
+ // (see FollowRequestIDs() comment for details).
+ c.GTS.FollowRequestIDs().InvalidateAll(
+ ">"+followReq.AccountID,
+ "<"+followReq.AccountID,
+ ">"+followReq.TargetAccountID,
+ "<"+followReq.TargetAccountID,
+ )
})
c.GTS.List().SetInvalidateCallback(func(list *gtsmodel.List) {
@@ -128,12 +171,29 @@ func (c *Caches) setuphooks() {
c.GTS.ListEntry().Invalidate("ListID", list.ID)
})
+ c.GTS.Media().SetInvalidateCallback(func(media *gtsmodel.MediaAttachment) {
+ if *media.Avatar || *media.Header {
+ // Invalidate cache of attaching account.
+ c.GTS.Account().Invalidate("ID", media.AccountID)
+ }
+
+ if media.StatusID != "" {
+ // Invalidate cache of attaching status.
+ c.GTS.Status().Invalidate("ID", media.StatusID)
+ }
+ })
+
c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) {
// Invalidate status ID cached visibility.
c.Visibility.Invalidate("ItemID", status.ID)
for _, id := range status.AttachmentIDs {
- // Invalidate cache for attached media IDs,
+ // Invalidate each media by the IDs we're aware of.
+ // This must be done as the status table is aware of
+ // the media IDs in use before the media table is
+ // aware of the status ID they are linked to.
+ //
+ // c.GTS.Media().Invalidate("StatusID") will not work.
c.GTS.Media().Invalidate("ID", id)
}
})
diff --git a/internal/cache/gts.go b/internal/cache/gts.go
index dd43154ef..fefd02fff 100644
--- a/internal/cache/gts.go
+++ b/internal/cache/gts.go
@@ -26,29 +26,31 @@
)
type GTSCaches struct {
- account *result.Cache[*gtsmodel.Account]
- accountNote *result.Cache[*gtsmodel.AccountNote]
- block *result.Cache[*gtsmodel.Block]
- // TODO: maybe should be moved out of here since it's
- // not actually doing anything with gtsmodel.DomainBlock.
- domainBlock *domain.BlockCache
- emoji *result.Cache[*gtsmodel.Emoji]
- emojiCategory *result.Cache[*gtsmodel.EmojiCategory]
- follow *result.Cache[*gtsmodel.Follow]
- followRequest *result.Cache[*gtsmodel.FollowRequest]
- instance *result.Cache[*gtsmodel.Instance]
- list *result.Cache[*gtsmodel.List]
- listEntry *result.Cache[*gtsmodel.ListEntry]
- marker *result.Cache[*gtsmodel.Marker]
- media *result.Cache[*gtsmodel.MediaAttachment]
- mention *result.Cache[*gtsmodel.Mention]
- notification *result.Cache[*gtsmodel.Notification]
- report *result.Cache[*gtsmodel.Report]
- status *result.Cache[*gtsmodel.Status]
- statusFave *result.Cache[*gtsmodel.StatusFave]
- tombstone *result.Cache[*gtsmodel.Tombstone]
- user *result.Cache[*gtsmodel.User]
- // TODO: move out of GTS caches since not using database models.
+ account *result.Cache[*gtsmodel.Account]
+ accountNote *result.Cache[*gtsmodel.AccountNote]
+ block *result.Cache[*gtsmodel.Block]
+ blockIDs *SliceCache[string]
+ domainBlock *domain.BlockCache
+ emoji *result.Cache[*gtsmodel.Emoji]
+ emojiCategory *result.Cache[*gtsmodel.EmojiCategory]
+ follow *result.Cache[*gtsmodel.Follow]
+ followIDs *SliceCache[string]
+ followRequest *result.Cache[*gtsmodel.FollowRequest]
+ followRequestIDs *SliceCache[string]
+ instance *result.Cache[*gtsmodel.Instance]
+ list *result.Cache[*gtsmodel.List]
+ listEntry *result.Cache[*gtsmodel.ListEntry]
+ marker *result.Cache[*gtsmodel.Marker]
+ media *result.Cache[*gtsmodel.MediaAttachment]
+ mention *result.Cache[*gtsmodel.Mention]
+ notification *result.Cache[*gtsmodel.Notification]
+ report *result.Cache[*gtsmodel.Report]
+ status *result.Cache[*gtsmodel.Status]
+ statusFave *result.Cache[*gtsmodel.StatusFave]
+ tombstone *result.Cache[*gtsmodel.Tombstone]
+ user *result.Cache[*gtsmodel.User]
+
+ // TODO: move out of GTS caches since unrelated to DB.
webfinger *ttl.Cache[string, string]
}
@@ -58,11 +60,14 @@ func (c *GTSCaches) Init() {
c.initAccount()
c.initAccountNote()
c.initBlock()
+ c.initBlockIDs()
c.initDomainBlock()
c.initEmoji()
c.initEmojiCategory()
c.initFollow()
+ c.initFollowIDs()
c.initFollowRequest()
+ c.initFollowRequestIDs()
c.initInstance()
c.initList()
c.initListEntry()
@@ -83,10 +88,28 @@ func (c *GTSCaches) Start() {
tryStart(c.account, config.GetCacheGTSAccountSweepFreq())
tryStart(c.accountNote, config.GetCacheGTSAccountNoteSweepFreq())
tryStart(c.block, config.GetCacheGTSBlockSweepFreq())
+ tryUntil("starting block IDs cache", 5, func() bool {
+ if sweep := config.GetCacheGTSBlockIDsSweepFreq(); sweep > 0 {
+ return c.blockIDs.Start(sweep)
+ }
+ return true
+ })
tryStart(c.emoji, config.GetCacheGTSEmojiSweepFreq())
tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStart(c.follow, config.GetCacheGTSFollowSweepFreq())
+ tryUntil("starting follow IDs cache", 5, func() bool {
+ if sweep := config.GetCacheGTSFollowIDsSweepFreq(); sweep > 0 {
+ return c.followIDs.Start(sweep)
+ }
+ return true
+ })
tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
+ tryUntil("starting follow request IDs cache", 5, func() bool {
+ if sweep := config.GetCacheGTSFollowRequestIDsSweepFreq(); sweep > 0 {
+ return c.followRequestIDs.Start(sweep)
+ }
+ return true
+ })
tryStart(c.instance, config.GetCacheGTSInstanceSweepFreq())
tryStart(c.list, config.GetCacheGTSListSweepFreq())
tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
@@ -112,10 +135,28 @@ func (c *GTSCaches) Stop() {
tryStop(c.account, config.GetCacheGTSAccountSweepFreq())
tryStop(c.accountNote, config.GetCacheGTSAccountNoteSweepFreq())
tryStop(c.block, config.GetCacheGTSBlockSweepFreq())
+ tryUntil("stopping block IDs cache", 5, func() bool {
+ if config.GetCacheGTSBlockIDsSweepFreq() > 0 {
+ return c.blockIDs.Stop()
+ }
+ return true
+ })
tryStop(c.emoji, config.GetCacheGTSEmojiSweepFreq())
tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStop(c.follow, config.GetCacheGTSFollowSweepFreq())
+ tryUntil("stopping follow IDs cache", 5, func() bool {
+ if config.GetCacheGTSFollowIDsSweepFreq() > 0 {
+ return c.followIDs.Stop()
+ }
+ return true
+ })
tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
+ tryUntil("stopping follow request IDs cache", 5, func() bool {
+ if config.GetCacheGTSFollowRequestIDsSweepFreq() > 0 {
+ return c.followRequestIDs.Stop()
+ }
+ return true
+ })
tryStop(c.instance, config.GetCacheGTSInstanceSweepFreq())
tryStop(c.list, config.GetCacheGTSListSweepFreq())
tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
@@ -128,7 +169,12 @@ func (c *GTSCaches) Stop() {
tryStop(c.statusFave, config.GetCacheGTSStatusFaveSweepFreq())
tryStop(c.tombstone, config.GetCacheGTSTombstoneSweepFreq())
tryStop(c.user, config.GetCacheGTSUserSweepFreq())
- tryUntil("stopping *gtsmodel.Webfinger cache", 5, c.webfinger.Stop)
+ tryUntil("stopping *gtsmodel.Webfinger cache", 5, func() bool {
+ if config.GetCacheGTSWebfingerSweepFreq() > 0 {
+ return c.webfinger.Stop()
+ }
+ return true
+ })
}
// Account provides access to the gtsmodel Account database cache.
@@ -146,6 +192,11 @@ func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] {
return c.block
}
+// FollowIDs provides access to the block IDs database cache.
+func (c *GTSCaches) BlockIDs() *SliceCache[string] {
+ return c.blockIDs
+}
+
// DomainBlock provides access to the domain block database cache.
func (c *GTSCaches) DomainBlock() *domain.BlockCache {
return c.domainBlock
@@ -166,11 +217,29 @@ func (c *GTSCaches) Follow() *result.Cache[*gtsmodel.Follow] {
return c.follow
}
+// FollowIDs provides access to the follower / following IDs database cache.
+// THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS:
+// - '>' for following IDs
+// - 'l>' for local following IDs
+// - '<' for follower IDs
+// - 'l<' for local follower IDs
+func (c *GTSCaches) FollowIDs() *SliceCache[string] {
+ return c.followIDs
+}
+
// FollowRequest provides access to the gtsmodel FollowRequest database cache.
func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] {
return c.followRequest
}
+// FollowRequestIDs provides access to the follow requester / requesting IDs database
+// cache. THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS:
+// - '>' for following IDs
+// - '<' for follower IDs
+func (c *GTSCaches) FollowRequestIDs() *SliceCache[string] {
+ return c.followRequestIDs
+}
+
// Instance provides access to the gtsmodel Instance database cache.
func (c *GTSCaches) Instance() *result.Cache[*gtsmodel.Instance] {
return c.instance
@@ -274,6 +343,8 @@ func (c *GTSCaches) initBlock() {
{Name: "ID"},
{Name: "URI"},
{Name: "AccountID.TargetAccountID"},
+ {Name: "AccountID", Multi: true},
+ {Name: "TargetAccountID", Multi: true},
}, func(b1 *gtsmodel.Block) *gtsmodel.Block {
b2 := new(gtsmodel.Block)
*b2 = *b1
@@ -283,6 +354,14 @@ func (c *GTSCaches) initBlock() {
c.block.IgnoreErrors(ignoreErrors)
}
+func (c *GTSCaches) initBlockIDs() {
+ c.blockIDs = &SliceCache[string]{Cache: ttl.New[string, []string](
+ 0,
+ config.GetCacheGTSBlockIDsMaxSize(),
+ config.GetCacheGTSBlockIDsTTL(),
+ )}
+}
+
func (c *GTSCaches) initDomainBlock() {
c.domainBlock = new(domain.BlockCache)
}
@@ -321,6 +400,8 @@ func (c *GTSCaches) initFollow() {
{Name: "ID"},
{Name: "URI"},
{Name: "AccountID.TargetAccountID"},
+ {Name: "AccountID", Multi: true},
+ {Name: "TargetAccountID", Multi: true},
}, func(f1 *gtsmodel.Follow) *gtsmodel.Follow {
f2 := new(gtsmodel.Follow)
*f2 = *f1
@@ -329,11 +410,21 @@ func (c *GTSCaches) initFollow() {
c.follow.SetTTL(config.GetCacheGTSFollowTTL(), true)
}
+func (c *GTSCaches) initFollowIDs() {
+ c.followIDs = &SliceCache[string]{Cache: ttl.New[string, []string](
+ 0,
+ config.GetCacheGTSFollowIDsMaxSize(),
+ config.GetCacheGTSFollowIDsTTL(),
+ )}
+}
+
func (c *GTSCaches) initFollowRequest() {
c.followRequest = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
{Name: "AccountID.TargetAccountID"},
+ {Name: "AccountID", Multi: true},
+ {Name: "TargetAccountID", Multi: true},
}, func(f1 *gtsmodel.FollowRequest) *gtsmodel.FollowRequest {
f2 := new(gtsmodel.FollowRequest)
*f2 = *f1
@@ -342,6 +433,14 @@ func (c *GTSCaches) initFollowRequest() {
c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true)
}
+func (c *GTSCaches) initFollowRequestIDs() {
+ c.followRequestIDs = &SliceCache[string]{Cache: ttl.New[string, []string](
+ 0,
+ config.GetCacheGTSFollowRequestIDsMaxSize(),
+ config.GetCacheGTSFollowRequestIDsTTL(),
+ )}
+}
+
func (c *GTSCaches) initInstance() {
c.instance = result.New([]result.Lookup{
{Name: "ID"},
@@ -502,5 +601,6 @@ func (c *GTSCaches) initWebfinger() {
c.webfinger = ttl.New[string, string](
0,
config.GetCacheGTSWebfingerMaxSize(),
- config.GetCacheGTSWebfingerTTL())
+ config.GetCacheGTSWebfingerTTL(),
+ )
}
diff --git a/internal/cache/slice.go b/internal/cache/slice.go
new file mode 100644
index 000000000..194f20d4b
--- /dev/null
+++ b/internal/cache/slice.go
@@ -0,0 +1,76 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package cache
+
+import (
+ "codeberg.org/gruf/go-cache/v3/ttl"
+ "golang.org/x/exp/slices"
+)
+
+// SliceCache wraps a ttl.Cache to provide simple loader-callback
+// functions for fetching + caching slices of objects (e.g. IDs).
+type SliceCache[T any] struct {
+ *ttl.Cache[string, []T]
+}
+
+// Load will attempt to load an existing slice from the cache for the given key, else calling the provided load function and caching the result.
+func (c *SliceCache[T]) Load(key string, load func() ([]T, error)) ([]T, error) {
+ // Look for follow IDs list in cache under this key.
+ data, ok := c.Get(key)
+
+ if !ok {
+ var err error
+
+ // Not cached, load!
+ data, err = load()
+ if err != nil {
+ return nil, err
+ }
+
+ // Store the data.
+ c.Set(key, data)
+ }
+
+ // Return data clone for safety.
+ return slices.Clone(data), nil
+}
+
+// LoadRange is functionally the same as .Load(), but will pass the result through provided reslice function before returning a cloned result.
+func (c *SliceCache[T]) LoadRange(key string, load func() ([]T, error), reslice func([]T) []T) ([]T, error) {
+ // Look for follow IDs list in cache under this key.
+ data, ok := c.Get(key)
+
+ if !ok {
+ var err error
+
+ // Not cached, load!
+ data, err = load()
+ if err != nil {
+ return nil, err
+ }
+
+ // Store the data.
+ c.Set(key, data)
+ }
+
+ // Reslice to range.
+ slice := reslice(data)
+
+ // Return range clone for safety.
+ return slices.Clone(slice), nil
+}
diff --git a/internal/cache/util.go b/internal/cache/util.go
index a0adfd366..f2357c904 100644
--- a/internal/cache/util.go
+++ b/internal/cache/util.go
@@ -18,28 +18,33 @@
package cache
import (
- "context"
+ "database/sql"
"errors"
"fmt"
"time"
"codeberg.org/gruf/go-cache/v3/result"
errorsv2 "codeberg.org/gruf/go-errors/v2"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
-// SentinelError is returned to indicate a non-permanent error return,
-// i.e. a situation in which we do not want a cache a negative result.
+// SentinelError is an error that can be returned and checked against to indicate a non-permanent
+// error return from a cache loader callback, e.g. a temporary situation that will soon be fixed.
var SentinelError = errors.New("BUG: error should not be returned") //nolint:revive
-// ignoreErrors is an error ignoring function capable of being passed to
-// caches, which specifically catches and ignores our sentinel error type.
+// ignoreErrors is an error matching function used to signal which errors
+// the result caches should NOT hold onto. these amount to anything non-permanent.
func ignoreErrors(err error) bool {
- return errorsv2.Comparable(
+ return !errorsv2.Comparable(
err,
- SentinelError,
- context.DeadlineExceeded,
- context.Canceled,
+
+ // the only cacheable errs,
+ // i.e anything permanent
+ // (until invalidation).
+ db.ErrNoEntries,
+ db.ErrAlreadyExists,
+ sql.ErrNoRows,
)
}
diff --git a/internal/config/config.go b/internal/config/config.go
index bd9fc468c..99b07358e 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -194,6 +194,10 @@ type GTSCacheConfiguration struct {
BlockTTL time.Duration `name:"block-ttl"`
BlockSweepFreq time.Duration `name:"block-sweep-freq"`
+ BlockIDsMaxSize int `name:"block-ids-max-size"`
+ BlockIDsTTL time.Duration `name:"block-ids-ttl"`
+ BlockIDsSweepFreq time.Duration `name:"block-ids-sweep-freq"`
+
DomainBlockMaxSize int `name:"domain-block-max-size"`
DomainBlockTTL time.Duration `name:"domain-block-ttl"`
DomainBlockSweepFreq time.Duration `name:"domain-block-sweep-freq"`
@@ -210,10 +214,18 @@ type GTSCacheConfiguration struct {
FollowTTL time.Duration `name:"follow-ttl"`
FollowSweepFreq time.Duration `name:"follow-sweep-freq"`
+ FollowIDsMaxSize int `name:"follow-ids-max-size"`
+ FollowIDsTTL time.Duration `name:"follow-ids-ttl"`
+ FollowIDsSweepFreq time.Duration `name:"follow-ids-sweep-freq"`
+
FollowRequestMaxSize int `name:"follow-request-max-size"`
FollowRequestTTL time.Duration `name:"follow-request-ttl"`
FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"`
+ FollowRequestIDsMaxSize int `name:"follow-request-ids-max-size"`
+ FollowRequestIDsTTL time.Duration `name:"follow-request-ids-ttl"`
+ FollowRequestIDsSweepFreq time.Duration `name:"follow-request-ids-sweep-freq"`
+
InstanceMaxSize int `name:"instance-max-size"`
InstanceTTL time.Duration `name:"instance-ttl"`
InstanceSweepFreq time.Duration `name:"instance-sweep-freq"`
diff --git a/internal/config/defaults.go b/internal/config/defaults.go
index ee20fb6a7..cb37838c1 100644
--- a/internal/config/defaults.go
+++ b/internal/config/defaults.go
@@ -139,6 +139,10 @@
BlockTTL: time.Minute * 30,
BlockSweepFreq: time.Minute,
+ BlockIDsMaxSize: 500,
+ BlockIDsTTL: time.Minute * 30,
+ BlockIDsSweepFreq: time.Minute,
+
DomainBlockMaxSize: 2000,
DomainBlockTTL: time.Hour * 24,
DomainBlockSweepFreq: time.Minute,
@@ -155,10 +159,18 @@
FollowTTL: time.Minute * 30,
FollowSweepFreq: time.Minute,
+ FollowIDsMaxSize: 500,
+ FollowIDsTTL: time.Minute * 30,
+ FollowIDsSweepFreq: time.Minute,
+
FollowRequestMaxSize: 2000,
FollowRequestTTL: time.Minute * 30,
FollowRequestSweepFreq: time.Minute,
+ FollowRequestIDsMaxSize: 500,
+ FollowRequestIDsTTL: time.Minute * 30,
+ FollowRequestIDsSweepFreq: time.Minute,
+
InstanceMaxSize: 2000,
InstanceTTL: time.Minute * 30,
InstanceSweepFreq: time.Minute,
diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go
index 5eed1b468..1bf8ec2bc 100644
--- a/internal/config/helpers.gen.go
+++ b/internal/config/helpers.gen.go
@@ -2624,6 +2624,81 @@ func GetCacheGTSBlockSweepFreq() time.Duration { return global.GetCacheGTSBlockS
// SetCacheGTSBlockSweepFreq safely sets the value for global configuration 'Cache.GTS.BlockSweepFreq' field
func SetCacheGTSBlockSweepFreq(v time.Duration) { global.SetCacheGTSBlockSweepFreq(v) }
+// GetCacheGTSBlockIDsMaxSize safely fetches the Configuration value for state's 'Cache.GTS.BlockIDsMaxSize' field
+func (st *ConfigState) GetCacheGTSBlockIDsMaxSize() (v int) {
+ st.mutex.RLock()
+ v = st.config.Cache.GTS.BlockIDsMaxSize
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheGTSBlockIDsMaxSize safely sets the Configuration value for state's 'Cache.GTS.BlockIDsMaxSize' field
+func (st *ConfigState) SetCacheGTSBlockIDsMaxSize(v int) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.BlockIDsMaxSize = v
+ st.reloadToViper()
+}
+
+// CacheGTSBlockIDsMaxSizeFlag returns the flag name for the 'Cache.GTS.BlockIDsMaxSize' field
+func CacheGTSBlockIDsMaxSizeFlag() string { return "cache-gts-block-ids-max-size" }
+
+// GetCacheGTSBlockIDsMaxSize safely fetches the value for global configuration 'Cache.GTS.BlockIDsMaxSize' field
+func GetCacheGTSBlockIDsMaxSize() int { return global.GetCacheGTSBlockIDsMaxSize() }
+
+// SetCacheGTSBlockIDsMaxSize safely sets the value for global configuration 'Cache.GTS.BlockIDsMaxSize' field
+func SetCacheGTSBlockIDsMaxSize(v int) { global.SetCacheGTSBlockIDsMaxSize(v) }
+
+// GetCacheGTSBlockIDsTTL safely fetches the Configuration value for state's 'Cache.GTS.BlockIDsTTL' field
+func (st *ConfigState) GetCacheGTSBlockIDsTTL() (v time.Duration) {
+ st.mutex.RLock()
+ v = st.config.Cache.GTS.BlockIDsTTL
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheGTSBlockIDsTTL safely sets the Configuration value for state's 'Cache.GTS.BlockIDsTTL' field
+func (st *ConfigState) SetCacheGTSBlockIDsTTL(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.BlockIDsTTL = v
+ st.reloadToViper()
+}
+
+// CacheGTSBlockIDsTTLFlag returns the flag name for the 'Cache.GTS.BlockIDsTTL' field
+func CacheGTSBlockIDsTTLFlag() string { return "cache-gts-block-ids-ttl" }
+
+// GetCacheGTSBlockIDsTTL safely fetches the value for global configuration 'Cache.GTS.BlockIDsTTL' field
+func GetCacheGTSBlockIDsTTL() time.Duration { return global.GetCacheGTSBlockIDsTTL() }
+
+// SetCacheGTSBlockIDsTTL safely sets the value for global configuration 'Cache.GTS.BlockIDsTTL' field
+func SetCacheGTSBlockIDsTTL(v time.Duration) { global.SetCacheGTSBlockIDsTTL(v) }
+
+// GetCacheGTSBlockIDsSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.BlockIDsSweepFreq' field
+func (st *ConfigState) GetCacheGTSBlockIDsSweepFreq() (v time.Duration) {
+ st.mutex.RLock()
+ v = st.config.Cache.GTS.BlockIDsSweepFreq
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheGTSBlockIDsSweepFreq safely sets the Configuration value for state's 'Cache.GTS.BlockIDsSweepFreq' field
+func (st *ConfigState) SetCacheGTSBlockIDsSweepFreq(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.BlockIDsSweepFreq = v
+ st.reloadToViper()
+}
+
+// CacheGTSBlockIDsSweepFreqFlag returns the flag name for the 'Cache.GTS.BlockIDsSweepFreq' field
+func CacheGTSBlockIDsSweepFreqFlag() string { return "cache-gts-block-ids-sweep-freq" }
+
+// GetCacheGTSBlockIDsSweepFreq safely fetches the value for global configuration 'Cache.GTS.BlockIDsSweepFreq' field
+func GetCacheGTSBlockIDsSweepFreq() time.Duration { return global.GetCacheGTSBlockIDsSweepFreq() }
+
+// SetCacheGTSBlockIDsSweepFreq safely sets the value for global configuration 'Cache.GTS.BlockIDsSweepFreq' field
+func SetCacheGTSBlockIDsSweepFreq(v time.Duration) { global.SetCacheGTSBlockIDsSweepFreq(v) }
+
// GetCacheGTSDomainBlockMaxSize safely fetches the Configuration value for state's 'Cache.GTS.DomainBlockMaxSize' field
func (st *ConfigState) GetCacheGTSDomainBlockMaxSize() (v int) {
st.mutex.RLock()
@@ -2926,6 +3001,81 @@ func GetCacheGTSFollowSweepFreq() time.Duration { return global.GetCacheGTSFollo
// SetCacheGTSFollowSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowSweepFreq' field
func SetCacheGTSFollowSweepFreq(v time.Duration) { global.SetCacheGTSFollowSweepFreq(v) }
+// GetCacheGTSFollowIDsMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowIDsMaxSize' field
+func (st *ConfigState) GetCacheGTSFollowIDsMaxSize() (v int) {
+ st.mutex.RLock()
+ v = st.config.Cache.GTS.FollowIDsMaxSize
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheGTSFollowIDsMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowIDsMaxSize' field
+func (st *ConfigState) SetCacheGTSFollowIDsMaxSize(v int) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowIDsMaxSize = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowIDsMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowIDsMaxSize' field
+func CacheGTSFollowIDsMaxSizeFlag() string { return "cache-gts-follow-ids-max-size" }
+
+// GetCacheGTSFollowIDsMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowIDsMaxSize' field
+func GetCacheGTSFollowIDsMaxSize() int { return global.GetCacheGTSFollowIDsMaxSize() }
+
+// SetCacheGTSFollowIDsMaxSize safely sets the value for global configuration 'Cache.GTS.FollowIDsMaxSize' field
+func SetCacheGTSFollowIDsMaxSize(v int) { global.SetCacheGTSFollowIDsMaxSize(v) }
+
+// GetCacheGTSFollowIDsTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowIDsTTL' field
+func (st *ConfigState) GetCacheGTSFollowIDsTTL() (v time.Duration) {
+ st.mutex.RLock()
+ v = st.config.Cache.GTS.FollowIDsTTL
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheGTSFollowIDsTTL safely sets the Configuration value for state's 'Cache.GTS.FollowIDsTTL' field
+func (st *ConfigState) SetCacheGTSFollowIDsTTL(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowIDsTTL = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowIDsTTLFlag returns the flag name for the 'Cache.GTS.FollowIDsTTL' field
+func CacheGTSFollowIDsTTLFlag() string { return "cache-gts-follow-ids-ttl" }
+
+// GetCacheGTSFollowIDsTTL safely fetches the value for global configuration 'Cache.GTS.FollowIDsTTL' field
+func GetCacheGTSFollowIDsTTL() time.Duration { return global.GetCacheGTSFollowIDsTTL() }
+
+// SetCacheGTSFollowIDsTTL safely sets the value for global configuration 'Cache.GTS.FollowIDsTTL' field
+func SetCacheGTSFollowIDsTTL(v time.Duration) { global.SetCacheGTSFollowIDsTTL(v) }
+
+// GetCacheGTSFollowIDsSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowIDsSweepFreq' field
+func (st *ConfigState) GetCacheGTSFollowIDsSweepFreq() (v time.Duration) {
+ st.mutex.RLock()
+ v = st.config.Cache.GTS.FollowIDsSweepFreq
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheGTSFollowIDsSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowIDsSweepFreq' field
+func (st *ConfigState) SetCacheGTSFollowIDsSweepFreq(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowIDsSweepFreq = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowIDsSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowIDsSweepFreq' field
+func CacheGTSFollowIDsSweepFreqFlag() string { return "cache-gts-follow-ids-sweep-freq" }
+
+// GetCacheGTSFollowIDsSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowIDsSweepFreq' field
+func GetCacheGTSFollowIDsSweepFreq() time.Duration { return global.GetCacheGTSFollowIDsSweepFreq() }
+
+// SetCacheGTSFollowIDsSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowIDsSweepFreq' field
+func SetCacheGTSFollowIDsSweepFreq(v time.Duration) { global.SetCacheGTSFollowIDsSweepFreq(v) }
+
// GetCacheGTSFollowRequestMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestMaxSize' field
func (st *ConfigState) GetCacheGTSFollowRequestMaxSize() (v int) {
st.mutex.RLock()
@@ -3003,6 +3153,85 @@ func GetCacheGTSFollowRequestSweepFreq() time.Duration {
// SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field
func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) }
+// GetCacheGTSFollowRequestIDsMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestIDsMaxSize' field
+func (st *ConfigState) GetCacheGTSFollowRequestIDsMaxSize() (v int) {
+ st.mutex.RLock()
+ v = st.config.Cache.GTS.FollowRequestIDsMaxSize
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheGTSFollowRequestIDsMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowRequestIDsMaxSize' field
+func (st *ConfigState) SetCacheGTSFollowRequestIDsMaxSize(v int) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowRequestIDsMaxSize = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowRequestIDsMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowRequestIDsMaxSize' field
+func CacheGTSFollowRequestIDsMaxSizeFlag() string { return "cache-gts-follow-request-ids-max-size" }
+
+// GetCacheGTSFollowRequestIDsMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowRequestIDsMaxSize' field
+func GetCacheGTSFollowRequestIDsMaxSize() int { return global.GetCacheGTSFollowRequestIDsMaxSize() }
+
+// SetCacheGTSFollowRequestIDsMaxSize safely sets the value for global configuration 'Cache.GTS.FollowRequestIDsMaxSize' field
+func SetCacheGTSFollowRequestIDsMaxSize(v int) { global.SetCacheGTSFollowRequestIDsMaxSize(v) }
+
+// GetCacheGTSFollowRequestIDsTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestIDsTTL' field
+func (st *ConfigState) GetCacheGTSFollowRequestIDsTTL() (v time.Duration) {
+ st.mutex.RLock()
+ v = st.config.Cache.GTS.FollowRequestIDsTTL
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheGTSFollowRequestIDsTTL safely sets the Configuration value for state's 'Cache.GTS.FollowRequestIDsTTL' field
+func (st *ConfigState) SetCacheGTSFollowRequestIDsTTL(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowRequestIDsTTL = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowRequestIDsTTLFlag returns the flag name for the 'Cache.GTS.FollowRequestIDsTTL' field
+func CacheGTSFollowRequestIDsTTLFlag() string { return "cache-gts-follow-request-ids-ttl" }
+
+// GetCacheGTSFollowRequestIDsTTL safely fetches the value for global configuration 'Cache.GTS.FollowRequestIDsTTL' field
+func GetCacheGTSFollowRequestIDsTTL() time.Duration { return global.GetCacheGTSFollowRequestIDsTTL() }
+
+// SetCacheGTSFollowRequestIDsTTL safely sets the value for global configuration 'Cache.GTS.FollowRequestIDsTTL' field
+func SetCacheGTSFollowRequestIDsTTL(v time.Duration) { global.SetCacheGTSFollowRequestIDsTTL(v) }
+
+// GetCacheGTSFollowRequestIDsSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestIDsSweepFreq' field
+func (st *ConfigState) GetCacheGTSFollowRequestIDsSweepFreq() (v time.Duration) {
+ st.mutex.RLock()
+ v = st.config.Cache.GTS.FollowRequestIDsSweepFreq
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheGTSFollowRequestIDsSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowRequestIDsSweepFreq' field
+func (st *ConfigState) SetCacheGTSFollowRequestIDsSweepFreq(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowRequestIDsSweepFreq = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowRequestIDsSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowRequestIDsSweepFreq' field
+func CacheGTSFollowRequestIDsSweepFreqFlag() string { return "cache-gts-follow-request-ids-sweep-freq" }
+
+// GetCacheGTSFollowRequestIDsSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowRequestIDsSweepFreq' field
+func GetCacheGTSFollowRequestIDsSweepFreq() time.Duration {
+ return global.GetCacheGTSFollowRequestIDsSweepFreq()
+}
+
+// SetCacheGTSFollowRequestIDsSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestIDsSweepFreq' field
+func SetCacheGTSFollowRequestIDsSweepFreq(v time.Duration) {
+ global.SetCacheGTSFollowRequestIDsSweepFreq(v)
+}
+
// GetCacheGTSInstanceMaxSize safely fetches the Configuration value for state's 'Cache.GTS.InstanceMaxSize' field
func (st *ConfigState) GetCacheGTSInstanceMaxSize() (v int) {
st.mutex.RLock()
diff --git a/internal/db/account.go b/internal/db/account.go
index 21b8d6a1f..505ca4004 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -104,8 +104,6 @@ type Account interface {
// In the case of no statuses, this function will return db.ErrNoEntries.
GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error)
- GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error)
-
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned.
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 2ef1618db..e57c01a82 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -694,46 +694,6 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
return a.statusesFromIDs(ctx, statusIDs)
}
-func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
- blocks := []*gtsmodel.Block{}
-
- fq := a.db.
- NewSelect().
- Model(&blocks).
- Where("? = ?", bun.Ident("block.account_id"), accountID).
- Relation("TargetAccount").
- Order("block.id DESC")
-
- if maxID != "" {
- fq = fq.Where("? < ?", bun.Ident("block.id"), maxID)
- }
-
- if sinceID != "" {
- fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID)
- }
-
- if limit > 0 {
- fq = fq.Limit(limit)
- }
-
- if err := fq.Scan(ctx); err != nil {
- return nil, "", "", a.db.ProcessError(err)
- }
-
- if len(blocks) == 0 {
- return nil, "", "", db.ErrNoEntries
- }
-
- accounts := []*gtsmodel.Account{}
- for _, b := range blocks {
- accounts = append(accounts, b.TargetAccount)
- }
-
- nextMaxID := blocks[len(blocks)-1].ID
- prevMinID := blocks[0].ID
- return accounts, nextMaxID, prevMinID, nil
-}
-
func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) {
// Catch case of no statuses early
if len(statusIDs) == 0 {
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go
index 90bcd134d..04f22b6e9 100644
--- a/internal/db/bundb/emoji.go
+++ b/internal/db/bundb/emoji.go
@@ -126,16 +126,12 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
return err
}
- // Prepare SELECT accounts query.
- aq := tx.NewSelect().
- Table("accounts").
- Column("id")
-
- // Append a WHERE LIKE clause to the query
+ // Prepare a SELECT query with a WHERE LIKE
// that checks the `emoji` column for any
// text containing this specific emoji ID.
//
// (see GetStatusesUsingEmoji() for details.)
+ aq := tx.NewSelect().Table("accounts").Column("id")
aq = whereLike(aq, "emojis", id)
// Select all accounts using this emoji into accountIDss.
@@ -170,16 +166,12 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
}
}
- // Prepare SELECT statuses query.
- sq := tx.NewSelect().
- Table("statuses").
- Column("id")
-
- // Append a WHERE LIKE clause to the query
+ // Prepare a SELECT query with a WHERE LIKE
// that checks the `emoji` column for any
// text containing this specific emoji ID.
//
// (see GetStatusesUsingEmoji() for details.)
+ sq := tx.NewSelect().Table("statuses").Column("id")
sq = whereLike(sq, "emojis", id)
// Select all statuses using this emoji into statusIDs.
diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go
index 25bb3a65d..70faf837a 100644
--- a/internal/db/bundb/list.go
+++ b/internal/db/bundb/list.go
@@ -189,11 +189,10 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
gtscontext.SetBarebones(ctx),
id,
)
- if err != nil {
- if errors.Is(err, db.ErrNoEntries) {
- // Already gone.
- return nil
- }
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ // NOTE: even if db.ErrNoEntries is returned, we
+ // still run the below transaction to ensure related
+ // objects are appropriately deleted.
return err
}
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go
index 3b885af61..b8120b87a 100644
--- a/internal/db/bundb/media.go
+++ b/internal/db/bundb/media.go
@@ -106,8 +106,6 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt
}
func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
- defer m.state.Caches.GTS.Media().Invalidate("ID", id)
-
// Load media into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -120,10 +118,8 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
return err
}
- var (
- invalidateAccount bool
- invalidateStatus bool
- )
+ // On return, ensure that media with ID is invalidated.
+ defer m.state.Caches.GTS.Media().Invalidate("ID", id)
// Delete media attachment in new transaction.
err = m.db.RunInTx(ctx, func(tx bun.Tx) error {
@@ -161,9 +157,6 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
if _, err := set(q).Exec(ctx); err != nil {
return gtserror.Newf("error updating account: %w", err)
}
-
- // Mark as needing invalidate.
- invalidateAccount = true
}
}
@@ -178,33 +171,18 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
return gtserror.Newf("error selecting status: %w", err)
}
- // Get length of attachments beforehand.
- before := len(status.AttachmentIDs)
-
- for i := 0; i < len(status.AttachmentIDs); {
- if status.AttachmentIDs[i] == id {
- // Remove this reference to deleted attachment ID.
- copy(status.AttachmentIDs[i:], status.AttachmentIDs[i+1:])
- status.AttachmentIDs = status.AttachmentIDs[:len(status.AttachmentIDs)-1]
- continue
- }
- i++
- }
-
- if before != len(status.AttachmentIDs) {
- // Note: this accounts for status not found.
+ if updatedIDs := dropID(status.AttachmentIDs, id); // nocollapse
+ len(updatedIDs) != len(status.AttachmentIDs) {
+ // Note: this handles not found.
//
// Attachments changed, update the status.
if _, err := tx.NewUpdate().
Table("statuses").
Where("? = ?", bun.Ident("id"), status.ID).
- Set("? = ?", bun.Ident("attachment_ids"), status.AttachmentIDs).
+ Set("? = ?", bun.Ident("attachment_ids"), updatedIDs).
Exec(ctx); err != nil {
return gtserror.Newf("error updating status: %w", err)
}
-
- // Mark as needing invalidate.
- invalidateStatus = true
}
}
@@ -219,16 +197,6 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
return nil
})
- if invalidateAccount {
- // The account for given ID will have been updated in transaction.
- m.state.Caches.GTS.Account().Invalidate("ID", media.AccountID)
- }
-
- if invalidateStatus {
- // The status for given ID will have been updated in transaction.
- m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID)
- }
-
return m.db.ProcessError(err)
}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index eddd73b49..e7b563f2e 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -20,11 +20,12 @@
import (
"context"
"errors"
- "fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
@@ -45,7 +46,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
targetAccount,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err)
+ return nil, gtserror.Newf("error fetching follow: %w", err)
}
if follow != nil {
@@ -61,7 +62,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
requestingAccount,
)
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err)
+ return nil, gtserror.Newf("error checking followedBy: %w", err)
}
// check if requesting has follow requested target
@@ -70,19 +71,19 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
targetAccount,
)
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err)
+ return nil, gtserror.Newf("error checking requested: %w", err)
}
// check if the requesting account is blocking the target account
rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err)
+ return nil, gtserror.Newf("error checking blocking: %w", err)
}
// check if the requesting account is blocked by the target account
rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err)
+ return nil, gtserror.Newf("error checking blockedBy: %w", err)
}
// retrieve a note by the requesting account on the target account, if there is one
@@ -92,7 +93,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
targetAccount,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, fmt.Errorf("GetRelationship: error fetching note: %w", err)
+ return nil, gtserror.Newf("error fetching note: %w", err)
}
if note != nil {
rel.Note = note.Comment
@@ -102,87 +103,186 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
}
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
- var followIDs []string
- if err := newSelectFollows(r.db, accountID).
- Scan(ctx, &followIDs); err != nil {
- return nil, r.db.ProcessError(err)
+ followIDs, err := r.getAccountFollowIDs(ctx, accountID)
+ if err != nil {
+ return nil, err
}
return r.GetFollowsByIDs(ctx, followIDs)
}
func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
- var followIDs []string
- if err := newSelectLocalFollows(r.db, accountID).
- Scan(ctx, &followIDs); err != nil {
- return nil, r.db.ProcessError(err)
+ followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID)
+ if err != nil {
+ return nil, err
}
return r.GetFollowsByIDs(ctx, followIDs)
}
func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
- var followIDs []string
- if err := newSelectFollowers(r.db, accountID).
- Scan(ctx, &followIDs); err != nil {
- return nil, r.db.ProcessError(err)
+ followerIDs, err := r.getAccountFollowerIDs(ctx, accountID)
+ if err != nil {
+ return nil, err
}
- return r.GetFollowsByIDs(ctx, followIDs)
+ return r.GetFollowsByIDs(ctx, followerIDs)
}
func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
- var followIDs []string
- if err := newSelectLocalFollowers(r.db, accountID).
- Scan(ctx, &followIDs); err != nil {
- return nil, r.db.ProcessError(err)
+ followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID)
+ if err != nil {
+ return nil, err
}
- return r.GetFollowsByIDs(ctx, followIDs)
-}
-
-func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
- n, err := newSelectFollows(r.db, accountID).Count(ctx)
- return n, r.db.ProcessError(err)
-}
-
-func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
- n, err := newSelectLocalFollows(r.db, accountID).Count(ctx)
- return n, r.db.ProcessError(err)
-}
-
-func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
- n, err := newSelectFollowers(r.db, accountID).Count(ctx)
- return n, r.db.ProcessError(err)
-}
-
-func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
- n, err := newSelectLocalFollowers(r.db, accountID).Count(ctx)
- return n, r.db.ProcessError(err)
+ return r.GetFollowsByIDs(ctx, followerIDs)
}
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
- var followReqIDs []string
- if err := newSelectFollowRequests(r.db, accountID).
- Scan(ctx, &followReqIDs); err != nil {
- return nil, r.db.ProcessError(err)
+ followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID)
+ if err != nil {
+ return nil, err
}
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
- var followReqIDs []string
- if err := newSelectFollowRequesting(r.db, accountID).
- Scan(ctx, &followReqIDs); err != nil {
- return nil, r.db.ProcessError(err)
+ followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID)
+ if err != nil {
+ return nil, err
}
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
+func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Pager) ([]*gtsmodel.Block, error) {
+ // Load block IDs from cache with database loader callback.
+ blockIDs, err := r.state.Caches.GTS.BlockIDs().LoadRange(accountID, func() ([]string, error) {
+ var blockIDs []string
+
+ // Block IDs not in cache, perform DB query!
+ q := newSelectBlocks(r.db, accountID)
+ if _, err := q.Exec(ctx, &blockIDs); err != nil {
+ return nil, r.db.ProcessError(err)
+ }
+
+ return blockIDs, nil
+ }, page.PageDesc)
+ if err != nil {
+ return nil, err
+ }
+
+ // Convert these IDs to full block objects.
+ return r.GetBlocksByIDs(ctx, blockIDs)
+}
+
+func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
+ followIDs, err := r.getAccountFollowIDs(ctx, accountID)
+ return len(followIDs), err
+}
+
+func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
+ followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID)
+ return len(followIDs), err
+}
+
+func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
+ followerIDs, err := r.getAccountFollowerIDs(ctx, accountID)
+ return len(followerIDs), err
+}
+
+func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
+ followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID)
+ return len(followerIDs), err
+}
+
func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
- n, err := newSelectFollowRequests(r.db, accountID).Count(ctx)
- return n, r.db.ProcessError(err)
+ followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID)
+ return len(followReqIDs), err
}
func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
- n, err := newSelectFollowRequesting(r.db, accountID).Count(ctx)
- return n, r.db.ProcessError(err)
+ followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID)
+ return len(followReqIDs), err
+}
+
+func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) {
+ return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) {
+ var followIDs []string
+
+ // Follow IDs not in cache, perform DB query!
+ q := newSelectFollows(r.db, accountID)
+ if _, err := q.Exec(ctx, &followIDs); err != nil {
+ return nil, r.db.ProcessError(err)
+ }
+
+ return followIDs, nil
+ })
+}
+
+func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) {
+ return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) {
+ var followIDs []string
+
+ // Follow IDs not in cache, perform DB query!
+ q := newSelectLocalFollows(r.db, accountID)
+ if _, err := q.Exec(ctx, &followIDs); err != nil {
+ return nil, r.db.ProcessError(err)
+ }
+
+ return followIDs, nil
+ })
+}
+
+func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) {
+ return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) {
+ var followIDs []string
+
+ // Follow IDs not in cache, perform DB query!
+ q := newSelectFollowers(r.db, accountID)
+ if _, err := q.Exec(ctx, &followIDs); err != nil {
+ return nil, r.db.ProcessError(err)
+ }
+
+ return followIDs, nil
+ })
+}
+
+func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) {
+ return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) {
+ var followIDs []string
+
+ // Follow IDs not in cache, perform DB query!
+ q := newSelectLocalFollowers(r.db, accountID)
+ if _, err := q.Exec(ctx, &followIDs); err != nil {
+ return nil, r.db.ProcessError(err)
+ }
+
+ return followIDs, nil
+ })
+}
+
+func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) {
+ return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) {
+ var followReqIDs []string
+
+ // Follow request IDs not in cache, perform DB query!
+ q := newSelectFollowRequests(r.db, accountID)
+ if _, err := q.Exec(ctx, &followReqIDs); err != nil {
+ return nil, r.db.ProcessError(err)
+ }
+
+ return followReqIDs, nil
+ })
+}
+
+func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) {
+ return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) {
+ var followReqIDs []string
+
+ // Follow request IDs not in cache, perform DB query!
+ q := newSelectFollowRequesting(r.db, accountID)
+ if _, err := q.Exec(ctx, &followReqIDs); err != nil {
+ return nil, r.db.ProcessError(err)
+ }
+
+ return followReqIDs, nil
+ })
}
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
@@ -256,3 +356,12 @@ func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery {
).
OrderExpr("? DESC", bun.Ident("updated_at"))
}
+
+// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID.
+func newSelectBlocks(db *WrappedDB, accountID string) *bun.SelectQuery {
+ return db.NewSelect().
+ TableExpr("?", bun.Ident("blocks")).
+ ColumnExpr("?", bun.Ident("?")).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
+}
diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go
index 948e82fcb..2a042bed4 100644
--- a/internal/db/bundb/relationship_block.go
+++ b/internal/db/bundb/relationship_block.go
@@ -25,6 +25,7 @@
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
@@ -97,6 +98,25 @@ func(block *gtsmodel.Block) error {
)
}
+func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) {
+ // Preallocate slice of expected length.
+ blocks := make([]*gtsmodel.Block, 0, len(ids))
+
+ for _, id := range ids {
+ // Fetch block model for this ID.
+ block, err := r.GetBlockByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting block %q: %v", id, err)
+ continue
+ }
+
+ // Append to return slice.
+ blocks = append(blocks, block)
+ }
+
+ return blocks, nil
+}
+
func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
// Fetch block from cache with loader callback
block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
@@ -148,8 +168,6 @@ func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) er
}
func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
- defer r.state.Caches.GTS.Block().Invalidate("ID", id)
-
// Load block into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -162,6 +180,9 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
return err
}
+ // Drop this now-cached block on return after delete.
+ defer r.state.Caches.GTS.Block().Invalidate("ID", id)
+
// Finally delete block from DB.
_, err = r.db.NewDelete().
Table("blocks").
@@ -171,8 +192,6 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
}
func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error {
- defer r.state.Caches.GTS.Block().Invalidate("URI", uri)
-
// Load block into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -185,6 +204,9 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error
return err
}
+ // Drop this now-cached block on return after delete.
+ defer r.state.Caches.GTS.Block().Invalidate("URI", uri)
+
// Finally delete block from DB.
_, err = r.db.NewDelete().
Table("blocks").
@@ -211,10 +233,9 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri
}
defer func() {
- // Invalidate all IDs on return.
- for _, id := range blockIDs {
- r.state.Caches.GTS.Block().Invalidate("ID", id)
- }
+ // Invalidate all account's incoming / outoing blocks on return.
+ r.state.Caches.GTS.Block().Invalidate("AccountID", accountID)
+ r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID)
}()
// Load all blocks into cache, this *really* isn't great
diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go
index 84501b0be..3b0597612 100644
--- a/internal/db/bundb/relationship_follow.go
+++ b/internal/db/bundb/relationship_follow.go
@@ -233,8 +233,6 @@ func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error {
}
func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID string, targetAccountID string) error {
- defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
-
// Load follow into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -251,13 +249,14 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin
return err
}
+ // Drop this now-cached follow on return after delete.
+ defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
+
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
}
func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
- defer r.state.Caches.GTS.Follow().Invalidate("ID", id)
-
// Load follow into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -270,13 +269,14 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error
return err
}
+ // Drop this now-cached follow on return after delete.
+ defer r.state.Caches.GTS.Follow().Invalidate("ID", id)
+
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
}
func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
- defer r.state.Caches.GTS.Follow().Invalidate("URI", uri)
-
// Load follow into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -289,6 +289,9 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro
return err
}
+ // Drop this now-cached follow on return after delete.
+ defer r.state.Caches.GTS.Follow().Invalidate("URI", uri)
+
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
}
@@ -312,10 +315,9 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
}
defer func() {
- // Invalidate all IDs on return.
- for _, id := range followIDs {
- r.state.Caches.GTS.Follow().Invalidate("ID", id)
- }
+ // Invalidate all account's incoming / outoing follows on return.
+ r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID)
+ r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID)
}()
// Load all follows into cache, this *really* isn't great
diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go
index a6e913953..dc5e760e6 100644
--- a/internal/db/bundb/relationship_follow_req.go
+++ b/internal/db/bundb/relationship_follow_req.go
@@ -208,9 +208,6 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
return nil, err
}
- // Invalidate follow request from cache lookups on return.
- defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID)
-
// Delete original follow request.
if _, err := r.db.
NewDelete().
@@ -243,8 +240,6 @@ func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountI
}
func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) error {
- defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
-
// Load followreq into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -261,6 +256,9 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI
return err
}
+ // Drop this now-cached follow request on return after delete.
+ defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
+
// Finally delete followreq from DB.
_, err = r.db.NewDelete().
Table("follow_requests").
@@ -270,8 +268,6 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI
}
func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
- defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
-
// Load followreq into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -284,6 +280,9 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)
return err
}
+ // Drop this now-cached follow request on return after delete.
+ defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
+
// Finally delete followreq from DB.
_, err = r.db.NewDelete().
Table("follow_requests").
@@ -293,8 +292,6 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)
}
func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
- defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
-
// Load followreq into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -307,6 +304,9 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin
return err
}
+ // Drop this now-cached follow request on return after delete.
+ defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
+
// Finally delete followreq from DB.
_, err = r.db.NewDelete().
Table("follow_requests").
@@ -334,10 +334,9 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
}
defer func() {
- // Invalidate all IDs on return.
- for _, id := range followReqIDs {
- r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
- }
+ // Invalidate all account's incoming / outoing follow requests on return.
+ r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID)
+ r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID)
}()
// Load all followreqs into cache, this *really* isn't
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index a019216d0..4dc7d8468 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -381,8 +381,6 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
}
func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {
- defer s.state.Caches.GTS.Status().Invalidate("ID", id)
-
// Load status into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
@@ -397,6 +395,9 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {
return err
}
+ // On return ensure status invalidated from cache.
+ defer s.state.Caches.GTS.Status().Invalidate("ID", id)
+
return s.db.RunInTx(ctx, func(tx bun.Tx) error {
// delete links between this status and any emojis it uses
if _, err := tx.
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index e19aee646..6ba9fdf8c 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -21,6 +21,7 @@
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
)
// Relationship contains functions for getting or modifying the relationship between two accounts.
@@ -166,6 +167,9 @@ type Relationship interface {
// CountAccountFollowerRequests returns number of follow requests originating from the given account.
CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
+ // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters.
+ GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Pager) ([]*gtsmodel.Block, error)
+
// GetNote gets a private note from a source account on a target account, if it exists.
GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error)
diff --git a/internal/paging/paging.go b/internal/paging/paging.go
new file mode 100644
index 000000000..0323f40bc
--- /dev/null
+++ b/internal/paging/paging.go
@@ -0,0 +1,227 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package paging
+
+import "golang.org/x/exp/slices"
+
+// Pager provides a means of paging serialized IDs,
+// using the terminology of our API endpoint queries.
+type Pager struct {
+ // SinceID will limit the returned
+ // page of IDs to contain newer than
+ // since ID (excluding it). Result
+ // will be returned DESCENDING.
+ SinceID string
+
+ // MinID will limit the returned
+ // page of IDs to contain newer than
+ // min ID (excluding it). Result
+ // will be returned ASCENDING.
+ MinID string
+
+ // MaxID will limit the returned
+ // page of IDs to contain older
+ // than (excluding) this max ID.
+ MaxID string
+
+ // Limit will limit the returned
+ // page of IDs to at most 'limit'.
+ Limit int
+}
+
+// Page will page the given slice of GoToSocial IDs according
+// to the receiving Pager's SinceID, MinID, MaxID and Limits.
+// NOTE THE INPUT SLICE MUST BE SORTED IN ASCENDING ORDER
+// (I.E. OLDEST ITEMS AT LOWEST INDICES, NEWER AT HIGHER).
+func (p *Pager) PageAsc(ids []string) []string {
+ if p == nil {
+ // no paging.
+ return ids
+ }
+
+ var asc bool
+
+ if p.SinceID != "" {
+ // If a sinceID is given, we
+ // page down i.e. descending.
+ asc = false
+
+ for i := 0; i < len(ids); i++ {
+ if ids[i] == p.SinceID {
+ // Hit the boundary.
+ // Reslice to be:
+ // "from here"
+ ids = ids[i+1:]
+ break
+ }
+ }
+ } else if p.MinID != "" {
+ // We only support minID if
+ // no sinceID is provided.
+ //
+ // If a minID is given, we
+ // page up, i.e. ascending.
+ asc = true
+
+ for i := 0; i < len(ids); i++ {
+ if ids[i] == p.MinID {
+ // Hit the boundary.
+ // Reslice to be:
+ // "from here"
+ ids = ids[i+1:]
+ break
+ }
+ }
+ }
+
+ if p.MaxID != "" {
+ for i := 0; i < len(ids); i++ {
+ if ids[i] == p.MaxID {
+ // Hit the boundary.
+ // Reslice to be:
+ // "up to here"
+ ids = ids[:i]
+ break
+ }
+ }
+ }
+
+ if !asc && len(ids) > 1 {
+ var (
+ // Start at front.
+ i = 0
+
+ // Start at back.
+ j = len(ids) - 1
+ )
+
+ // Clone input IDs before
+ // we perform modifications.
+ ids = slices.Clone(ids)
+
+ for i < j {
+ // Swap i,j index values in slice.
+ ids[i], ids[j] = ids[j], ids[i]
+
+ // incr + decr,
+ // looping until
+ // they meet in
+ // the middle.
+ i++
+ j--
+ }
+ }
+
+ if p.Limit > 0 && p.Limit < len(ids) {
+ // Reslice IDs to given limit.
+ ids = ids[:p.Limit]
+ }
+
+ return ids
+}
+
+// Page will page the given slice of GoToSocial IDs according
+// to the receiving Pager's SinceID, MinID, MaxID and Limits.
+// NOTE THE INPUT SLICE MUST BE SORTED IN ASCENDING ORDER.
+// (I.E. NEWEST ITEMS AT LOWEST INDICES, OLDER AT HIGHER).
+func (p *Pager) PageDesc(ids []string) []string {
+ if p == nil {
+ // no paging.
+ return ids
+ }
+
+ var asc bool
+
+ if p.MaxID != "" {
+ for i := 0; i < len(ids); i++ {
+ if ids[i] == p.MaxID {
+ // Hit the boundary.
+ // Reslice to be:
+ // "from here"
+ ids = ids[i+1:]
+ break
+ }
+ }
+ }
+
+ if p.SinceID != "" {
+ // If a sinceID is given, we
+ // page down i.e. descending.
+ asc = false
+
+ for i := 0; i < len(ids); i++ {
+ if ids[i] == p.SinceID {
+ // Hit the boundary.
+ // Reslice to be:
+ // "up to here"
+ ids = ids[:i]
+ break
+ }
+ }
+ } else if p.MinID != "" {
+ // We only support minID if
+ // no sinceID is provided.
+ //
+ // If a minID is given, we
+ // page up, i.e. ascending.
+ asc = true
+
+ for i := 0; i < len(ids); i++ {
+ if ids[i] == p.MinID {
+ // Hit the boundary.
+ // Reslice to be:
+ // "up to here"
+ ids = ids[:i]
+ break
+ }
+ }
+ }
+
+ if asc && len(ids) > 1 {
+ var (
+ // Start at front.
+ i = 0
+
+ // Start at back.
+ j = len(ids) - 1
+ )
+
+ // Clone input IDs before
+ // we perform modifications.
+ ids = slices.Clone(ids)
+
+ for i < j {
+ // Swap i,j index values in slice.
+ ids[i], ids[j] = ids[j], ids[i]
+
+ // incr + decr,
+ // looping until
+ // they meet in
+ // the middle.
+ i++
+ j--
+ }
+ }
+
+ if p.Limit > 0 && p.Limit < len(ids) {
+ // Reslice IDs to given limit.
+ ids = ids[:p.Limit]
+ }
+
+ return ids
+}
diff --git a/internal/paging/paging_test.go b/internal/paging/paging_test.go
new file mode 100644
index 000000000..71c3be0c9
--- /dev/null
+++ b/internal/paging/paging_test.go
@@ -0,0 +1,171 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package paging_test
+
+import (
+ "testing"
+
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
+ "golang.org/x/exp/slices"
+)
+
+type Case struct {
+ // Name is the test case name.
+ Name string
+
+ // Input contains test case input ID slice.
+ Input []string
+
+ // Expect contains expected test case output.
+ Expect []string
+
+ // Page contains the paging function to use.
+ Page func([]string) []string
+}
+
+var cases = []Case{
+ {
+ Name: "min_id and max_id set",
+ Input: []string{
+ "064Q5D7VG6TPPQ46T09MHJ96FW",
+ "064Q5D7VGPTC4NK5T070VYSSF8",
+ "064Q5D7VH5F0JXG6W5NCQ3JCWW",
+ "064Q5D7VHMSW9DF3GCS088VAZC",
+ "064Q5D7VJ073XG9ZTWHA2KHN10",
+ "064Q5D7VJADJTPA3GW8WAX10TW",
+ "064Q5D7VJMWXZD3S1KT7RD51N8",
+ "064Q5D7VJYFBYSAH86KDBKZ6AC",
+ "064Q5D7VK8H7WMJS399SHEPCB0",
+ "064Q5D7VKG5EQ43TYP71B4K6K0",
+ },
+ Expect: []string{
+ "064Q5D7VGPTC4NK5T070VYSSF8",
+ "064Q5D7VH5F0JXG6W5NCQ3JCWW",
+ "064Q5D7VHMSW9DF3GCS088VAZC",
+ "064Q5D7VJ073XG9ZTWHA2KHN10",
+ "064Q5D7VJADJTPA3GW8WAX10TW",
+ "064Q5D7VJMWXZD3S1KT7RD51N8",
+ "064Q5D7VJYFBYSAH86KDBKZ6AC",
+ "064Q5D7VK8H7WMJS399SHEPCB0",
+ },
+ Page: (&paging.Pager{
+ MinID: "064Q5D7VG6TPPQ46T09MHJ96FW",
+ MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0",
+ }).PageAsc,
+ },
+ {
+ Name: "min_id, max_id and limit set",
+ Input: []string{
+ "064Q5D7VG6TPPQ46T09MHJ96FW",
+ "064Q5D7VGPTC4NK5T070VYSSF8",
+ "064Q5D7VH5F0JXG6W5NCQ3JCWW",
+ "064Q5D7VHMSW9DF3GCS088VAZC",
+ "064Q5D7VJ073XG9ZTWHA2KHN10",
+ "064Q5D7VJADJTPA3GW8WAX10TW",
+ "064Q5D7VJMWXZD3S1KT7RD51N8",
+ "064Q5D7VJYFBYSAH86KDBKZ6AC",
+ "064Q5D7VK8H7WMJS399SHEPCB0",
+ "064Q5D7VKG5EQ43TYP71B4K6K0",
+ },
+ Expect: []string{
+ "064Q5D7VGPTC4NK5T070VYSSF8",
+ "064Q5D7VH5F0JXG6W5NCQ3JCWW",
+ "064Q5D7VHMSW9DF3GCS088VAZC",
+ "064Q5D7VJ073XG9ZTWHA2KHN10",
+ "064Q5D7VJADJTPA3GW8WAX10TW",
+ },
+ Page: (&paging.Pager{
+ MinID: "064Q5D7VG6TPPQ46T09MHJ96FW",
+ MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0",
+ Limit: 5,
+ }).PageAsc,
+ },
+ {
+ Name: "min_id, max_id and too-large limit set",
+ Input: []string{
+ "064Q5D7VG6TPPQ46T09MHJ96FW",
+ "064Q5D7VGPTC4NK5T070VYSSF8",
+ "064Q5D7VH5F0JXG6W5NCQ3JCWW",
+ "064Q5D7VHMSW9DF3GCS088VAZC",
+ "064Q5D7VJ073XG9ZTWHA2KHN10",
+ "064Q5D7VJADJTPA3GW8WAX10TW",
+ "064Q5D7VJMWXZD3S1KT7RD51N8",
+ "064Q5D7VJYFBYSAH86KDBKZ6AC",
+ "064Q5D7VK8H7WMJS399SHEPCB0",
+ "064Q5D7VKG5EQ43TYP71B4K6K0",
+ },
+ Expect: []string{
+ "064Q5D7VGPTC4NK5T070VYSSF8",
+ "064Q5D7VH5F0JXG6W5NCQ3JCWW",
+ "064Q5D7VHMSW9DF3GCS088VAZC",
+ "064Q5D7VJ073XG9ZTWHA2KHN10",
+ "064Q5D7VJADJTPA3GW8WAX10TW",
+ "064Q5D7VJMWXZD3S1KT7RD51N8",
+ "064Q5D7VJYFBYSAH86KDBKZ6AC",
+ "064Q5D7VK8H7WMJS399SHEPCB0",
+ },
+ Page: (&paging.Pager{
+ MinID: "064Q5D7VG6TPPQ46T09MHJ96FW",
+ MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0",
+ Limit: 100,
+ }).PageAsc,
+ },
+ {
+ Name: "since_id and max_id set",
+ Input: []string{
+ "064Q5D7VG6TPPQ46T09MHJ96FW",
+ "064Q5D7VGPTC4NK5T070VYSSF8",
+ "064Q5D7VH5F0JXG6W5NCQ3JCWW",
+ "064Q5D7VHMSW9DF3GCS088VAZC",
+ "064Q5D7VJ073XG9ZTWHA2KHN10",
+ "064Q5D7VJADJTPA3GW8WAX10TW",
+ "064Q5D7VJMWXZD3S1KT7RD51N8",
+ "064Q5D7VJYFBYSAH86KDBKZ6AC",
+ "064Q5D7VK8H7WMJS399SHEPCB0",
+ "064Q5D7VKG5EQ43TYP71B4K6K0",
+ },
+ Expect: []string{
+ "064Q5D7VK8H7WMJS399SHEPCB0",
+ "064Q5D7VJYFBYSAH86KDBKZ6AC",
+ "064Q5D7VJMWXZD3S1KT7RD51N8",
+ "064Q5D7VJADJTPA3GW8WAX10TW",
+ "064Q5D7VJ073XG9ZTWHA2KHN10",
+ "064Q5D7VHMSW9DF3GCS088VAZC",
+ "064Q5D7VH5F0JXG6W5NCQ3JCWW",
+ "064Q5D7VGPTC4NK5T070VYSSF8",
+ },
+ Page: (&paging.Pager{
+ SinceID: "064Q5D7VG6TPPQ46T09MHJ96FW",
+ MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0",
+ }).PageAsc,
+ },
+}
+
+func TestPage(t *testing.T) {
+ for _, c := range cases {
+ t.Run(c.Name, func(t *testing.T) {
+ // Page the input slice.
+ out := c.Page(c.Input)
+
+ // Check paged output is as expected.
+ if !slices.Equal(out, c.Expect) {
+ t.Errorf("\nreceived=%v\nexpect%v\n", out, c.Expect)
+ }
+ })
+ }
+}
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index 2a20ec96e..a613ba485 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -20,7 +20,6 @@
import (
"context"
"errors"
- "fmt"
"net"
"time"
@@ -114,38 +113,38 @@ func (p *Processor) DeleteSelf(ctx context.Context, account *gtsmodel.Account) g
func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *gtsmodel.Account) error {
user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)
if err != nil {
- return fmt.Errorf("deleteUserAndTokensForAccount: db error getting user: %w", err)
+ return gtserror.Newf("db error getting user: %w", err)
}
tokens := []*gtsmodel.Token{}
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err != nil {
- return fmt.Errorf("deleteUserAndTokensForAccount: db error getting tokens: %w", err)
+ return gtserror.Newf("db error getting tokens: %w", err)
}
for _, t := range tokens {
// Delete any OAuth clients associated with this token.
if err := p.state.DB.DeleteByID(ctx, t.ClientID, &[]*gtsmodel.Client{}); err != nil {
- return fmt.Errorf("deleteUserAndTokensForAccount: db error deleting client: %w", err)
+ return gtserror.Newf("db error deleting client: %w", err)
}
// Delete any OAuth applications associated with this token.
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, &[]*gtsmodel.Application{}); err != nil {
- return fmt.Errorf("deleteUserAndTokensForAccount: db error deleting application: %w", err)
+ return gtserror.Newf("db error deleting application: %w", err)
}
// Delete the token itself.
if err := p.state.DB.DeleteByID(ctx, t.ID, t); err != nil {
- return fmt.Errorf("deleteUserAndTokensForAccount: db error deleting token: %w", err)
+ return gtserror.Newf("db error deleting token: %w", err)
}
}
columns, err := stubbifyUser(user)
if err != nil {
- return fmt.Errorf("deleteUserAndTokensForAccount: error stubbifying user: %w", err)
+ return gtserror.Newf("error stubbifying user: %w", err)
}
if err := p.state.DB.UpdateUser(ctx, user, columns...); err != nil {
- return fmt.Errorf("deleteUserAndTokensForAccount: db error updating user: %w", err)
+ return gtserror.Newf("db error updating user: %w", err)
}
return nil
@@ -160,24 +159,24 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
// Delete follows targeting this account.
followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return fmt.Errorf("deleteAccountFollows: db error getting follows targeting account %s: %w", account.ID, err)
+ return gtserror.Newf("db error getting follows targeting account %s: %w", account.ID, err)
}
for _, follow := range followedBy {
if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil {
- return fmt.Errorf("deleteAccountFollows: db error unfollowing account followedBy: %w", err)
+ return gtserror.Newf("db error unfollowing account followedBy: %w", err)
}
}
// Delete follow requests targeting this account.
followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return fmt.Errorf("deleteAccountFollows: db error getting follow requests targeting account %s: %w", account.ID, err)
+ return gtserror.Newf("db error getting follow requests targeting account %s: %w", account.ID, err)
}
for _, followRequest := range followRequestedBy {
if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil {
- return fmt.Errorf("deleteAccountFollows: db error unfollowing account followRequestedBy: %w", err)
+ return gtserror.Newf("db error unfollowing account followRequestedBy: %w", err)
}
}
@@ -193,14 +192,14 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
// Delete follows originating from this account.
following, err := p.state.DB.GetAccountFollows(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return fmt.Errorf("deleteAccountFollows: db error getting follows owned by account %s: %w", account.ID, err)
+ return gtserror.Newf("db error getting follows owned by account %s: %w", account.ID, err)
}
// For each follow owned by this account, unfollow
// and process side effects (noop if remote account).
for _, follow := range following {
if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil {
- return fmt.Errorf("deleteAccountFollows: db error unfollowing account: %w", err)
+ return gtserror.Newf("db error unfollowing account: %w", err)
}
if msg := unfollowSideEffects(ctx, account, follow); msg != nil {
// There was a side effect to process.
@@ -211,14 +210,14 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
// Delete follow requests originating from this account.
followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return fmt.Errorf("deleteAccountFollows: db error getting follow requests owned by account %s: %w", account.ID, err)
+ return gtserror.Newf("db error getting follow requests owned by account %s: %w", account.ID, err)
}
// For each follow owned by this account, unfollow
// and process side effects (noop if remote account).
for _, followRequest := range followRequesting {
if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil {
- return fmt.Errorf("deleteAccountFollows: db error unfollowingRequesting account: %w", err)
+ return gtserror.Newf("db error unfollowingRequesting account: %w", err)
}
// Dummy out a follow so our side effects func
@@ -279,7 +278,7 @@ func (p *Processor) unfollowSideEffectsFunc(deletedAccount *gtsmodel.Account) fu
func (p *Processor) deleteAccountBlocks(ctx context.Context, account *gtsmodel.Account) error {
if err := p.state.DB.DeleteAccountBlocks(ctx, account.ID); err != nil {
- return fmt.Errorf("deleteAccountBlocks: db error deleting account blocks for %s: %w", account.ID, err)
+ return gtserror.Newf("db error deleting account blocks for %s: %w", account.ID, err)
}
return nil
}
@@ -333,7 +332,7 @@ func (p *Processor) deleteAccountStatuses(ctx context.Context, account *gtsmodel
// Look for any boosts of this status in DB.
boosts, err := p.state.DB.GetStatusReblogs(ctx, status)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return fmt.Errorf("deleteAccountStatuses: error fetching status reblogs for %s: %w", status.ID, err)
+ return gtserror.Newf("error fetching status reblogs for %s: %w", status.ID, err)
}
for _, boost := range boosts {
@@ -347,7 +346,7 @@ func (p *Processor) deleteAccountStatuses(ctx context.Context, account *gtsmodel
log.WithContext(ctx).WithField("boost", boost).Warnf("no account found with id %s for boost %s", boost.AccountID, boost.ID)
continue
}
- return fmt.Errorf("deleteAccountStatuses: error fetching boosted status account for %s: %w", boost.AccountID, err)
+ return gtserror.Newf("error fetching boosted status account for %s: %w", boost.AccountID, err)
}
// Set account model
@@ -505,7 +504,7 @@ func stubbifyUser(user *gtsmodel.User) ([]string, error) {
return nil, err
}
- var never = time.Time{}
+ never := time.Time{}
user.EncryptedPassword = string(dummyPassword)
user.SignUpIP = net.IPv4zero
diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go
index 644f28ca9..8996dff92 100644
--- a/internal/processing/blocks.go
+++ b/internal/processing/blocks.go
@@ -19,69 +19,71 @@
import (
"context"
- "fmt"
- "net/url"
+ "errors"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
- "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *Processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
- accounts, nextMaxID, prevMinID, err := p.state.DB.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit)
- if err != nil {
- if err == db.ErrNoEntries {
- // there are just no entries
- return &apimodel.BlocksResponse{
- Accounts: []*apimodel.Account{},
- }, nil
- }
- // there's an actual error
+// BlocksGet ...
+func (p *Processor) BlocksGet(
+ ctx context.Context,
+ requestingAccount *gtsmodel.Account,
+ page paging.Pager,
+) (*apimodel.PageableResponse, gtserror.WithCode) {
+ blocks, err := p.state.DB.GetAccountBlocks(ctx,
+ requestingAccount.ID,
+ &page,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
}
- apiAccounts := []*apimodel.Account{}
- for _, a := range accounts {
- apiAccount, err := p.tc.AccountToAPIAccountBlocked(ctx, a)
- if err != nil {
+ // Check for zero length.
+ count := len(blocks)
+ if len(blocks) == 0 {
+ return util.EmptyPageableResponse(), nil
+ }
+
+ var (
+ items = make([]interface{}, 0, count)
+
+ // Set next + prev values before API converting
+ // so the caller can still page even on error.
+ nextMaxIDValue = blocks[count-1].ID
+ prevMinIDValue = blocks[0].ID
+ )
+
+ for _, block := range blocks {
+ if block.TargetAccount == nil {
+ // All models should be populated at this point.
+ log.Warnf(ctx, "block target account was nil: %v", err)
continue
}
- apiAccounts = append(apiAccounts, apiAccount)
- }
- return p.packageBlocksResponse(apiAccounts, "/api/v1/blocks", nextMaxID, prevMinID, limit)
-}
-
-func (p *Processor) packageBlocksResponse(accounts []*apimodel.Account, path string, nextMaxID string, prevMinID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
- resp := &apimodel.BlocksResponse{
- Accounts: []*apimodel.Account{},
- }
- resp.Accounts = accounts
-
- // prepare the next and previous links
- if len(accounts) != 0 {
- protocol := config.GetProtocol()
- host := config.GetHost()
-
- nextLink := &url.URL{
- Scheme: protocol,
- Host: host,
- Path: path,
- RawQuery: fmt.Sprintf("limit=%d&max_id=%s", limit, nextMaxID),
+ // Convert target account to frontend API model.
+ account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount)
+ if err != nil {
+ log.Errorf(ctx, "error converting account to public api account: %v", err)
+ continue
}
- next := fmt.Sprintf("<%s>; rel=\"next\"", nextLink.String())
- prevLink := &url.URL{
- Scheme: protocol,
- Host: host,
- Path: path,
- RawQuery: fmt.Sprintf("limit=%d&min_id=%s", limit, prevMinID),
- }
- prev := fmt.Sprintf("<%s>; rel=\"prev\"", prevLink.String())
- resp.LinkHeader = fmt.Sprintf("%s, %s", next, prev)
+ // Append target to return items.
+ items = append(items, account)
}
- return resp, nil
+ return util.PackagePageableResponse(util.PageableResponseParams{
+ Items: items,
+ Path: "/api/v1/blocks",
+ NextMaxIDKey: "max_id",
+ PrevMinIDKey: "since_id",
+ NextMaxIDValue: nextMaxIDValue,
+ PrevMinIDValue: prevMinIDValue,
+ Limit: page.Limit,
+ })
}
diff --git a/test/envparsing.sh b/test/envparsing.sh
index 8f4372906..b9017d0be 100755
--- a/test/envparsing.sh
+++ b/test/envparsing.sh
@@ -25,6 +25,9 @@ EXPECT=$(cat <<"EOF"
"account-note-ttl": 1800000000000,
"account-sweep-freq": 1000000000,
"account-ttl": 10800000000000,
+ "block-ids-max-size": 500,
+ "block-ids-sweep-freq": 60000000000,
+ "block-ids-ttl": 1800000000000,
"block-max-size": 1000,
"block-sweep-freq": 60000000000,
"block-ttl": 1800000000000,
@@ -37,7 +40,13 @@ EXPECT=$(cat <<"EOF"
"emoji-max-size": 2000,
"emoji-sweep-freq": 60000000000,
"emoji-ttl": 1800000000000,
+ "follow-ids-max-size": 500,
+ "follow-ids-sweep-freq": 60000000000,
+ "follow-ids-ttl": 1800000000000,
"follow-max-size": 2000,
+ "follow-request-ids-max-size": 500,
+ "follow-request-ids-sweep-freq": 60000000000,
+ "follow-request-ids-ttl": 1800000000000,
"follow-request-max-size": 2000,
"follow-request-sweep-freq": 60000000000,
"follow-request-ttl": 1800000000000,
diff --git a/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go b/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go
index 623a19910..af108e336 100644
--- a/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go
+++ b/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go
@@ -479,23 +479,23 @@ func (c *Cache[K, V]) InvalidateAll(keys ...K) (ok bool) {
kvs = make([]kv[K, V], 0, len(keys))
c.locked(func() {
- for _, key := range keys {
+ for x := range keys {
var item *Entry[K, V]
// Check for item in cache
- item, ok = c.Cache.Get(key)
+ item, ok = c.Cache.Get(keys[x])
if !ok {
- return
+ continue
}
// Append this old value to slice
kvs = append(kvs, kv[K, V]{
- K: key,
+ K: keys[x],
V: item.Value,
})
// Remove from cache map
- _ = c.Cache.Delete(key)
+ _ = c.Cache.Delete(keys[x])
// Free entry
c.free(item)
@@ -553,6 +553,7 @@ func (c *Cache[K, V]) Cap() (l int) {
return
}
+// locked performs given function within mutex lock (NOTE: UNLOCK IS NOT DEFERRED).
func (c *Cache[K, V]) locked(fn func()) {
c.Lock()
fn()
diff --git a/vendor/modules.txt b/vendor/modules.txt
index 64a310838..006cc3e5d 100644
--- a/vendor/modules.txt
+++ b/vendor/modules.txt
@@ -13,7 +13,7 @@ codeberg.org/gruf/go-bytesize
# codeberg.org/gruf/go-byteutil v1.1.2
## explicit; go 1.16
codeberg.org/gruf/go-byteutil
-# codeberg.org/gruf/go-cache/v3 v3.4.3
+# codeberg.org/gruf/go-cache/v3 v3.4.4
## explicit; go 1.19
codeberg.org/gruf/go-cache/v3
codeberg.org/gruf/go-cache/v3/result