From ed2477ebea4c3ceec5949821f4950db9669a4a15 Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Mon, 31 Jul 2023 11:25:29 +0100 Subject: [PATCH] [performance] cache follow, follow request and block ID lists (#2027) --- go.mod | 2 +- go.sum | 4 +- internal/api/client/blocks/blocks.go | 2 + internal/api/client/blocks/blocksget.go | 42 ++-- internal/cache/cache.go | 80 +++++- internal/cache/gts.go | 150 ++++++++++-- internal/cache/slice.go | 76 ++++++ internal/cache/util.go | 23 +- internal/config/config.go | 12 + internal/config/defaults.go | 12 + internal/config/helpers.gen.go | 229 ++++++++++++++++++ internal/db/account.go | 2 - internal/db/bundb/account.go | 40 --- internal/db/bundb/emoji.go | 16 +- internal/db/bundb/list.go | 9 +- internal/db/bundb/media.go | 44 +--- internal/db/bundb/relationship.go | 223 ++++++++++++----- internal/db/bundb/relationship_block.go | 37 ++- internal/db/bundb/relationship_follow.go | 22 +- internal/db/bundb/relationship_follow_req.go | 25 +- internal/db/bundb/status.go | 5 +- internal/db/relationship.go | 4 + internal/paging/paging.go | 227 +++++++++++++++++ internal/paging/paging_test.go | 171 +++++++++++++ internal/processing/account/delete.go | 39 ++- internal/processing/blocks.go | 100 ++++---- test/envparsing.sh | 9 + .../codeberg.org/gruf/go-cache/v3/ttl/ttl.go | 11 +- vendor/modules.txt | 2 +- 29 files changed, 1283 insertions(+), 335 deletions(-) create mode 100644 internal/cache/slice.go create mode 100644 internal/paging/paging.go create mode 100644 internal/paging/paging_test.go diff --git a/go.mod b/go.mod index 1a4d14b97..98abc64ee 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( codeberg.org/gruf/go-bytesize v1.0.2 codeberg.org/gruf/go-byteutil v1.1.2 - codeberg.org/gruf/go-cache/v3 v3.4.3 + codeberg.org/gruf/go-cache/v3 v3.4.4 codeberg.org/gruf/go-debug v1.3.0 codeberg.org/gruf/go-errors/v2 v2.2.0 codeberg.org/gruf/go-fastcopy v1.1.2 diff --git a/go.sum b/go.sum index e700364a5..19964f9f1 100644 --- a/go.sum +++ b/go.sum @@ -48,8 +48,8 @@ codeberg.org/gruf/go-bytesize v1.0.2/go.mod h1:n/GU8HzL9f3UNp/mUKyr1qVmTlj7+xacp codeberg.org/gruf/go-byteutil v1.0.0/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU= codeberg.org/gruf/go-byteutil v1.1.2 h1:TQLZtTxTNca9xEfDIndmo7nBYxeS94nrv/9DS3Nk5Tw= codeberg.org/gruf/go-byteutil v1.1.2/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU= -codeberg.org/gruf/go-cache/v3 v3.4.3 h1:GTNq01M17jUJ3B3ehrVTbElpvCqOKgz1x+VB9GEIxXA= -codeberg.org/gruf/go-cache/v3 v3.4.3/go.mod h1:pTeVPEb9DshXUkd8Dg76UcsLpU6EC/tXQ2qb+JrmxEc= +codeberg.org/gruf/go-cache/v3 v3.4.4 h1:V0A3EzjhzhULOydD16pwa2DRDwF67OuuP4ORnm//7p8= +codeberg.org/gruf/go-cache/v3 v3.4.4/go.mod h1:pTeVPEb9DshXUkd8Dg76UcsLpU6EC/tXQ2qb+JrmxEc= codeberg.org/gruf/go-debug v1.3.0 h1:PIRxQiWUFKtGOGZFdZ3Y0pqyfI0Xr87j224IYe2snZs= codeberg.org/gruf/go-debug v1.3.0/go.mod h1:N+vSy9uJBQgpQcJUqjctvqFz7tBHJf+S/PIjLILzpLg= codeberg.org/gruf/go-errors/v2 v2.0.0/go.mod h1:ZRhbdhvgoUA3Yw6e56kd9Ox984RrvbEFC2pOXyHDJP4= diff --git a/internal/api/client/blocks/blocks.go b/internal/api/client/blocks/blocks.go index bff9a068e..0eeee2bf1 100644 --- a/internal/api/client/blocks/blocks.go +++ b/internal/api/client/blocks/blocks.go @@ -30,8 +30,10 @@ // MaxIDKey is the url query for setting a max ID to return MaxIDKey = "max_id" + // SinceIDKey is the url query for returning results newer than the given ID SinceIDKey = "since_id" + // LimitKey is for specifying maximum number of results to return. LimitKey = "limit" ) diff --git a/internal/api/client/blocks/blocksget.go b/internal/api/client/blocks/blocksget.go index 7aec8b334..505c33db8 100644 --- a/internal/api/client/blocks/blocksget.go +++ b/internal/api/client/blocks/blocksget.go @@ -18,14 +18,13 @@ package blocks import ( - "fmt" "net/http" - "strconv" "github.com/gin-gonic/gin" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // BlocksGETHandler swagger:operation GET /api/v1/blocks blocksGet @@ -104,31 +103,21 @@ func (m *Module) BlocksGETHandler(c *gin.Context) { return } - maxID := "" - maxIDString := c.Query(MaxIDKey) - if maxIDString != "" { - maxID = maxIDString + limit, errWithCode := apiutil.ParseLimit(c.Query(LimitKey), 20, 100, 2) + if err != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } - sinceID := "" - sinceIDString := c.Query(SinceIDKey) - if sinceIDString != "" { - sinceID = sinceIDString - } - - limit := 20 - limitString := c.Query(LimitKey) - if limitString != "" { - i, err := strconv.ParseInt(limitString, 10, 32) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", LimitKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - limit = int(i) - } - - resp, errWithCode := m.processor.BlocksGet(c.Request.Context(), authed, maxID, sinceID, limit) + resp, errWithCode := m.processor.BlocksGet( + c.Request.Context(), + authed.Account, + paging.Pager{ + SinceID: c.Query(SinceIDKey), + MaxID: c.Query(MaxIDKey), + Limit: limit, + }, + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return @@ -137,5 +126,6 @@ func (m *Module) BlocksGETHandler(c *gin.Context) { if resp.LinkHeader != "" { c.Header("Link", resp.LinkHeader) } - c.JSON(http.StatusOK, resp.Accounts) + + c.JSON(http.StatusOK, resp.Items) } diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 63564935e..e97dce6f9 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -80,6 +80,27 @@ func (c *Caches) setuphooks() { // Invalidate account ID cached visibility. c.Visibility.Invalidate("ItemID", account.ID) c.Visibility.Invalidate("RequesterID", account.ID) + + // Invalidate this account's + // following / follower lists. + // (see FollowIDs() comment for details). + c.GTS.FollowIDs().InvalidateAll( + ">"+account.ID, + "l>"+account.ID, + "<"+account.ID, + "l<"+account.ID, + ) + + // Invalidate this account's + // follow requesting / request lists. + // (see FollowRequestIDs() comment for details). + c.GTS.FollowRequestIDs().InvalidateAll( + ">"+account.ID, + "<"+account.ID, + ) + + // Invalidate this account's block lists. + c.GTS.BlockIDs().Invalidate(account.ID) }) c.GTS.Block().SetInvalidateCallback(func(block *gtsmodel.Block) { @@ -90,6 +111,9 @@ func (c *Caches) setuphooks() { // Invalidate block target account ID cached visibility. c.Visibility.Invalidate("ItemID", block.TargetAccountID) c.Visibility.Invalidate("RequesterID", block.TargetAccountID) + + // Invalidate source account's block lists. + c.GTS.BlockIDs().Invalidate(block.AccountID) }) c.GTS.EmojiCategory().SetInvalidateCallback(func(category *gtsmodel.EmojiCategory) { @@ -98,6 +122,9 @@ func (c *Caches) setuphooks() { }) c.GTS.Follow().SetInvalidateCallback(func(follow *gtsmodel.Follow) { + // Invalidate follow request with this same ID. + c.GTS.FollowRequest().Invalidate("ID", follow.ID) + // Invalidate any related list entries. c.GTS.ListEntry().Invalidate("FollowID", follow.ID) @@ -108,19 +135,35 @@ func (c *Caches) setuphooks() { // Invalidate follow target account ID cached visibility. c.Visibility.Invalidate("ItemID", follow.TargetAccountID) c.Visibility.Invalidate("RequesterID", follow.TargetAccountID) + + // Invalidate source account's following + // lists, and destination's follwer lists. + // (see FollowIDs() comment for details). + c.GTS.FollowIDs().InvalidateAll( + ">"+follow.AccountID, + "l>"+follow.AccountID, + "<"+follow.AccountID, + "l<"+follow.AccountID, + "<"+follow.TargetAccountID, + "l<"+follow.TargetAccountID, + ">"+follow.TargetAccountID, + "l>"+follow.TargetAccountID, + ) }) c.GTS.FollowRequest().SetInvalidateCallback(func(followReq *gtsmodel.FollowRequest) { - // Invalidate follow request origin account ID cached visibility. - c.Visibility.Invalidate("ItemID", followReq.AccountID) - c.Visibility.Invalidate("RequesterID", followReq.AccountID) - - // Invalidate follow request target account ID cached visibility. - c.Visibility.Invalidate("ItemID", followReq.TargetAccountID) - c.Visibility.Invalidate("RequesterID", followReq.TargetAccountID) - - // Invalidate any cached follow with same ID. + // Invalidate follow with this same ID. c.GTS.Follow().Invalidate("ID", followReq.ID) + + // Invalidate source account's followreq + // lists, and destinations follow req lists. + // (see FollowRequestIDs() comment for details). + c.GTS.FollowRequestIDs().InvalidateAll( + ">"+followReq.AccountID, + "<"+followReq.AccountID, + ">"+followReq.TargetAccountID, + "<"+followReq.TargetAccountID, + ) }) c.GTS.List().SetInvalidateCallback(func(list *gtsmodel.List) { @@ -128,12 +171,29 @@ func (c *Caches) setuphooks() { c.GTS.ListEntry().Invalidate("ListID", list.ID) }) + c.GTS.Media().SetInvalidateCallback(func(media *gtsmodel.MediaAttachment) { + if *media.Avatar || *media.Header { + // Invalidate cache of attaching account. + c.GTS.Account().Invalidate("ID", media.AccountID) + } + + if media.StatusID != "" { + // Invalidate cache of attaching status. + c.GTS.Status().Invalidate("ID", media.StatusID) + } + }) + c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) { // Invalidate status ID cached visibility. c.Visibility.Invalidate("ItemID", status.ID) for _, id := range status.AttachmentIDs { - // Invalidate cache for attached media IDs, + // Invalidate each media by the IDs we're aware of. + // This must be done as the status table is aware of + // the media IDs in use before the media table is + // aware of the status ID they are linked to. + // + // c.GTS.Media().Invalidate("StatusID") will not work. c.GTS.Media().Invalidate("ID", id) } }) diff --git a/internal/cache/gts.go b/internal/cache/gts.go index dd43154ef..fefd02fff 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -26,29 +26,31 @@ ) type GTSCaches struct { - account *result.Cache[*gtsmodel.Account] - accountNote *result.Cache[*gtsmodel.AccountNote] - block *result.Cache[*gtsmodel.Block] - // TODO: maybe should be moved out of here since it's - // not actually doing anything with gtsmodel.DomainBlock. - domainBlock *domain.BlockCache - emoji *result.Cache[*gtsmodel.Emoji] - emojiCategory *result.Cache[*gtsmodel.EmojiCategory] - follow *result.Cache[*gtsmodel.Follow] - followRequest *result.Cache[*gtsmodel.FollowRequest] - instance *result.Cache[*gtsmodel.Instance] - list *result.Cache[*gtsmodel.List] - listEntry *result.Cache[*gtsmodel.ListEntry] - marker *result.Cache[*gtsmodel.Marker] - media *result.Cache[*gtsmodel.MediaAttachment] - mention *result.Cache[*gtsmodel.Mention] - notification *result.Cache[*gtsmodel.Notification] - report *result.Cache[*gtsmodel.Report] - status *result.Cache[*gtsmodel.Status] - statusFave *result.Cache[*gtsmodel.StatusFave] - tombstone *result.Cache[*gtsmodel.Tombstone] - user *result.Cache[*gtsmodel.User] - // TODO: move out of GTS caches since not using database models. + account *result.Cache[*gtsmodel.Account] + accountNote *result.Cache[*gtsmodel.AccountNote] + block *result.Cache[*gtsmodel.Block] + blockIDs *SliceCache[string] + domainBlock *domain.BlockCache + emoji *result.Cache[*gtsmodel.Emoji] + emojiCategory *result.Cache[*gtsmodel.EmojiCategory] + follow *result.Cache[*gtsmodel.Follow] + followIDs *SliceCache[string] + followRequest *result.Cache[*gtsmodel.FollowRequest] + followRequestIDs *SliceCache[string] + instance *result.Cache[*gtsmodel.Instance] + list *result.Cache[*gtsmodel.List] + listEntry *result.Cache[*gtsmodel.ListEntry] + marker *result.Cache[*gtsmodel.Marker] + media *result.Cache[*gtsmodel.MediaAttachment] + mention *result.Cache[*gtsmodel.Mention] + notification *result.Cache[*gtsmodel.Notification] + report *result.Cache[*gtsmodel.Report] + status *result.Cache[*gtsmodel.Status] + statusFave *result.Cache[*gtsmodel.StatusFave] + tombstone *result.Cache[*gtsmodel.Tombstone] + user *result.Cache[*gtsmodel.User] + + // TODO: move out of GTS caches since unrelated to DB. webfinger *ttl.Cache[string, string] } @@ -58,11 +60,14 @@ func (c *GTSCaches) Init() { c.initAccount() c.initAccountNote() c.initBlock() + c.initBlockIDs() c.initDomainBlock() c.initEmoji() c.initEmojiCategory() c.initFollow() + c.initFollowIDs() c.initFollowRequest() + c.initFollowRequestIDs() c.initInstance() c.initList() c.initListEntry() @@ -83,10 +88,28 @@ func (c *GTSCaches) Start() { tryStart(c.account, config.GetCacheGTSAccountSweepFreq()) tryStart(c.accountNote, config.GetCacheGTSAccountNoteSweepFreq()) tryStart(c.block, config.GetCacheGTSBlockSweepFreq()) + tryUntil("starting block IDs cache", 5, func() bool { + if sweep := config.GetCacheGTSBlockIDsSweepFreq(); sweep > 0 { + return c.blockIDs.Start(sweep) + } + return true + }) tryStart(c.emoji, config.GetCacheGTSEmojiSweepFreq()) tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStart(c.follow, config.GetCacheGTSFollowSweepFreq()) + tryUntil("starting follow IDs cache", 5, func() bool { + if sweep := config.GetCacheGTSFollowIDsSweepFreq(); sweep > 0 { + return c.followIDs.Start(sweep) + } + return true + }) tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) + tryUntil("starting follow request IDs cache", 5, func() bool { + if sweep := config.GetCacheGTSFollowRequestIDsSweepFreq(); sweep > 0 { + return c.followRequestIDs.Start(sweep) + } + return true + }) tryStart(c.instance, config.GetCacheGTSInstanceSweepFreq()) tryStart(c.list, config.GetCacheGTSListSweepFreq()) tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) @@ -112,10 +135,28 @@ func (c *GTSCaches) Stop() { tryStop(c.account, config.GetCacheGTSAccountSweepFreq()) tryStop(c.accountNote, config.GetCacheGTSAccountNoteSweepFreq()) tryStop(c.block, config.GetCacheGTSBlockSweepFreq()) + tryUntil("stopping block IDs cache", 5, func() bool { + if config.GetCacheGTSBlockIDsSweepFreq() > 0 { + return c.blockIDs.Stop() + } + return true + }) tryStop(c.emoji, config.GetCacheGTSEmojiSweepFreq()) tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStop(c.follow, config.GetCacheGTSFollowSweepFreq()) + tryUntil("stopping follow IDs cache", 5, func() bool { + if config.GetCacheGTSFollowIDsSweepFreq() > 0 { + return c.followIDs.Stop() + } + return true + }) tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) + tryUntil("stopping follow request IDs cache", 5, func() bool { + if config.GetCacheGTSFollowRequestIDsSweepFreq() > 0 { + return c.followRequestIDs.Stop() + } + return true + }) tryStop(c.instance, config.GetCacheGTSInstanceSweepFreq()) tryStop(c.list, config.GetCacheGTSListSweepFreq()) tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) @@ -128,7 +169,12 @@ func (c *GTSCaches) Stop() { tryStop(c.statusFave, config.GetCacheGTSStatusFaveSweepFreq()) tryStop(c.tombstone, config.GetCacheGTSTombstoneSweepFreq()) tryStop(c.user, config.GetCacheGTSUserSweepFreq()) - tryUntil("stopping *gtsmodel.Webfinger cache", 5, c.webfinger.Stop) + tryUntil("stopping *gtsmodel.Webfinger cache", 5, func() bool { + if config.GetCacheGTSWebfingerSweepFreq() > 0 { + return c.webfinger.Stop() + } + return true + }) } // Account provides access to the gtsmodel Account database cache. @@ -146,6 +192,11 @@ func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] { return c.block } +// FollowIDs provides access to the block IDs database cache. +func (c *GTSCaches) BlockIDs() *SliceCache[string] { + return c.blockIDs +} + // DomainBlock provides access to the domain block database cache. func (c *GTSCaches) DomainBlock() *domain.BlockCache { return c.domainBlock @@ -166,11 +217,29 @@ func (c *GTSCaches) Follow() *result.Cache[*gtsmodel.Follow] { return c.follow } +// FollowIDs provides access to the follower / following IDs database cache. +// THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS: +// - '>' for following IDs +// - 'l>' for local following IDs +// - '<' for follower IDs +// - 'l<' for local follower IDs +func (c *GTSCaches) FollowIDs() *SliceCache[string] { + return c.followIDs +} + // FollowRequest provides access to the gtsmodel FollowRequest database cache. func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] { return c.followRequest } +// FollowRequestIDs provides access to the follow requester / requesting IDs database +// cache. THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS: +// - '>' for following IDs +// - '<' for follower IDs +func (c *GTSCaches) FollowRequestIDs() *SliceCache[string] { + return c.followRequestIDs +} + // Instance provides access to the gtsmodel Instance database cache. func (c *GTSCaches) Instance() *result.Cache[*gtsmodel.Instance] { return c.instance @@ -274,6 +343,8 @@ func (c *GTSCaches) initBlock() { {Name: "ID"}, {Name: "URI"}, {Name: "AccountID.TargetAccountID"}, + {Name: "AccountID", Multi: true}, + {Name: "TargetAccountID", Multi: true}, }, func(b1 *gtsmodel.Block) *gtsmodel.Block { b2 := new(gtsmodel.Block) *b2 = *b1 @@ -283,6 +354,14 @@ func (c *GTSCaches) initBlock() { c.block.IgnoreErrors(ignoreErrors) } +func (c *GTSCaches) initBlockIDs() { + c.blockIDs = &SliceCache[string]{Cache: ttl.New[string, []string]( + 0, + config.GetCacheGTSBlockIDsMaxSize(), + config.GetCacheGTSBlockIDsTTL(), + )} +} + func (c *GTSCaches) initDomainBlock() { c.domainBlock = new(domain.BlockCache) } @@ -321,6 +400,8 @@ func (c *GTSCaches) initFollow() { {Name: "ID"}, {Name: "URI"}, {Name: "AccountID.TargetAccountID"}, + {Name: "AccountID", Multi: true}, + {Name: "TargetAccountID", Multi: true}, }, func(f1 *gtsmodel.Follow) *gtsmodel.Follow { f2 := new(gtsmodel.Follow) *f2 = *f1 @@ -329,11 +410,21 @@ func (c *GTSCaches) initFollow() { c.follow.SetTTL(config.GetCacheGTSFollowTTL(), true) } +func (c *GTSCaches) initFollowIDs() { + c.followIDs = &SliceCache[string]{Cache: ttl.New[string, []string]( + 0, + config.GetCacheGTSFollowIDsMaxSize(), + config.GetCacheGTSFollowIDsTTL(), + )} +} + func (c *GTSCaches) initFollowRequest() { c.followRequest = result.New([]result.Lookup{ {Name: "ID"}, {Name: "URI"}, {Name: "AccountID.TargetAccountID"}, + {Name: "AccountID", Multi: true}, + {Name: "TargetAccountID", Multi: true}, }, func(f1 *gtsmodel.FollowRequest) *gtsmodel.FollowRequest { f2 := new(gtsmodel.FollowRequest) *f2 = *f1 @@ -342,6 +433,14 @@ func (c *GTSCaches) initFollowRequest() { c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true) } +func (c *GTSCaches) initFollowRequestIDs() { + c.followRequestIDs = &SliceCache[string]{Cache: ttl.New[string, []string]( + 0, + config.GetCacheGTSFollowRequestIDsMaxSize(), + config.GetCacheGTSFollowRequestIDsTTL(), + )} +} + func (c *GTSCaches) initInstance() { c.instance = result.New([]result.Lookup{ {Name: "ID"}, @@ -502,5 +601,6 @@ func (c *GTSCaches) initWebfinger() { c.webfinger = ttl.New[string, string]( 0, config.GetCacheGTSWebfingerMaxSize(), - config.GetCacheGTSWebfingerTTL()) + config.GetCacheGTSWebfingerTTL(), + ) } diff --git a/internal/cache/slice.go b/internal/cache/slice.go new file mode 100644 index 000000000..194f20d4b --- /dev/null +++ b/internal/cache/slice.go @@ -0,0 +1,76 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "codeberg.org/gruf/go-cache/v3/ttl" + "golang.org/x/exp/slices" +) + +// SliceCache wraps a ttl.Cache to provide simple loader-callback +// functions for fetching + caching slices of objects (e.g. IDs). +type SliceCache[T any] struct { + *ttl.Cache[string, []T] +} + +// Load will attempt to load an existing slice from the cache for the given key, else calling the provided load function and caching the result. +func (c *SliceCache[T]) Load(key string, load func() ([]T, error)) ([]T, error) { + // Look for follow IDs list in cache under this key. + data, ok := c.Get(key) + + if !ok { + var err error + + // Not cached, load! + data, err = load() + if err != nil { + return nil, err + } + + // Store the data. + c.Set(key, data) + } + + // Return data clone for safety. + return slices.Clone(data), nil +} + +// LoadRange is functionally the same as .Load(), but will pass the result through provided reslice function before returning a cloned result. +func (c *SliceCache[T]) LoadRange(key string, load func() ([]T, error), reslice func([]T) []T) ([]T, error) { + // Look for follow IDs list in cache under this key. + data, ok := c.Get(key) + + if !ok { + var err error + + // Not cached, load! + data, err = load() + if err != nil { + return nil, err + } + + // Store the data. + c.Set(key, data) + } + + // Reslice to range. + slice := reslice(data) + + // Return range clone for safety. + return slices.Clone(slice), nil +} diff --git a/internal/cache/util.go b/internal/cache/util.go index a0adfd366..f2357c904 100644 --- a/internal/cache/util.go +++ b/internal/cache/util.go @@ -18,28 +18,33 @@ package cache import ( - "context" + "database/sql" "errors" "fmt" "time" "codeberg.org/gruf/go-cache/v3/result" errorsv2 "codeberg.org/gruf/go-errors/v2" + "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/log" ) -// SentinelError is returned to indicate a non-permanent error return, -// i.e. a situation in which we do not want a cache a negative result. +// SentinelError is an error that can be returned and checked against to indicate a non-permanent +// error return from a cache loader callback, e.g. a temporary situation that will soon be fixed. var SentinelError = errors.New("BUG: error should not be returned") //nolint:revive -// ignoreErrors is an error ignoring function capable of being passed to -// caches, which specifically catches and ignores our sentinel error type. +// ignoreErrors is an error matching function used to signal which errors +// the result caches should NOT hold onto. these amount to anything non-permanent. func ignoreErrors(err error) bool { - return errorsv2.Comparable( + return !errorsv2.Comparable( err, - SentinelError, - context.DeadlineExceeded, - context.Canceled, + + // the only cacheable errs, + // i.e anything permanent + // (until invalidation). + db.ErrNoEntries, + db.ErrAlreadyExists, + sql.ErrNoRows, ) } diff --git a/internal/config/config.go b/internal/config/config.go index bd9fc468c..99b07358e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -194,6 +194,10 @@ type GTSCacheConfiguration struct { BlockTTL time.Duration `name:"block-ttl"` BlockSweepFreq time.Duration `name:"block-sweep-freq"` + BlockIDsMaxSize int `name:"block-ids-max-size"` + BlockIDsTTL time.Duration `name:"block-ids-ttl"` + BlockIDsSweepFreq time.Duration `name:"block-ids-sweep-freq"` + DomainBlockMaxSize int `name:"domain-block-max-size"` DomainBlockTTL time.Duration `name:"domain-block-ttl"` DomainBlockSweepFreq time.Duration `name:"domain-block-sweep-freq"` @@ -210,10 +214,18 @@ type GTSCacheConfiguration struct { FollowTTL time.Duration `name:"follow-ttl"` FollowSweepFreq time.Duration `name:"follow-sweep-freq"` + FollowIDsMaxSize int `name:"follow-ids-max-size"` + FollowIDsTTL time.Duration `name:"follow-ids-ttl"` + FollowIDsSweepFreq time.Duration `name:"follow-ids-sweep-freq"` + FollowRequestMaxSize int `name:"follow-request-max-size"` FollowRequestTTL time.Duration `name:"follow-request-ttl"` FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"` + FollowRequestIDsMaxSize int `name:"follow-request-ids-max-size"` + FollowRequestIDsTTL time.Duration `name:"follow-request-ids-ttl"` + FollowRequestIDsSweepFreq time.Duration `name:"follow-request-ids-sweep-freq"` + InstanceMaxSize int `name:"instance-max-size"` InstanceTTL time.Duration `name:"instance-ttl"` InstanceSweepFreq time.Duration `name:"instance-sweep-freq"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index ee20fb6a7..cb37838c1 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -139,6 +139,10 @@ BlockTTL: time.Minute * 30, BlockSweepFreq: time.Minute, + BlockIDsMaxSize: 500, + BlockIDsTTL: time.Minute * 30, + BlockIDsSweepFreq: time.Minute, + DomainBlockMaxSize: 2000, DomainBlockTTL: time.Hour * 24, DomainBlockSweepFreq: time.Minute, @@ -155,10 +159,18 @@ FollowTTL: time.Minute * 30, FollowSweepFreq: time.Minute, + FollowIDsMaxSize: 500, + FollowIDsTTL: time.Minute * 30, + FollowIDsSweepFreq: time.Minute, + FollowRequestMaxSize: 2000, FollowRequestTTL: time.Minute * 30, FollowRequestSweepFreq: time.Minute, + FollowRequestIDsMaxSize: 500, + FollowRequestIDsTTL: time.Minute * 30, + FollowRequestIDsSweepFreq: time.Minute, + InstanceMaxSize: 2000, InstanceTTL: time.Minute * 30, InstanceSweepFreq: time.Minute, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index 5eed1b468..1bf8ec2bc 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2624,6 +2624,81 @@ func GetCacheGTSBlockSweepFreq() time.Duration { return global.GetCacheGTSBlockS // SetCacheGTSBlockSweepFreq safely sets the value for global configuration 'Cache.GTS.BlockSweepFreq' field func SetCacheGTSBlockSweepFreq(v time.Duration) { global.SetCacheGTSBlockSweepFreq(v) } +// GetCacheGTSBlockIDsMaxSize safely fetches the Configuration value for state's 'Cache.GTS.BlockIDsMaxSize' field +func (st *ConfigState) GetCacheGTSBlockIDsMaxSize() (v int) { + st.mutex.RLock() + v = st.config.Cache.GTS.BlockIDsMaxSize + st.mutex.RUnlock() + return +} + +// SetCacheGTSBlockIDsMaxSize safely sets the Configuration value for state's 'Cache.GTS.BlockIDsMaxSize' field +func (st *ConfigState) SetCacheGTSBlockIDsMaxSize(v int) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.BlockIDsMaxSize = v + st.reloadToViper() +} + +// CacheGTSBlockIDsMaxSizeFlag returns the flag name for the 'Cache.GTS.BlockIDsMaxSize' field +func CacheGTSBlockIDsMaxSizeFlag() string { return "cache-gts-block-ids-max-size" } + +// GetCacheGTSBlockIDsMaxSize safely fetches the value for global configuration 'Cache.GTS.BlockIDsMaxSize' field +func GetCacheGTSBlockIDsMaxSize() int { return global.GetCacheGTSBlockIDsMaxSize() } + +// SetCacheGTSBlockIDsMaxSize safely sets the value for global configuration 'Cache.GTS.BlockIDsMaxSize' field +func SetCacheGTSBlockIDsMaxSize(v int) { global.SetCacheGTSBlockIDsMaxSize(v) } + +// GetCacheGTSBlockIDsTTL safely fetches the Configuration value for state's 'Cache.GTS.BlockIDsTTL' field +func (st *ConfigState) GetCacheGTSBlockIDsTTL() (v time.Duration) { + st.mutex.RLock() + v = st.config.Cache.GTS.BlockIDsTTL + st.mutex.RUnlock() + return +} + +// SetCacheGTSBlockIDsTTL safely sets the Configuration value for state's 'Cache.GTS.BlockIDsTTL' field +func (st *ConfigState) SetCacheGTSBlockIDsTTL(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.BlockIDsTTL = v + st.reloadToViper() +} + +// CacheGTSBlockIDsTTLFlag returns the flag name for the 'Cache.GTS.BlockIDsTTL' field +func CacheGTSBlockIDsTTLFlag() string { return "cache-gts-block-ids-ttl" } + +// GetCacheGTSBlockIDsTTL safely fetches the value for global configuration 'Cache.GTS.BlockIDsTTL' field +func GetCacheGTSBlockIDsTTL() time.Duration { return global.GetCacheGTSBlockIDsTTL() } + +// SetCacheGTSBlockIDsTTL safely sets the value for global configuration 'Cache.GTS.BlockIDsTTL' field +func SetCacheGTSBlockIDsTTL(v time.Duration) { global.SetCacheGTSBlockIDsTTL(v) } + +// GetCacheGTSBlockIDsSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.BlockIDsSweepFreq' field +func (st *ConfigState) GetCacheGTSBlockIDsSweepFreq() (v time.Duration) { + st.mutex.RLock() + v = st.config.Cache.GTS.BlockIDsSweepFreq + st.mutex.RUnlock() + return +} + +// SetCacheGTSBlockIDsSweepFreq safely sets the Configuration value for state's 'Cache.GTS.BlockIDsSweepFreq' field +func (st *ConfigState) SetCacheGTSBlockIDsSweepFreq(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.BlockIDsSweepFreq = v + st.reloadToViper() +} + +// CacheGTSBlockIDsSweepFreqFlag returns the flag name for the 'Cache.GTS.BlockIDsSweepFreq' field +func CacheGTSBlockIDsSweepFreqFlag() string { return "cache-gts-block-ids-sweep-freq" } + +// GetCacheGTSBlockIDsSweepFreq safely fetches the value for global configuration 'Cache.GTS.BlockIDsSweepFreq' field +func GetCacheGTSBlockIDsSweepFreq() time.Duration { return global.GetCacheGTSBlockIDsSweepFreq() } + +// SetCacheGTSBlockIDsSweepFreq safely sets the value for global configuration 'Cache.GTS.BlockIDsSweepFreq' field +func SetCacheGTSBlockIDsSweepFreq(v time.Duration) { global.SetCacheGTSBlockIDsSweepFreq(v) } + // GetCacheGTSDomainBlockMaxSize safely fetches the Configuration value for state's 'Cache.GTS.DomainBlockMaxSize' field func (st *ConfigState) GetCacheGTSDomainBlockMaxSize() (v int) { st.mutex.RLock() @@ -2926,6 +3001,81 @@ func GetCacheGTSFollowSweepFreq() time.Duration { return global.GetCacheGTSFollo // SetCacheGTSFollowSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowSweepFreq' field func SetCacheGTSFollowSweepFreq(v time.Duration) { global.SetCacheGTSFollowSweepFreq(v) } +// GetCacheGTSFollowIDsMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowIDsMaxSize' field +func (st *ConfigState) GetCacheGTSFollowIDsMaxSize() (v int) { + st.mutex.RLock() + v = st.config.Cache.GTS.FollowIDsMaxSize + st.mutex.RUnlock() + return +} + +// SetCacheGTSFollowIDsMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowIDsMaxSize' field +func (st *ConfigState) SetCacheGTSFollowIDsMaxSize(v int) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.FollowIDsMaxSize = v + st.reloadToViper() +} + +// CacheGTSFollowIDsMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowIDsMaxSize' field +func CacheGTSFollowIDsMaxSizeFlag() string { return "cache-gts-follow-ids-max-size" } + +// GetCacheGTSFollowIDsMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowIDsMaxSize' field +func GetCacheGTSFollowIDsMaxSize() int { return global.GetCacheGTSFollowIDsMaxSize() } + +// SetCacheGTSFollowIDsMaxSize safely sets the value for global configuration 'Cache.GTS.FollowIDsMaxSize' field +func SetCacheGTSFollowIDsMaxSize(v int) { global.SetCacheGTSFollowIDsMaxSize(v) } + +// GetCacheGTSFollowIDsTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowIDsTTL' field +func (st *ConfigState) GetCacheGTSFollowIDsTTL() (v time.Duration) { + st.mutex.RLock() + v = st.config.Cache.GTS.FollowIDsTTL + st.mutex.RUnlock() + return +} + +// SetCacheGTSFollowIDsTTL safely sets the Configuration value for state's 'Cache.GTS.FollowIDsTTL' field +func (st *ConfigState) SetCacheGTSFollowIDsTTL(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.FollowIDsTTL = v + st.reloadToViper() +} + +// CacheGTSFollowIDsTTLFlag returns the flag name for the 'Cache.GTS.FollowIDsTTL' field +func CacheGTSFollowIDsTTLFlag() string { return "cache-gts-follow-ids-ttl" } + +// GetCacheGTSFollowIDsTTL safely fetches the value for global configuration 'Cache.GTS.FollowIDsTTL' field +func GetCacheGTSFollowIDsTTL() time.Duration { return global.GetCacheGTSFollowIDsTTL() } + +// SetCacheGTSFollowIDsTTL safely sets the value for global configuration 'Cache.GTS.FollowIDsTTL' field +func SetCacheGTSFollowIDsTTL(v time.Duration) { global.SetCacheGTSFollowIDsTTL(v) } + +// GetCacheGTSFollowIDsSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowIDsSweepFreq' field +func (st *ConfigState) GetCacheGTSFollowIDsSweepFreq() (v time.Duration) { + st.mutex.RLock() + v = st.config.Cache.GTS.FollowIDsSweepFreq + st.mutex.RUnlock() + return +} + +// SetCacheGTSFollowIDsSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowIDsSweepFreq' field +func (st *ConfigState) SetCacheGTSFollowIDsSweepFreq(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.FollowIDsSweepFreq = v + st.reloadToViper() +} + +// CacheGTSFollowIDsSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowIDsSweepFreq' field +func CacheGTSFollowIDsSweepFreqFlag() string { return "cache-gts-follow-ids-sweep-freq" } + +// GetCacheGTSFollowIDsSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowIDsSweepFreq' field +func GetCacheGTSFollowIDsSweepFreq() time.Duration { return global.GetCacheGTSFollowIDsSweepFreq() } + +// SetCacheGTSFollowIDsSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowIDsSweepFreq' field +func SetCacheGTSFollowIDsSweepFreq(v time.Duration) { global.SetCacheGTSFollowIDsSweepFreq(v) } + // GetCacheGTSFollowRequestMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestMaxSize' field func (st *ConfigState) GetCacheGTSFollowRequestMaxSize() (v int) { st.mutex.RLock() @@ -3003,6 +3153,85 @@ func GetCacheGTSFollowRequestSweepFreq() time.Duration { // SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) } +// GetCacheGTSFollowRequestIDsMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestIDsMaxSize' field +func (st *ConfigState) GetCacheGTSFollowRequestIDsMaxSize() (v int) { + st.mutex.RLock() + v = st.config.Cache.GTS.FollowRequestIDsMaxSize + st.mutex.RUnlock() + return +} + +// SetCacheGTSFollowRequestIDsMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowRequestIDsMaxSize' field +func (st *ConfigState) SetCacheGTSFollowRequestIDsMaxSize(v int) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.FollowRequestIDsMaxSize = v + st.reloadToViper() +} + +// CacheGTSFollowRequestIDsMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowRequestIDsMaxSize' field +func CacheGTSFollowRequestIDsMaxSizeFlag() string { return "cache-gts-follow-request-ids-max-size" } + +// GetCacheGTSFollowRequestIDsMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowRequestIDsMaxSize' field +func GetCacheGTSFollowRequestIDsMaxSize() int { return global.GetCacheGTSFollowRequestIDsMaxSize() } + +// SetCacheGTSFollowRequestIDsMaxSize safely sets the value for global configuration 'Cache.GTS.FollowRequestIDsMaxSize' field +func SetCacheGTSFollowRequestIDsMaxSize(v int) { global.SetCacheGTSFollowRequestIDsMaxSize(v) } + +// GetCacheGTSFollowRequestIDsTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestIDsTTL' field +func (st *ConfigState) GetCacheGTSFollowRequestIDsTTL() (v time.Duration) { + st.mutex.RLock() + v = st.config.Cache.GTS.FollowRequestIDsTTL + st.mutex.RUnlock() + return +} + +// SetCacheGTSFollowRequestIDsTTL safely sets the Configuration value for state's 'Cache.GTS.FollowRequestIDsTTL' field +func (st *ConfigState) SetCacheGTSFollowRequestIDsTTL(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.FollowRequestIDsTTL = v + st.reloadToViper() +} + +// CacheGTSFollowRequestIDsTTLFlag returns the flag name for the 'Cache.GTS.FollowRequestIDsTTL' field +func CacheGTSFollowRequestIDsTTLFlag() string { return "cache-gts-follow-request-ids-ttl" } + +// GetCacheGTSFollowRequestIDsTTL safely fetches the value for global configuration 'Cache.GTS.FollowRequestIDsTTL' field +func GetCacheGTSFollowRequestIDsTTL() time.Duration { return global.GetCacheGTSFollowRequestIDsTTL() } + +// SetCacheGTSFollowRequestIDsTTL safely sets the value for global configuration 'Cache.GTS.FollowRequestIDsTTL' field +func SetCacheGTSFollowRequestIDsTTL(v time.Duration) { global.SetCacheGTSFollowRequestIDsTTL(v) } + +// GetCacheGTSFollowRequestIDsSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestIDsSweepFreq' field +func (st *ConfigState) GetCacheGTSFollowRequestIDsSweepFreq() (v time.Duration) { + st.mutex.RLock() + v = st.config.Cache.GTS.FollowRequestIDsSweepFreq + st.mutex.RUnlock() + return +} + +// SetCacheGTSFollowRequestIDsSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowRequestIDsSweepFreq' field +func (st *ConfigState) SetCacheGTSFollowRequestIDsSweepFreq(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.FollowRequestIDsSweepFreq = v + st.reloadToViper() +} + +// CacheGTSFollowRequestIDsSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowRequestIDsSweepFreq' field +func CacheGTSFollowRequestIDsSweepFreqFlag() string { return "cache-gts-follow-request-ids-sweep-freq" } + +// GetCacheGTSFollowRequestIDsSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowRequestIDsSweepFreq' field +func GetCacheGTSFollowRequestIDsSweepFreq() time.Duration { + return global.GetCacheGTSFollowRequestIDsSweepFreq() +} + +// SetCacheGTSFollowRequestIDsSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestIDsSweepFreq' field +func SetCacheGTSFollowRequestIDsSweepFreq(v time.Duration) { + global.SetCacheGTSFollowRequestIDsSweepFreq(v) +} + // GetCacheGTSInstanceMaxSize safely fetches the Configuration value for state's 'Cache.GTS.InstanceMaxSize' field func (st *ConfigState) GetCacheGTSInstanceMaxSize() (v int) { st.mutex.RLock() diff --git a/internal/db/account.go b/internal/db/account.go index 21b8d6a1f..505ca4004 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -104,8 +104,6 @@ type Account interface { // In the case of no statuses, this function will return db.ErrNoEntries. GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) - GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) - // GetAccountLastPosted simply gets the timestamp of the most recent post by the account. // // If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned. diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 2ef1618db..e57c01a82 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -694,46 +694,6 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, return a.statusesFromIDs(ctx, statusIDs) } -func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) { - blocks := []*gtsmodel.Block{} - - fq := a.db. - NewSelect(). - Model(&blocks). - Where("? = ?", bun.Ident("block.account_id"), accountID). - Relation("TargetAccount"). - Order("block.id DESC") - - if maxID != "" { - fq = fq.Where("? < ?", bun.Ident("block.id"), maxID) - } - - if sinceID != "" { - fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID) - } - - if limit > 0 { - fq = fq.Limit(limit) - } - - if err := fq.Scan(ctx); err != nil { - return nil, "", "", a.db.ProcessError(err) - } - - if len(blocks) == 0 { - return nil, "", "", db.ErrNoEntries - } - - accounts := []*gtsmodel.Account{} - for _, b := range blocks { - accounts = append(accounts, b.TargetAccount) - } - - nextMaxID := blocks[len(blocks)-1].ID - prevMinID := blocks[0].ID - return accounts, nextMaxID, prevMinID, nil -} - func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) { // Catch case of no statuses early if len(statusIDs) == 0 { diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 90bcd134d..04f22b6e9 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -126,16 +126,12 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error { return err } - // Prepare SELECT accounts query. - aq := tx.NewSelect(). - Table("accounts"). - Column("id") - - // Append a WHERE LIKE clause to the query + // Prepare a SELECT query with a WHERE LIKE // that checks the `emoji` column for any // text containing this specific emoji ID. // // (see GetStatusesUsingEmoji() for details.) + aq := tx.NewSelect().Table("accounts").Column("id") aq = whereLike(aq, "emojis", id) // Select all accounts using this emoji into accountIDss. @@ -170,16 +166,12 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error { } } - // Prepare SELECT statuses query. - sq := tx.NewSelect(). - Table("statuses"). - Column("id") - - // Append a WHERE LIKE clause to the query + // Prepare a SELECT query with a WHERE LIKE // that checks the `emoji` column for any // text containing this specific emoji ID. // // (see GetStatusesUsingEmoji() for details.) + sq := tx.NewSelect().Table("statuses").Column("id") sq = whereLike(sq, "emojis", id) // Select all statuses using this emoji into statusIDs. diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 25bb3a65d..70faf837a 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -189,11 +189,10 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error { gtscontext.SetBarebones(ctx), id, ) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - // Already gone. - return nil - } + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // NOTE: even if db.ErrNoEntries is returned, we + // still run the below transaction to ensure related + // objects are appropriately deleted. return err } diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index 3b885af61..b8120b87a 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -106,8 +106,6 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt } func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { - defer m.state.Caches.GTS.Media().Invalidate("ID", id) - // Load media into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -120,10 +118,8 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { return err } - var ( - invalidateAccount bool - invalidateStatus bool - ) + // On return, ensure that media with ID is invalidated. + defer m.state.Caches.GTS.Media().Invalidate("ID", id) // Delete media attachment in new transaction. err = m.db.RunInTx(ctx, func(tx bun.Tx) error { @@ -161,9 +157,6 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { if _, err := set(q).Exec(ctx); err != nil { return gtserror.Newf("error updating account: %w", err) } - - // Mark as needing invalidate. - invalidateAccount = true } } @@ -178,33 +171,18 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { return gtserror.Newf("error selecting status: %w", err) } - // Get length of attachments beforehand. - before := len(status.AttachmentIDs) - - for i := 0; i < len(status.AttachmentIDs); { - if status.AttachmentIDs[i] == id { - // Remove this reference to deleted attachment ID. - copy(status.AttachmentIDs[i:], status.AttachmentIDs[i+1:]) - status.AttachmentIDs = status.AttachmentIDs[:len(status.AttachmentIDs)-1] - continue - } - i++ - } - - if before != len(status.AttachmentIDs) { - // Note: this accounts for status not found. + if updatedIDs := dropID(status.AttachmentIDs, id); // nocollapse + len(updatedIDs) != len(status.AttachmentIDs) { + // Note: this handles not found. // // Attachments changed, update the status. if _, err := tx.NewUpdate(). Table("statuses"). Where("? = ?", bun.Ident("id"), status.ID). - Set("? = ?", bun.Ident("attachment_ids"), status.AttachmentIDs). + Set("? = ?", bun.Ident("attachment_ids"), updatedIDs). Exec(ctx); err != nil { return gtserror.Newf("error updating status: %w", err) } - - // Mark as needing invalidate. - invalidateStatus = true } } @@ -219,16 +197,6 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { return nil }) - if invalidateAccount { - // The account for given ID will have been updated in transaction. - m.state.Caches.GTS.Account().Invalidate("ID", media.AccountID) - } - - if invalidateStatus { - // The status for given ID will have been updated in transaction. - m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID) - } - return m.db.ProcessError(err) } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index eddd73b49..e7b563f2e 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -20,11 +20,12 @@ import ( "context" "errors" - "fmt" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/uptrace/bun" ) @@ -45,7 +46,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount targetAccount, ) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err) + return nil, gtserror.Newf("error fetching follow: %w", err) } if follow != nil { @@ -61,7 +62,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount requestingAccount, ) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err) + return nil, gtserror.Newf("error checking followedBy: %w", err) } // check if requesting has follow requested target @@ -70,19 +71,19 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount targetAccount, ) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err) + return nil, gtserror.Newf("error checking requested: %w", err) } // check if the requesting account is blocking the target account rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err) + return nil, gtserror.Newf("error checking blocking: %w", err) } // check if the requesting account is blocked by the target account rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err) + return nil, gtserror.Newf("error checking blockedBy: %w", err) } // retrieve a note by the requesting account on the target account, if there is one @@ -92,7 +93,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount targetAccount, ) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, fmt.Errorf("GetRelationship: error fetching note: %w", err) + return nil, gtserror.Newf("error fetching note: %w", err) } if note != nil { rel.Note = note.Comment @@ -102,87 +103,186 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount } func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - var followIDs []string - if err := newSelectFollows(r.db, accountID). - Scan(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + followIDs, err := r.getAccountFollowIDs(ctx, accountID) + if err != nil { + return nil, err } return r.GetFollowsByIDs(ctx, followIDs) } func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - var followIDs []string - if err := newSelectLocalFollows(r.db, accountID). - Scan(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) + if err != nil { + return nil, err } return r.GetFollowsByIDs(ctx, followIDs) } func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - var followIDs []string - if err := newSelectFollowers(r.db, accountID). - Scan(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) + if err != nil { + return nil, err } - return r.GetFollowsByIDs(ctx, followIDs) + return r.GetFollowsByIDs(ctx, followerIDs) } func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - var followIDs []string - if err := newSelectLocalFollowers(r.db, accountID). - Scan(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) + if err != nil { + return nil, err } - return r.GetFollowsByIDs(ctx, followIDs) -} - -func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollows(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) -} - -func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { - n, err := newSelectLocalFollows(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) -} - -func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollowers(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) -} - -func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { - n, err := newSelectLocalFollowers(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) + return r.GetFollowsByIDs(ctx, followerIDs) } func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { - var followReqIDs []string - if err := newSelectFollowRequests(r.db, accountID). - Scan(ctx, &followReqIDs); err != nil { - return nil, r.db.ProcessError(err) + followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) + if err != nil { + return nil, err } return r.GetFollowRequestsByIDs(ctx, followReqIDs) } func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { - var followReqIDs []string - if err := newSelectFollowRequesting(r.db, accountID). - Scan(ctx, &followReqIDs); err != nil { - return nil, r.db.ProcessError(err) + followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) + if err != nil { + return nil, err } return r.GetFollowRequestsByIDs(ctx, followReqIDs) } +func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Pager) ([]*gtsmodel.Block, error) { + // Load block IDs from cache with database loader callback. + blockIDs, err := r.state.Caches.GTS.BlockIDs().LoadRange(accountID, func() ([]string, error) { + var blockIDs []string + + // Block IDs not in cache, perform DB query! + q := newSelectBlocks(r.db, accountID) + if _, err := q.Exec(ctx, &blockIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return blockIDs, nil + }, page.PageDesc) + if err != nil { + return nil, err + } + + // Convert these IDs to full block objects. + return r.GetBlocksByIDs(ctx, blockIDs) +} + +func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { + followIDs, err := r.getAccountFollowIDs(ctx, accountID) + return len(followIDs), err +} + +func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { + followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) + return len(followIDs), err +} + +func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { + followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) + return len(followerIDs), err +} + +func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { + followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) + return len(followerIDs), err +} + func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollowRequests(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) + followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) + return len(followReqIDs), err } func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollowRequesting(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) + followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) + return len(followReqIDs), err +} + +func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) { + var followIDs []string + + // Follow IDs not in cache, perform DB query! + q := newSelectFollows(r.db, accountID) + if _, err := q.Exec(ctx, &followIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followIDs, nil + }) +} + +func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) { + var followIDs []string + + // Follow IDs not in cache, perform DB query! + q := newSelectLocalFollows(r.db, accountID) + if _, err := q.Exec(ctx, &followIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followIDs, nil + }) +} + +func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) { + var followIDs []string + + // Follow IDs not in cache, perform DB query! + q := newSelectFollowers(r.db, accountID) + if _, err := q.Exec(ctx, &followIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followIDs, nil + }) +} + +func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) { + var followIDs []string + + // Follow IDs not in cache, perform DB query! + q := newSelectLocalFollowers(r.db, accountID) + if _, err := q.Exec(ctx, &followIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followIDs, nil + }) +} + +func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) { + var followReqIDs []string + + // Follow request IDs not in cache, perform DB query! + q := newSelectFollowRequests(r.db, accountID) + if _, err := q.Exec(ctx, &followReqIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followReqIDs, nil + }) +} + +func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) { + var followReqIDs []string + + // Follow request IDs not in cache, perform DB query! + q := newSelectFollowRequesting(r.db, accountID) + if _, err := q.Exec(ctx, &followReqIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followReqIDs, nil + }) } // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. @@ -256,3 +356,12 @@ func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery { ). OrderExpr("? DESC", bun.Ident("updated_at")) } + +// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. +func newSelectBlocks(db *WrappedDB, accountID string) *bun.SelectQuery { + return db.NewSelect(). + TableExpr("?", bun.Ident("blocks")). + ColumnExpr("?", bun.Ident("?")). + Where("? = ?", bun.Ident("account_id"), accountID). + OrderExpr("? DESC", bun.Ident("updated_at")) +} diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go index 948e82fcb..2a042bed4 100644 --- a/internal/db/bundb/relationship_block.go +++ b/internal/db/bundb/relationship_block.go @@ -25,6 +25,7 @@ "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/uptrace/bun" ) @@ -97,6 +98,25 @@ func(block *gtsmodel.Block) error { ) } +func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) { + // Preallocate slice of expected length. + blocks := make([]*gtsmodel.Block, 0, len(ids)) + + for _, id := range ids { + // Fetch block model for this ID. + block, err := r.GetBlockByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error getting block %q: %v", id, err) + continue + } + + // Append to return slice. + blocks = append(blocks, block) + } + + return blocks, nil +} + func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) { // Fetch block from cache with loader callback block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) { @@ -148,8 +168,6 @@ func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) er } func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { - defer r.state.Caches.GTS.Block().Invalidate("ID", id) - // Load block into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -162,6 +180,9 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { return err } + // Drop this now-cached block on return after delete. + defer r.state.Caches.GTS.Block().Invalidate("ID", id) + // Finally delete block from DB. _, err = r.db.NewDelete(). Table("blocks"). @@ -171,8 +192,6 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { } func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error { - defer r.state.Caches.GTS.Block().Invalidate("URI", uri) - // Load block into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -185,6 +204,9 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error return err } + // Drop this now-cached block on return after delete. + defer r.state.Caches.GTS.Block().Invalidate("URI", uri) + // Finally delete block from DB. _, err = r.db.NewDelete(). Table("blocks"). @@ -211,10 +233,9 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri } defer func() { - // Invalidate all IDs on return. - for _, id := range blockIDs { - r.state.Caches.GTS.Block().Invalidate("ID", id) - } + // Invalidate all account's incoming / outoing blocks on return. + r.state.Caches.GTS.Block().Invalidate("AccountID", accountID) + r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID) }() // Load all blocks into cache, this *really* isn't great diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 84501b0be..3b0597612 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -233,8 +233,6 @@ func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error { } func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID string, targetAccountID string) error { - defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) - // Load follow into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -251,13 +249,14 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin return err } + // Drop this now-cached follow on return after delete. + defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) + // Finally delete follow from DB. return r.deleteFollow(ctx, follow.ID) } func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error { - defer r.state.Caches.GTS.Follow().Invalidate("ID", id) - // Load follow into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -270,13 +269,14 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error return err } + // Drop this now-cached follow on return after delete. + defer r.state.Caches.GTS.Follow().Invalidate("ID", id) + // Finally delete follow from DB. return r.deleteFollow(ctx, follow.ID) } func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error { - defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) - // Load follow into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -289,6 +289,9 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro return err } + // Drop this now-cached follow on return after delete. + defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) + // Finally delete follow from DB. return r.deleteFollow(ctx, follow.ID) } @@ -312,10 +315,9 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str } defer func() { - // Invalidate all IDs on return. - for _, id := range followIDs { - r.state.Caches.GTS.Follow().Invalidate("ID", id) - } + // Invalidate all account's incoming / outoing follows on return. + r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID) + r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID) }() // Load all follows into cache, this *really* isn't great diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go index a6e913953..dc5e760e6 100644 --- a/internal/db/bundb/relationship_follow_req.go +++ b/internal/db/bundb/relationship_follow_req.go @@ -208,9 +208,6 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI return nil, err } - // Invalidate follow request from cache lookups on return. - defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID) - // Delete original follow request. if _, err := r.db. NewDelete(). @@ -243,8 +240,6 @@ func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountI } func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) error { - defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) - // Load followreq into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -261,6 +256,9 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI return err } + // Drop this now-cached follow request on return after delete. + defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) + // Finally delete followreq from DB. _, err = r.db.NewDelete(). Table("follow_requests"). @@ -270,8 +268,6 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI } func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error { - defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) - // Load followreq into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -284,6 +280,9 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) return err } + // Drop this now-cached follow request on return after delete. + defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) + // Finally delete followreq from DB. _, err = r.db.NewDelete(). Table("follow_requests"). @@ -293,8 +292,6 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) } func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error { - defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) - // Load followreq into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -307,6 +304,9 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin return err } + // Drop this now-cached follow request on return after delete. + defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) + // Finally delete followreq from DB. _, err = r.db.NewDelete(). Table("follow_requests"). @@ -334,10 +334,9 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun } defer func() { - // Invalidate all IDs on return. - for _, id := range followReqIDs { - r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) - } + // Invalidate all account's incoming / outoing follow requests on return. + r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID) + r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID) }() // Load all followreqs into cache, this *really* isn't diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index a019216d0..4dc7d8468 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -381,8 +381,6 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co } func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error { - defer s.state.Caches.GTS.Status().Invalidate("ID", id) - // Load status into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. @@ -397,6 +395,9 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error { return err } + // On return ensure status invalidated from cache. + defer s.state.Caches.GTS.Status().Invalidate("ID", id) + return s.db.RunInTx(ctx, func(tx bun.Tx) error { // delete links between this status and any emojis it uses if _, err := tx. diff --git a/internal/db/relationship.go b/internal/db/relationship.go index e19aee646..6ba9fdf8c 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -21,6 +21,7 @@ "context" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // Relationship contains functions for getting or modifying the relationship between two accounts. @@ -166,6 +167,9 @@ type Relationship interface { // CountAccountFollowerRequests returns number of follow requests originating from the given account. CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) + // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. + GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Pager) ([]*gtsmodel.Block, error) + // GetNote gets a private note from a source account on a target account, if it exists. GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) diff --git a/internal/paging/paging.go b/internal/paging/paging.go new file mode 100644 index 000000000..0323f40bc --- /dev/null +++ b/internal/paging/paging.go @@ -0,0 +1,227 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package paging + +import "golang.org/x/exp/slices" + +// Pager provides a means of paging serialized IDs, +// using the terminology of our API endpoint queries. +type Pager struct { + // SinceID will limit the returned + // page of IDs to contain newer than + // since ID (excluding it). Result + // will be returned DESCENDING. + SinceID string + + // MinID will limit the returned + // page of IDs to contain newer than + // min ID (excluding it). Result + // will be returned ASCENDING. + MinID string + + // MaxID will limit the returned + // page of IDs to contain older + // than (excluding) this max ID. + MaxID string + + // Limit will limit the returned + // page of IDs to at most 'limit'. + Limit int +} + +// Page will page the given slice of GoToSocial IDs according +// to the receiving Pager's SinceID, MinID, MaxID and Limits. +// NOTE THE INPUT SLICE MUST BE SORTED IN ASCENDING ORDER +// (I.E. OLDEST ITEMS AT LOWEST INDICES, NEWER AT HIGHER). +func (p *Pager) PageAsc(ids []string) []string { + if p == nil { + // no paging. + return ids + } + + var asc bool + + if p.SinceID != "" { + // If a sinceID is given, we + // page down i.e. descending. + asc = false + + for i := 0; i < len(ids); i++ { + if ids[i] == p.SinceID { + // Hit the boundary. + // Reslice to be: + // "from here" + ids = ids[i+1:] + break + } + } + } else if p.MinID != "" { + // We only support minID if + // no sinceID is provided. + // + // If a minID is given, we + // page up, i.e. ascending. + asc = true + + for i := 0; i < len(ids); i++ { + if ids[i] == p.MinID { + // Hit the boundary. + // Reslice to be: + // "from here" + ids = ids[i+1:] + break + } + } + } + + if p.MaxID != "" { + for i := 0; i < len(ids); i++ { + if ids[i] == p.MaxID { + // Hit the boundary. + // Reslice to be: + // "up to here" + ids = ids[:i] + break + } + } + } + + if !asc && len(ids) > 1 { + var ( + // Start at front. + i = 0 + + // Start at back. + j = len(ids) - 1 + ) + + // Clone input IDs before + // we perform modifications. + ids = slices.Clone(ids) + + for i < j { + // Swap i,j index values in slice. + ids[i], ids[j] = ids[j], ids[i] + + // incr + decr, + // looping until + // they meet in + // the middle. + i++ + j-- + } + } + + if p.Limit > 0 && p.Limit < len(ids) { + // Reslice IDs to given limit. + ids = ids[:p.Limit] + } + + return ids +} + +// Page will page the given slice of GoToSocial IDs according +// to the receiving Pager's SinceID, MinID, MaxID and Limits. +// NOTE THE INPUT SLICE MUST BE SORTED IN ASCENDING ORDER. +// (I.E. NEWEST ITEMS AT LOWEST INDICES, OLDER AT HIGHER). +func (p *Pager) PageDesc(ids []string) []string { + if p == nil { + // no paging. + return ids + } + + var asc bool + + if p.MaxID != "" { + for i := 0; i < len(ids); i++ { + if ids[i] == p.MaxID { + // Hit the boundary. + // Reslice to be: + // "from here" + ids = ids[i+1:] + break + } + } + } + + if p.SinceID != "" { + // If a sinceID is given, we + // page down i.e. descending. + asc = false + + for i := 0; i < len(ids); i++ { + if ids[i] == p.SinceID { + // Hit the boundary. + // Reslice to be: + // "up to here" + ids = ids[:i] + break + } + } + } else if p.MinID != "" { + // We only support minID if + // no sinceID is provided. + // + // If a minID is given, we + // page up, i.e. ascending. + asc = true + + for i := 0; i < len(ids); i++ { + if ids[i] == p.MinID { + // Hit the boundary. + // Reslice to be: + // "up to here" + ids = ids[:i] + break + } + } + } + + if asc && len(ids) > 1 { + var ( + // Start at front. + i = 0 + + // Start at back. + j = len(ids) - 1 + ) + + // Clone input IDs before + // we perform modifications. + ids = slices.Clone(ids) + + for i < j { + // Swap i,j index values in slice. + ids[i], ids[j] = ids[j], ids[i] + + // incr + decr, + // looping until + // they meet in + // the middle. + i++ + j-- + } + } + + if p.Limit > 0 && p.Limit < len(ids) { + // Reslice IDs to given limit. + ids = ids[:p.Limit] + } + + return ids +} diff --git a/internal/paging/paging_test.go b/internal/paging/paging_test.go new file mode 100644 index 000000000..71c3be0c9 --- /dev/null +++ b/internal/paging/paging_test.go @@ -0,0 +1,171 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package paging_test + +import ( + "testing" + + "github.com/superseriousbusiness/gotosocial/internal/paging" + "golang.org/x/exp/slices" +) + +type Case struct { + // Name is the test case name. + Name string + + // Input contains test case input ID slice. + Input []string + + // Expect contains expected test case output. + Expect []string + + // Page contains the paging function to use. + Page func([]string) []string +} + +var cases = []Case{ + { + Name: "min_id and max_id set", + Input: []string{ + "064Q5D7VG6TPPQ46T09MHJ96FW", + "064Q5D7VGPTC4NK5T070VYSSF8", + "064Q5D7VH5F0JXG6W5NCQ3JCWW", + "064Q5D7VHMSW9DF3GCS088VAZC", + "064Q5D7VJ073XG9ZTWHA2KHN10", + "064Q5D7VJADJTPA3GW8WAX10TW", + "064Q5D7VJMWXZD3S1KT7RD51N8", + "064Q5D7VJYFBYSAH86KDBKZ6AC", + "064Q5D7VK8H7WMJS399SHEPCB0", + "064Q5D7VKG5EQ43TYP71B4K6K0", + }, + Expect: []string{ + "064Q5D7VGPTC4NK5T070VYSSF8", + "064Q5D7VH5F0JXG6W5NCQ3JCWW", + "064Q5D7VHMSW9DF3GCS088VAZC", + "064Q5D7VJ073XG9ZTWHA2KHN10", + "064Q5D7VJADJTPA3GW8WAX10TW", + "064Q5D7VJMWXZD3S1KT7RD51N8", + "064Q5D7VJYFBYSAH86KDBKZ6AC", + "064Q5D7VK8H7WMJS399SHEPCB0", + }, + Page: (&paging.Pager{ + MinID: "064Q5D7VG6TPPQ46T09MHJ96FW", + MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0", + }).PageAsc, + }, + { + Name: "min_id, max_id and limit set", + Input: []string{ + "064Q5D7VG6TPPQ46T09MHJ96FW", + "064Q5D7VGPTC4NK5T070VYSSF8", + "064Q5D7VH5F0JXG6W5NCQ3JCWW", + "064Q5D7VHMSW9DF3GCS088VAZC", + "064Q5D7VJ073XG9ZTWHA2KHN10", + "064Q5D7VJADJTPA3GW8WAX10TW", + "064Q5D7VJMWXZD3S1KT7RD51N8", + "064Q5D7VJYFBYSAH86KDBKZ6AC", + "064Q5D7VK8H7WMJS399SHEPCB0", + "064Q5D7VKG5EQ43TYP71B4K6K0", + }, + Expect: []string{ + "064Q5D7VGPTC4NK5T070VYSSF8", + "064Q5D7VH5F0JXG6W5NCQ3JCWW", + "064Q5D7VHMSW9DF3GCS088VAZC", + "064Q5D7VJ073XG9ZTWHA2KHN10", + "064Q5D7VJADJTPA3GW8WAX10TW", + }, + Page: (&paging.Pager{ + MinID: "064Q5D7VG6TPPQ46T09MHJ96FW", + MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0", + Limit: 5, + }).PageAsc, + }, + { + Name: "min_id, max_id and too-large limit set", + Input: []string{ + "064Q5D7VG6TPPQ46T09MHJ96FW", + "064Q5D7VGPTC4NK5T070VYSSF8", + "064Q5D7VH5F0JXG6W5NCQ3JCWW", + "064Q5D7VHMSW9DF3GCS088VAZC", + "064Q5D7VJ073XG9ZTWHA2KHN10", + "064Q5D7VJADJTPA3GW8WAX10TW", + "064Q5D7VJMWXZD3S1KT7RD51N8", + "064Q5D7VJYFBYSAH86KDBKZ6AC", + "064Q5D7VK8H7WMJS399SHEPCB0", + "064Q5D7VKG5EQ43TYP71B4K6K0", + }, + Expect: []string{ + "064Q5D7VGPTC4NK5T070VYSSF8", + "064Q5D7VH5F0JXG6W5NCQ3JCWW", + "064Q5D7VHMSW9DF3GCS088VAZC", + "064Q5D7VJ073XG9ZTWHA2KHN10", + "064Q5D7VJADJTPA3GW8WAX10TW", + "064Q5D7VJMWXZD3S1KT7RD51N8", + "064Q5D7VJYFBYSAH86KDBKZ6AC", + "064Q5D7VK8H7WMJS399SHEPCB0", + }, + Page: (&paging.Pager{ + MinID: "064Q5D7VG6TPPQ46T09MHJ96FW", + MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0", + Limit: 100, + }).PageAsc, + }, + { + Name: "since_id and max_id set", + Input: []string{ + "064Q5D7VG6TPPQ46T09MHJ96FW", + "064Q5D7VGPTC4NK5T070VYSSF8", + "064Q5D7VH5F0JXG6W5NCQ3JCWW", + "064Q5D7VHMSW9DF3GCS088VAZC", + "064Q5D7VJ073XG9ZTWHA2KHN10", + "064Q5D7VJADJTPA3GW8WAX10TW", + "064Q5D7VJMWXZD3S1KT7RD51N8", + "064Q5D7VJYFBYSAH86KDBKZ6AC", + "064Q5D7VK8H7WMJS399SHEPCB0", + "064Q5D7VKG5EQ43TYP71B4K6K0", + }, + Expect: []string{ + "064Q5D7VK8H7WMJS399SHEPCB0", + "064Q5D7VJYFBYSAH86KDBKZ6AC", + "064Q5D7VJMWXZD3S1KT7RD51N8", + "064Q5D7VJADJTPA3GW8WAX10TW", + "064Q5D7VJ073XG9ZTWHA2KHN10", + "064Q5D7VHMSW9DF3GCS088VAZC", + "064Q5D7VH5F0JXG6W5NCQ3JCWW", + "064Q5D7VGPTC4NK5T070VYSSF8", + }, + Page: (&paging.Pager{ + SinceID: "064Q5D7VG6TPPQ46T09MHJ96FW", + MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0", + }).PageAsc, + }, +} + +func TestPage(t *testing.T) { + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + // Page the input slice. + out := c.Page(c.Input) + + // Check paged output is as expected. + if !slices.Equal(out, c.Expect) { + t.Errorf("\nreceived=%v\nexpect%v\n", out, c.Expect) + } + }) + } +} diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index 2a20ec96e..a613ba485 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -20,7 +20,6 @@ import ( "context" "errors" - "fmt" "net" "time" @@ -114,38 +113,38 @@ func (p *Processor) DeleteSelf(ctx context.Context, account *gtsmodel.Account) g func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *gtsmodel.Account) error { user, err := p.state.DB.GetUserByAccountID(ctx, account.ID) if err != nil { - return fmt.Errorf("deleteUserAndTokensForAccount: db error getting user: %w", err) + return gtserror.Newf("db error getting user: %w", err) } tokens := []*gtsmodel.Token{} if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err != nil { - return fmt.Errorf("deleteUserAndTokensForAccount: db error getting tokens: %w", err) + return gtserror.Newf("db error getting tokens: %w", err) } for _, t := range tokens { // Delete any OAuth clients associated with this token. if err := p.state.DB.DeleteByID(ctx, t.ClientID, &[]*gtsmodel.Client{}); err != nil { - return fmt.Errorf("deleteUserAndTokensForAccount: db error deleting client: %w", err) + return gtserror.Newf("db error deleting client: %w", err) } // Delete any OAuth applications associated with this token. if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, &[]*gtsmodel.Application{}); err != nil { - return fmt.Errorf("deleteUserAndTokensForAccount: db error deleting application: %w", err) + return gtserror.Newf("db error deleting application: %w", err) } // Delete the token itself. if err := p.state.DB.DeleteByID(ctx, t.ID, t); err != nil { - return fmt.Errorf("deleteUserAndTokensForAccount: db error deleting token: %w", err) + return gtserror.Newf("db error deleting token: %w", err) } } columns, err := stubbifyUser(user) if err != nil { - return fmt.Errorf("deleteUserAndTokensForAccount: error stubbifying user: %w", err) + return gtserror.Newf("error stubbifying user: %w", err) } if err := p.state.DB.UpdateUser(ctx, user, columns...); err != nil { - return fmt.Errorf("deleteUserAndTokensForAccount: db error updating user: %w", err) + return gtserror.Newf("db error updating user: %w", err) } return nil @@ -160,24 +159,24 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel. // Delete follows targeting this account. followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return fmt.Errorf("deleteAccountFollows: db error getting follows targeting account %s: %w", account.ID, err) + return gtserror.Newf("db error getting follows targeting account %s: %w", account.ID, err) } for _, follow := range followedBy { if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil { - return fmt.Errorf("deleteAccountFollows: db error unfollowing account followedBy: %w", err) + return gtserror.Newf("db error unfollowing account followedBy: %w", err) } } // Delete follow requests targeting this account. followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return fmt.Errorf("deleteAccountFollows: db error getting follow requests targeting account %s: %w", account.ID, err) + return gtserror.Newf("db error getting follow requests targeting account %s: %w", account.ID, err) } for _, followRequest := range followRequestedBy { if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil { - return fmt.Errorf("deleteAccountFollows: db error unfollowing account followRequestedBy: %w", err) + return gtserror.Newf("db error unfollowing account followRequestedBy: %w", err) } } @@ -193,14 +192,14 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel. // Delete follows originating from this account. following, err := p.state.DB.GetAccountFollows(ctx, account.ID) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return fmt.Errorf("deleteAccountFollows: db error getting follows owned by account %s: %w", account.ID, err) + return gtserror.Newf("db error getting follows owned by account %s: %w", account.ID, err) } // For each follow owned by this account, unfollow // and process side effects (noop if remote account). for _, follow := range following { if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil { - return fmt.Errorf("deleteAccountFollows: db error unfollowing account: %w", err) + return gtserror.Newf("db error unfollowing account: %w", err) } if msg := unfollowSideEffects(ctx, account, follow); msg != nil { // There was a side effect to process. @@ -211,14 +210,14 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel. // Delete follow requests originating from this account. followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return fmt.Errorf("deleteAccountFollows: db error getting follow requests owned by account %s: %w", account.ID, err) + return gtserror.Newf("db error getting follow requests owned by account %s: %w", account.ID, err) } // For each follow owned by this account, unfollow // and process side effects (noop if remote account). for _, followRequest := range followRequesting { if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil { - return fmt.Errorf("deleteAccountFollows: db error unfollowingRequesting account: %w", err) + return gtserror.Newf("db error unfollowingRequesting account: %w", err) } // Dummy out a follow so our side effects func @@ -279,7 +278,7 @@ func (p *Processor) unfollowSideEffectsFunc(deletedAccount *gtsmodel.Account) fu func (p *Processor) deleteAccountBlocks(ctx context.Context, account *gtsmodel.Account) error { if err := p.state.DB.DeleteAccountBlocks(ctx, account.ID); err != nil { - return fmt.Errorf("deleteAccountBlocks: db error deleting account blocks for %s: %w", account.ID, err) + return gtserror.Newf("db error deleting account blocks for %s: %w", account.ID, err) } return nil } @@ -333,7 +332,7 @@ func (p *Processor) deleteAccountStatuses(ctx context.Context, account *gtsmodel // Look for any boosts of this status in DB. boosts, err := p.state.DB.GetStatusReblogs(ctx, status) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return fmt.Errorf("deleteAccountStatuses: error fetching status reblogs for %s: %w", status.ID, err) + return gtserror.Newf("error fetching status reblogs for %s: %w", status.ID, err) } for _, boost := range boosts { @@ -347,7 +346,7 @@ func (p *Processor) deleteAccountStatuses(ctx context.Context, account *gtsmodel log.WithContext(ctx).WithField("boost", boost).Warnf("no account found with id %s for boost %s", boost.AccountID, boost.ID) continue } - return fmt.Errorf("deleteAccountStatuses: error fetching boosted status account for %s: %w", boost.AccountID, err) + return gtserror.Newf("error fetching boosted status account for %s: %w", boost.AccountID, err) } // Set account model @@ -505,7 +504,7 @@ func stubbifyUser(user *gtsmodel.User) ([]string, error) { return nil, err } - var never = time.Time{} + never := time.Time{} user.EncryptedPassword = string(dummyPassword) user.SignUpIP = net.IPv4zero diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go index 644f28ca9..8996dff92 100644 --- a/internal/processing/blocks.go +++ b/internal/processing/blocks.go @@ -19,69 +19,71 @@ import ( "context" - "fmt" - "net/url" + "errors" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/paging" + "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *Processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { - accounts, nextMaxID, prevMinID, err := p.state.DB.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit) - if err != nil { - if err == db.ErrNoEntries { - // there are just no entries - return &apimodel.BlocksResponse{ - Accounts: []*apimodel.Account{}, - }, nil - } - // there's an actual error +// BlocksGet ... +func (p *Processor) BlocksGet( + ctx context.Context, + requestingAccount *gtsmodel.Account, + page paging.Pager, +) (*apimodel.PageableResponse, gtserror.WithCode) { + blocks, err := p.state.DB.GetAccountBlocks(ctx, + requestingAccount.ID, + &page, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { return nil, gtserror.NewErrorInternalError(err) } - apiAccounts := []*apimodel.Account{} - for _, a := range accounts { - apiAccount, err := p.tc.AccountToAPIAccountBlocked(ctx, a) - if err != nil { + // Check for zero length. + count := len(blocks) + if len(blocks) == 0 { + return util.EmptyPageableResponse(), nil + } + + var ( + items = make([]interface{}, 0, count) + + // Set next + prev values before API converting + // so the caller can still page even on error. + nextMaxIDValue = blocks[count-1].ID + prevMinIDValue = blocks[0].ID + ) + + for _, block := range blocks { + if block.TargetAccount == nil { + // All models should be populated at this point. + log.Warnf(ctx, "block target account was nil: %v", err) continue } - apiAccounts = append(apiAccounts, apiAccount) - } - return p.packageBlocksResponse(apiAccounts, "/api/v1/blocks", nextMaxID, prevMinID, limit) -} - -func (p *Processor) packageBlocksResponse(accounts []*apimodel.Account, path string, nextMaxID string, prevMinID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { - resp := &apimodel.BlocksResponse{ - Accounts: []*apimodel.Account{}, - } - resp.Accounts = accounts - - // prepare the next and previous links - if len(accounts) != 0 { - protocol := config.GetProtocol() - host := config.GetHost() - - nextLink := &url.URL{ - Scheme: protocol, - Host: host, - Path: path, - RawQuery: fmt.Sprintf("limit=%d&max_id=%s", limit, nextMaxID), + // Convert target account to frontend API model. + account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount) + if err != nil { + log.Errorf(ctx, "error converting account to public api account: %v", err) + continue } - next := fmt.Sprintf("<%s>; rel=\"next\"", nextLink.String()) - prevLink := &url.URL{ - Scheme: protocol, - Host: host, - Path: path, - RawQuery: fmt.Sprintf("limit=%d&min_id=%s", limit, prevMinID), - } - prev := fmt.Sprintf("<%s>; rel=\"prev\"", prevLink.String()) - resp.LinkHeader = fmt.Sprintf("%s, %s", next, prev) + // Append target to return items. + items = append(items, account) } - return resp, nil + return util.PackagePageableResponse(util.PageableResponseParams{ + Items: items, + Path: "/api/v1/blocks", + NextMaxIDKey: "max_id", + PrevMinIDKey: "since_id", + NextMaxIDValue: nextMaxIDValue, + PrevMinIDValue: prevMinIDValue, + Limit: page.Limit, + }) } diff --git a/test/envparsing.sh b/test/envparsing.sh index 8f4372906..b9017d0be 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -25,6 +25,9 @@ EXPECT=$(cat <<"EOF" "account-note-ttl": 1800000000000, "account-sweep-freq": 1000000000, "account-ttl": 10800000000000, + "block-ids-max-size": 500, + "block-ids-sweep-freq": 60000000000, + "block-ids-ttl": 1800000000000, "block-max-size": 1000, "block-sweep-freq": 60000000000, "block-ttl": 1800000000000, @@ -37,7 +40,13 @@ EXPECT=$(cat <<"EOF" "emoji-max-size": 2000, "emoji-sweep-freq": 60000000000, "emoji-ttl": 1800000000000, + "follow-ids-max-size": 500, + "follow-ids-sweep-freq": 60000000000, + "follow-ids-ttl": 1800000000000, "follow-max-size": 2000, + "follow-request-ids-max-size": 500, + "follow-request-ids-sweep-freq": 60000000000, + "follow-request-ids-ttl": 1800000000000, "follow-request-max-size": 2000, "follow-request-sweep-freq": 60000000000, "follow-request-ttl": 1800000000000, diff --git a/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go b/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go index 623a19910..af108e336 100644 --- a/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go +++ b/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go @@ -479,23 +479,23 @@ func (c *Cache[K, V]) InvalidateAll(keys ...K) (ok bool) { kvs = make([]kv[K, V], 0, len(keys)) c.locked(func() { - for _, key := range keys { + for x := range keys { var item *Entry[K, V] // Check for item in cache - item, ok = c.Cache.Get(key) + item, ok = c.Cache.Get(keys[x]) if !ok { - return + continue } // Append this old value to slice kvs = append(kvs, kv[K, V]{ - K: key, + K: keys[x], V: item.Value, }) // Remove from cache map - _ = c.Cache.Delete(key) + _ = c.Cache.Delete(keys[x]) // Free entry c.free(item) @@ -553,6 +553,7 @@ func (c *Cache[K, V]) Cap() (l int) { return } +// locked performs given function within mutex lock (NOTE: UNLOCK IS NOT DEFERRED). func (c *Cache[K, V]) locked(fn func()) { c.Lock() fn() diff --git a/vendor/modules.txt b/vendor/modules.txt index 64a310838..006cc3e5d 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -13,7 +13,7 @@ codeberg.org/gruf/go-bytesize # codeberg.org/gruf/go-byteutil v1.1.2 ## explicit; go 1.16 codeberg.org/gruf/go-byteutil -# codeberg.org/gruf/go-cache/v3 v3.4.3 +# codeberg.org/gruf/go-cache/v3 v3.4.4 ## explicit; go 1.19 codeberg.org/gruf/go-cache/v3 codeberg.org/gruf/go-cache/v3/result