diff --git a/go.mod b/go.mod index 0b8021710..81c7874c9 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( codeberg.org/gruf/go-bytesize v1.0.2 codeberg.org/gruf/go-byteutil v1.1.2 - codeberg.org/gruf/go-cache/v3 v3.5.3 + codeberg.org/gruf/go-cache/v3 v3.5.5 codeberg.org/gruf/go-debug v1.3.0 codeberg.org/gruf/go-errors/v2 v2.2.0 codeberg.org/gruf/go-fastcopy v1.1.2 diff --git a/go.sum b/go.sum index 6a5bce0ee..4b0d31864 100644 --- a/go.sum +++ b/go.sum @@ -48,8 +48,8 @@ codeberg.org/gruf/go-bytesize v1.0.2/go.mod h1:n/GU8HzL9f3UNp/mUKyr1qVmTlj7+xacp codeberg.org/gruf/go-byteutil v1.0.0/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU= codeberg.org/gruf/go-byteutil v1.1.2 h1:TQLZtTxTNca9xEfDIndmo7nBYxeS94nrv/9DS3Nk5Tw= codeberg.org/gruf/go-byteutil v1.1.2/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU= -codeberg.org/gruf/go-cache/v3 v3.5.3 h1:CRO2syVQxT/JbqDnUxzjeJkLInihEmTlJOkrOgkTmqI= -codeberg.org/gruf/go-cache/v3 v3.5.3/go.mod h1:NbsGQUgEdNFd631WSasvCHIVAaY9ovuiSeoBwtsIeDc= +codeberg.org/gruf/go-cache/v3 v3.5.5 h1:Ce7odyvr8oF6h49LSjPL7AZs2QGyKMN9BPkgKcfR0BA= +codeberg.org/gruf/go-cache/v3 v3.5.5/go.mod h1:NbsGQUgEdNFd631WSasvCHIVAaY9ovuiSeoBwtsIeDc= codeberg.org/gruf/go-debug v1.3.0 h1:PIRxQiWUFKtGOGZFdZ3Y0pqyfI0Xr87j224IYe2snZs= codeberg.org/gruf/go-debug v1.3.0/go.mod h1:N+vSy9uJBQgpQcJUqjctvqFz7tBHJf+S/PIjLILzpLg= codeberg.org/gruf/go-errors/v2 v2.0.0/go.mod h1:ZRhbdhvgoUA3Yw6e56kd9Ox984RrvbEFC2pOXyHDJP4= diff --git a/internal/cache/cache.go b/internal/cache/cache.go index cb5503a84..ec0ec3faa 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -196,6 +196,21 @@ func (c *Caches) setuphooks() { // c.GTS.Media().Invalidate("StatusID") will not work. c.GTS.Media().Invalidate("ID", id) } + + if status.BoostOfID != "" { + // Invalidate boost ID list of the original status. + c.GTS.BoostOfIDs().Invalidate(status.BoostOfID) + } + + if status.InReplyToID != "" { + // Invalidate in reply to ID list of original status. + c.GTS.InReplyToIDs().Invalidate(status.InReplyToID) + } + }) + + c.GTS.StatusFave().SetInvalidateCallback(func(fave *gtsmodel.StatusFave) { + // Invalidate status fave ID list for this status. + c.GTS.StatusFaveIDs().Invalidate(fave.StatusID) }) c.GTS.User().SetInvalidateCallback(func(user *gtsmodel.User) { diff --git a/internal/cache/gts.go b/internal/cache/gts.go index 3f54d5c52..f120bcf4e 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -34,6 +34,7 @@ type GTSCaches struct { accountNote *result.Cache[*gtsmodel.AccountNote] block *result.Cache[*gtsmodel.Block] blockIDs *SliceCache[string] + boostOfIDs *SliceCache[string] domainBlock *domain.BlockCache emoji *result.Cache[*gtsmodel.Emoji] emojiCategory *result.Cache[*gtsmodel.EmojiCategory] @@ -42,6 +43,7 @@ type GTSCaches struct { followRequest *result.Cache[*gtsmodel.FollowRequest] followRequestIDs *SliceCache[string] instance *result.Cache[*gtsmodel.Instance] + inReplyToIDs *SliceCache[string] list *result.Cache[*gtsmodel.List] listEntry *result.Cache[*gtsmodel.ListEntry] marker *result.Cache[*gtsmodel.Marker] @@ -51,6 +53,7 @@ type GTSCaches struct { report *result.Cache[*gtsmodel.Report] status *result.Cache[*gtsmodel.Status] statusFave *result.Cache[*gtsmodel.StatusFave] + statusFaveIDs *SliceCache[string] tag *result.Cache[*gtsmodel.Tag] tombstone *result.Cache[*gtsmodel.Tombstone] user *result.Cache[*gtsmodel.User] @@ -66,6 +69,7 @@ func (c *GTSCaches) Init() { c.initAccountNote() c.initBlock() c.initBlockIDs() + c.initBoostOfIDs() c.initDomainBlock() c.initEmoji() c.initEmojiCategory() @@ -73,6 +77,7 @@ func (c *GTSCaches) Init() { c.initFollowIDs() c.initFollowRequest() c.initFollowRequestIDs() + c.initInReplyToIDs() c.initInstance() c.initList() c.initListEntry() @@ -84,6 +89,7 @@ func (c *GTSCaches) Init() { c.initStatus() c.initStatusFave() c.initTag() + c.initStatusFaveIDs() c.initTombstone() c.initUser() c.initWebfinger() @@ -121,6 +127,11 @@ func (c *GTSCaches) BlockIDs() *SliceCache[string] { return c.blockIDs } +// BoostOfIDs provides access to the boost of IDs list database cache. +func (c *GTSCaches) BoostOfIDs() *SliceCache[string] { + return c.boostOfIDs +} + // DomainBlock provides access to the domain block database cache. func (c *GTSCaches) DomainBlock() *domain.BlockCache { return c.domainBlock @@ -169,6 +180,11 @@ func (c *GTSCaches) Instance() *result.Cache[*gtsmodel.Instance] { return c.instance } +// InReplyToIDs provides access to the status in reply to IDs list database cache. +func (c *GTSCaches) InReplyToIDs() *SliceCache[string] { + return c.inReplyToIDs +} + // List provides access to the gtsmodel List database cache. func (c *GTSCaches) List() *result.Cache[*gtsmodel.List] { return c.list @@ -219,6 +235,11 @@ func (c *GTSCaches) Tag() *result.Cache[*gtsmodel.Tag] { return c.tag } +// StatusFaveIDs provides access to the status fave IDs list database cache. +func (c *GTSCaches) StatusFaveIDs() *SliceCache[string] { + return c.statusFaveIDs +} + // Tombstone provides access to the gtsmodel Tombstone database cache. func (c *GTSCaches) Tombstone() *result.Cache[*gtsmodel.Tombstone] { return c.tombstone @@ -247,7 +268,7 @@ func (c *GTSCaches) initAccount() { {Name: "ID"}, {Name: "URI"}, {Name: "URL"}, - {Name: "Username.Domain"}, + {Name: "Username.Domain", AllowZero: true /* domain can be zero i.e. "" */}, {Name: "PublicKeyURI"}, {Name: "InboxURI"}, {Name: "OutboxURI"}, @@ -320,6 +341,20 @@ func (c *GTSCaches) initBlockIDs() { )} } +func (c *GTSCaches) initBoostOfIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCacheBoostOfIDsMemRatio(), + ) + + log.Infof(nil, "BoostofIDs cache size = %d", cap) + + c.boostOfIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + func (c *GTSCaches) initDomainBlock() { c.domainBlock = new(domain.BlockCache) } @@ -336,7 +371,7 @@ func (c *GTSCaches) initEmoji() { c.emoji = result.New([]result.Lookup{ {Name: "ID"}, {Name: "URI"}, - {Name: "Shortcode.Domain"}, + {Name: "Shortcode.Domain", AllowZero: true /* domain can be zero i.e. "" */}, {Name: "ImageStaticURL"}, {Name: "CategoryID", Multi: true}, }, func(e1 *gtsmodel.Emoji) *gtsmodel.Emoji { @@ -445,6 +480,20 @@ func (c *GTSCaches) initFollowRequestIDs() { )} } +func (c *GTSCaches) initInReplyToIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCacheInReplyToIDsMemRatio(), + ) + + log.Infof(nil, "InReplyTo IDs cache size = %d", cap) + + c.inReplyToIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + func (c *GTSCaches) initInstance() { // Calculate maximum cache size. cap := calculateResultCacheMax( @@ -622,6 +671,7 @@ func (c *GTSCaches) initStatus() { {Name: "ID"}, {Name: "URI"}, {Name: "URL"}, + {Name: "BoostOfID.AccountID"}, }, func(s1 *gtsmodel.Status) *gtsmodel.Status { s2 := new(gtsmodel.Status) *s2 = *s1 @@ -643,6 +693,7 @@ func (c *GTSCaches) initStatusFave() { c.statusFave = result.New([]result.Lookup{ {Name: "ID"}, {Name: "AccountID.StatusID"}, + {Name: "StatusID", Multi: true}, }, func(f1 *gtsmodel.StatusFave) *gtsmodel.StatusFave { f2 := new(gtsmodel.StatusFave) *f2 = *f1 @@ -652,6 +703,20 @@ func (c *GTSCaches) initStatusFave() { c.statusFave.IgnoreErrors(ignoreErrors) } +func (c *GTSCaches) initStatusFaveIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCacheStatusFaveIDsMemRatio(), + ) + + log.Infof(nil, "StatusFave IDs cache size = %d", cap) + + c.statusFaveIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + func (c *GTSCaches) initTag() { // Calculate maximum cache size. cap := calculateResultCacheMax( diff --git a/internal/config/config.go b/internal/config/config.go index 50508e40b..ef79d4e12 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -180,12 +180,14 @@ type CacheConfiguration struct { AccountNoteMemRatio float64 `name:"account-note-mem-ratio"` BlockMemRatio float64 `name:"block-mem-ratio"` BlockIDsMemRatio float64 `name:"block-mem-ratio"` + BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"` EmojiMemRatio float64 `name:"emoji-mem-ratio"` EmojiCategoryMemRatio float64 `name:"emoji-category-mem-ratio"` FollowMemRatio float64 `name:"follow-mem-ratio"` FollowIDsMemRatio float64 `name:"follow-ids-mem-ratio"` FollowRequestMemRatio float64 `name:"follow-request-mem-ratio"` FollowRequestIDsMemRatio float64 `name:"follow-request-ids-mem-ratio"` + InReplyToIDsMemRatio float64 `name:"in-reply-to-ids-mem-ratio"` InstanceMemRatio float64 `name:"instance-mem-ratio"` ListMemRatio float64 `name:"list-mem-ratio"` ListEntryMemRatio float64 `name:"list-entry-mem-ratio"` @@ -196,6 +198,7 @@ type CacheConfiguration struct { ReportMemRatio float64 `name:"report-mem-ratio"` StatusMemRatio float64 `name:"status-mem-ratio"` StatusFaveMemRatio float64 `name:"status-fave-mem-ratio"` + StatusFaveIDsMemRatio float64 `name:"status-fave-ids-mem-ratio"` TagMemRatio float64 `name:"tag-mem-ratio"` TombstoneMemRatio float64 `name:"tombstone-mem-ratio"` UserMemRatio float64 `name:"user-mem-ratio"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index e740d1c98..2bc95f6f1 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -149,12 +149,14 @@ AccountNoteMemRatio: 0.1, BlockMemRatio: 3, BlockIDsMemRatio: 3, + BoostOfIDsMemRatio: 3, EmojiMemRatio: 3, EmojiCategoryMemRatio: 0.1, FollowMemRatio: 4, FollowIDsMemRatio: 4, FollowRequestMemRatio: 2, FollowRequestIDsMemRatio: 2, + InReplyToIDsMemRatio: 3, InstanceMemRatio: 1, ListMemRatio: 3, ListEntryMemRatio: 3, @@ -165,6 +167,7 @@ ReportMemRatio: 1, StatusMemRatio: 18, StatusFaveMemRatio: 5, + StatusFaveIDsMemRatio: 3, TagMemRatio: 3, TombstoneMemRatio: 2, UserMemRatio: 0.1, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index c29b5c38f..0a299e7d0 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2549,6 +2549,31 @@ func GetCacheBlockIDsMemRatio() float64 { return global.GetCacheBlockIDsMemRatio // SetCacheBlockIDsMemRatio safely sets the value for global configuration 'Cache.BlockIDsMemRatio' field func SetCacheBlockIDsMemRatio(v float64) { global.SetCacheBlockIDsMemRatio(v) } +// GetCacheBoostOfIDsMemRatio safely fetches the Configuration value for state's 'Cache.BoostOfIDsMemRatio' field +func (st *ConfigState) GetCacheBoostOfIDsMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.BoostOfIDsMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheBoostOfIDsMemRatio safely sets the Configuration value for state's 'Cache.BoostOfIDsMemRatio' field +func (st *ConfigState) SetCacheBoostOfIDsMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.BoostOfIDsMemRatio = v + st.reloadToViper() +} + +// CacheBoostOfIDsMemRatioFlag returns the flag name for the 'Cache.BoostOfIDsMemRatio' field +func CacheBoostOfIDsMemRatioFlag() string { return "cache-boost-of-ids-mem-ratio" } + +// GetCacheBoostOfIDsMemRatio safely fetches the value for global configuration 'Cache.BoostOfIDsMemRatio' field +func GetCacheBoostOfIDsMemRatio() float64 { return global.GetCacheBoostOfIDsMemRatio() } + +// SetCacheBoostOfIDsMemRatio safely sets the value for global configuration 'Cache.BoostOfIDsMemRatio' field +func SetCacheBoostOfIDsMemRatio(v float64) { global.SetCacheBoostOfIDsMemRatio(v) } + // GetCacheEmojiMemRatio safely fetches the Configuration value for state's 'Cache.EmojiMemRatio' field func (st *ConfigState) GetCacheEmojiMemRatio() (v float64) { st.mutex.RLock() @@ -2699,6 +2724,31 @@ func GetCacheFollowRequestIDsMemRatio() float64 { return global.GetCacheFollowRe // SetCacheFollowRequestIDsMemRatio safely sets the value for global configuration 'Cache.FollowRequestIDsMemRatio' field func SetCacheFollowRequestIDsMemRatio(v float64) { global.SetCacheFollowRequestIDsMemRatio(v) } +// GetCacheInReplyToIDsMemRatio safely fetches the Configuration value for state's 'Cache.InReplyToIDsMemRatio' field +func (st *ConfigState) GetCacheInReplyToIDsMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.InReplyToIDsMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheInReplyToIDsMemRatio safely sets the Configuration value for state's 'Cache.InReplyToIDsMemRatio' field +func (st *ConfigState) SetCacheInReplyToIDsMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.InReplyToIDsMemRatio = v + st.reloadToViper() +} + +// CacheInReplyToIDsMemRatioFlag returns the flag name for the 'Cache.InReplyToIDsMemRatio' field +func CacheInReplyToIDsMemRatioFlag() string { return "cache-in-reply-to-ids-mem-ratio" } + +// GetCacheInReplyToIDsMemRatio safely fetches the value for global configuration 'Cache.InReplyToIDsMemRatio' field +func GetCacheInReplyToIDsMemRatio() float64 { return global.GetCacheInReplyToIDsMemRatio() } + +// SetCacheInReplyToIDsMemRatio safely sets the value for global configuration 'Cache.InReplyToIDsMemRatio' field +func SetCacheInReplyToIDsMemRatio(v float64) { global.SetCacheInReplyToIDsMemRatio(v) } + // GetCacheInstanceMemRatio safely fetches the Configuration value for state's 'Cache.InstanceMemRatio' field func (st *ConfigState) GetCacheInstanceMemRatio() (v float64) { st.mutex.RLock() @@ -2949,6 +2999,31 @@ func GetCacheStatusFaveMemRatio() float64 { return global.GetCacheStatusFaveMemR // SetCacheStatusFaveMemRatio safely sets the value for global configuration 'Cache.StatusFaveMemRatio' field func SetCacheStatusFaveMemRatio(v float64) { global.SetCacheStatusFaveMemRatio(v) } +// GetCacheStatusFaveIDsMemRatio safely fetches the Configuration value for state's 'Cache.StatusFaveIDsMemRatio' field +func (st *ConfigState) GetCacheStatusFaveIDsMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.StatusFaveIDsMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheStatusFaveIDsMemRatio safely sets the Configuration value for state's 'Cache.StatusFaveIDsMemRatio' field +func (st *ConfigState) SetCacheStatusFaveIDsMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.StatusFaveIDsMemRatio = v + st.reloadToViper() +} + +// CacheStatusFaveIDsMemRatioFlag returns the flag name for the 'Cache.StatusFaveIDsMemRatio' field +func CacheStatusFaveIDsMemRatioFlag() string { return "cache-status-fave-ids-mem-ratio" } + +// GetCacheStatusFaveIDsMemRatio safely fetches the value for global configuration 'Cache.StatusFaveIDsMemRatio' field +func GetCacheStatusFaveIDsMemRatio() float64 { return global.GetCacheStatusFaveIDsMemRatio() } + +// SetCacheStatusFaveIDsMemRatio safely sets the value for global configuration 'Cache.StatusFaveIDsMemRatio' field +func SetCacheStatusFaveIDsMemRatio(v float64) { global.SetCacheStatusFaveIDsMemRatio(v) } + // GetCacheTagMemRatio safely fetches the Configuration value for state's 'Cache.TagMemRatio' field func (st *ConfigState) GetCacheTagMemRatio() (v float64) { st.mutex.RLock() diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 25b773dfa..c6091e2c9 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -20,7 +20,6 @@ import ( "container/list" "context" - "database/sql" "errors" "time" @@ -96,6 +95,26 @@ func(status *gtsmodel.Status) error { ) } +func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) { + return s.getStatus( + ctx, + "BoostOfID.AccountID", + func(status *gtsmodel.Status) error { + return s.newStatusQ(status). + Where("status.boost_of_id = ?", boostOfID). + Where("status.account_id = ?", byAccountID). + + // Our old code actually allowed a status to + // be boosted multiple times by the same author, + // so limit our query + order to fetch latest. + Order("status.id DESC"). // our IDs are timestamped + Limit(1). + Scan(ctx) + }, + boostOfID, byAccountID, + ) +} + func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) { // Fetch status from database cache with loader callback status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) { @@ -245,11 +264,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) } } - if err := errs.Combine(); err != nil { - return gtserror.Newf("%w", err) - } - - return nil + return errs.Combine() } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error { @@ -506,25 +521,17 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu } func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - var childIDs []string - - q := s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Column("status.id"). - Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID) - if minID != "" { - q = q.Where("? > ?", bun.Ident("status.id"), minID) - } - - if err := q.Scan(ctx, &childIDs); err != nil { - if err != sql.ErrNoRows { - log.Errorf(ctx, "error getting children for %q: %v", status.ID, err) - } + childIDs, err := s.getStatusReplyIDs(ctx, status.ID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + log.Errorf(ctx, "error getting status %s children: %v", status.ID, err) return } for _, id := range childIDs { + if id <= minID { + continue + } + // Fetch child with ID from database child, err := s.GetStatusByID(ctx, id) if err != nil { @@ -553,48 +560,80 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, } } -func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, error) { - return s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID). - Count(ctx) +func (s *statusDB) GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) { + statusIDs, err := s.getStatusReplyIDs(ctx, statusID) + if err != nil { + return nil, err + } + return s.GetStatusesByIDs(ctx, statusIDs) } -func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, error) { - return s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). - Count(ctx) +func (s *statusDB) CountStatusReplies(ctx context.Context, statusID string) (int, error) { + statusIDs, err := s.getStatusReplyIDs(ctx, statusID) + return len(statusIDs), err } -func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, error) { - return s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). - Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). - Count(ctx) +func (s *statusDB) getStatusReplyIDs(ctx context.Context, statusID string) ([]string, error) { + return s.state.Caches.GTS.InReplyToIDs().Load(statusID, func() ([]string, error) { + var statusIDs []string + + // Status reply IDs not in cache, perform DB query! + if err := s.db. + NewSelect(). + Table("statuses"). + Column("id"). + Where("? = ?", bun.Ident("in_reply_to_id"), statusID). + Order("id DESC"). + Scan(ctx, &statusIDs); err != nil { + return nil, s.db.ProcessError(err) + } + + return statusIDs, nil + }) } -func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { - q := s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). - Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). - Where("? = ?", bun.Ident("status_fave.account_id"), accountID) - - return s.db.Exists(ctx, q) +func (s *statusDB) GetStatusBoosts(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) { + statusIDs, err := s.getStatusBoostIDs(ctx, statusID) + if err != nil { + return nil, err + } + return s.GetStatusesByIDs(ctx, statusIDs) } -func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { - q := s.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). - Where("? = ?", bun.Ident("status.account_id"), accountID) +func (s *statusDB) IsStatusBoostedBy(ctx context.Context, statusID string, accountID string) (bool, error) { + boost, err := s.GetStatusBoost( + gtscontext.SetBarebones(ctx), + statusID, + accountID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return false, err + } + return (boost != nil), nil +} - return s.db.Exists(ctx, q) +func (s *statusDB) CountStatusBoosts(ctx context.Context, statusID string) (int, error) { + statusIDs, err := s.getStatusBoostIDs(ctx, statusID) + return len(statusIDs), err +} + +func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]string, error) { + return s.state.Caches.GTS.BoostOfIDs().Load(statusID, func() ([]string, error) { + var statusIDs []string + + // Status boost IDs not in cache, perform DB query! + if err := s.db. + NewSelect(). + Table("statuses"). + Column("id"). + Where("? = ?", bun.Ident("boost_of_id"), statusID). + Order("id DESC"). + Scan(ctx, &statusIDs); err != nil { + return nil, s.db.ProcessError(err) + } + + return statusIDs, nil + }) } func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { @@ -616,16 +655,3 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St return s.db.Exists(ctx, q) } - -func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) { - reblogs := []*gtsmodel.Status{} - - q := s. - newStatusQ(&reblogs). - Where("? = ?", bun.Ident("status.boost_of_id"), status.ID) - - if err := q.Scan(ctx); err != nil { - return nil, s.db.ProcessError(err) - } - return reblogs, nil -} diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go index 7aff543fd..ab09fb1ba 100644 --- a/internal/db/bundb/statusfave.go +++ b/internal/db/bundb/statusfave.go @@ -19,6 +19,7 @@ import ( "context" + "database/sql" "errors" "fmt" @@ -44,8 +45,14 @@ func(fave *gtsmodel.StatusFave) error { return s.db. NewSelect(). Model(fave). - Where("? = ?", bun.Ident("account_id"), accountID). - Where("? = ?", bun.Ident("status_id"), statusID). + Where("status_fave.account_id = ?", accountID). + Where("status_fave.status_id = ?", statusID). + + // Our old code actually allowed a status to + // be faved multiple times by the same author, + // so limit our query + order to fetch latest. + Order("status_fave.id DESC"). // our IDs are timestamped + Limit(1). Scan(ctx) }, accountID, @@ -89,63 +96,68 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery return fave, nil } - // Fetch the status fave author account. - fave.Account, err = s.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - fave.AccountID, - ) - if err != nil { - return nil, fmt.Errorf("error getting status fave account %q: %w", fave.AccountID, err) - } - - // Fetch the status fave target account. - fave.TargetAccount, err = s.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - fave.TargetAccountID, - ) - if err != nil { - return nil, fmt.Errorf("error getting status fave target account %q: %w", fave.TargetAccountID, err) - } - - // Fetch the status fave target status. - fave.Status, err = s.state.DB.GetStatusByID( - gtscontext.SetBarebones(ctx), - fave.StatusID, - ) - if err != nil { - return nil, fmt.Errorf("error getting status fave status %q: %w", fave.StatusID, err) + // Populate the status favourite model. + if err := s.PopulateStatusFave(ctx, fave); err != nil { + return nil, fmt.Errorf("error(s) populating status fave: %w", err) } return fave, nil } -func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) { - ids := []string{} - - if err := s.db. - NewSelect(). - Table("status_faves"). - Column("id"). - Where("? = ?", bun.Ident("status_id"), statusID). - Scan(ctx, &ids); err != nil { - return nil, s.db.ProcessError(err) +func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) { + // Fetch the status fave IDs for status. + faveIDs, err := s.getStatusFaveIDs(ctx, statusID) + if err != nil { + return nil, err } - faves := make([]*gtsmodel.StatusFave, 0, len(ids)) + // Preallocate a slice of expected status fave capacity. + faves := make([]*gtsmodel.StatusFave, 0, len(faveIDs)) - for _, id := range ids { + for _, id := range faveIDs { + // Fetch status fave model for each ID. fave, err := s.GetStatusFaveByID(ctx, id) if err != nil { log.Errorf(ctx, "error getting status fave %q: %v", id, err) continue } - faves = append(faves, fave) } return faves, nil } +func (s *statusFaveDB) IsStatusFavedBy(ctx context.Context, statusID string, accountID string) (bool, error) { + fave, err := s.GetStatusFave(ctx, accountID, statusID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return false, err + } + return (fave != nil), nil +} + +func (s *statusFaveDB) CountStatusFaves(ctx context.Context, statusID string) (int, error) { + faveIDs, err := s.getStatusFaveIDs(ctx, statusID) + return len(faveIDs), err +} + +func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([]string, error) { + return s.state.Caches.GTS.StatusFaveIDs().Load(statusID, func() ([]string, error) { + var faveIDs []string + + // Status fave IDs not in cache, perform DB query! + if err := s.db. + NewSelect(). + Table("status_faves"). + Column("id"). + Where("? = ?", bun.Ident("status_id"), statusID). + Scan(ctx, &faveIDs); err != nil { + return nil, s.db.ProcessError(err) + } + + return faveIDs, nil + }) +} + func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error { var ( err error @@ -203,26 +215,32 @@ func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusF } func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) error { - defer s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + var statusID string - // Load fave into cache before attempting a delete, - // as we need it cached in order to trigger the invalidate - // callback. This in turn invalidates others. - _, err := s.GetStatusFaveByID(gtscontext.SetBarebones(ctx), id) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - // not an issue. + // Perform DELETE on status fave, + // returning the status ID it was for. + if _, err := s.db.NewDelete(). + Table("status_faves"). + Where("id = ?", id). + Returning("status_id"). + Exec(ctx, &statusID); err != nil { + if err == sql.ErrNoRows { + // Not an issue, only due + // to us doing a RETURNING. err = nil } - return err + return s.db.ProcessError(err) } - // Finally delete fave from DB. - _, err = s.db.NewDelete(). - Table("status_faves"). - Where("? = ?", bun.Ident("id"), id). - Exec(ctx) - return s.db.ProcessError(err) + if statusID != "" { + // Invalidate any cached status faves for this status. + s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + + // Invalidate any cached status fave IDs for this status. + s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) + } + + return nil } func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error { @@ -230,12 +248,13 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set") } - var faveIDs []string + var statusIDs []string - q := s.db. - NewSelect(). - Column("id"). - Table("status_faves") + // Prepare DELETE query returning + // the deleted faves for status IDs. + q := s.db.NewDelete(). + Table("status_faves"). + Returning("status_id") if targetAccountID != "" { q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID) @@ -245,69 +264,46 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st q = q.Where("? = ?", bun.Ident("account_id"), originAccountID) } - if _, err := q.Exec(ctx, &faveIDs); err != nil { + // Execute query, store favourited status IDs. + if _, err := q.Exec(ctx, &statusIDs); err != nil { + if err == sql.ErrNoRows { + // Not an issue, only due + // to us doing a RETURNING. + err = nil + } return s.db.ProcessError(err) } - defer func() { - // Invalidate all IDs on return. - for _, id := range faveIDs { - s.state.Caches.GTS.StatusFave().Invalidate("ID", id) - } - }() + // Collate (deduplicating) status IDs. + statusIDs = collate(func(i int) string { + return statusIDs[i] + }, len(statusIDs)) - // Load all faves into cache, this *really* isn't great - // but it is the only way we can ensure we invalidate all - // related caches correctly (e.g. visibility). - for _, id := range faveIDs { - _, err := s.GetStatusFaveByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } + for _, id := range statusIDs { + // Invalidate any cached status faves for this status. + s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + + // Invalidate any cached status fave IDs for this status. + s.state.Caches.GTS.StatusFaveIDs().Invalidate(id) } - // Finally delete all from DB. - _, err := s.db.NewDelete(). - Table("status_faves"). - Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)). - Exec(ctx) - return s.db.ProcessError(err) + return nil } func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) error { - // Capture fave IDs in a RETURNING statement. - var faveIDs []string - - q := s.db. - NewSelect(). - Column("id"). + // Delete all status faves for status. + if _, err := s.db.NewDelete(). Table("status_faves"). - Where("? = ?", bun.Ident("status_id"), statusID) - if _, err := q.Exec(ctx, &faveIDs); err != nil { + Where("status_id = ?", statusID). + Exec(ctx); err != nil { return s.db.ProcessError(err) } - defer func() { - // Invalidate all IDs on return. - for _, id := range faveIDs { - s.state.Caches.GTS.StatusFave().Invalidate("ID", id) - } - }() + // Invalidate any cached status faves for this status. + s.state.Caches.GTS.StatusFave().Invalidate("ID", statusID) - // Load all faves into cache, this *really* isn't great - // but it is the only way we can ensure we invalidate all - // related caches correctly (e.g. visibility). - for _, id := range faveIDs { - _, err := s.GetStatusFaveByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } - } + // Invalidate any cached status fave IDs for this status. + s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) - // Finally delete all from DB. - _, err := s.db.NewDelete(). - Table("status_faves"). - Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)). - Exec(ctx) - return s.db.ProcessError(err) + return nil } diff --git a/internal/db/bundb/statusfave_test.go b/internal/db/bundb/statusfave_test.go index 7218390bc..9c99d795b 100644 --- a/internal/db/bundb/statusfave_test.go +++ b/internal/db/bundb/statusfave_test.go @@ -35,7 +35,7 @@ type StatusFaveTestSuite struct { func (suite *StatusFaveTestSuite) TestGetStatusFaves() { testStatus := suite.testStatuses["admin_account_status_1"] - faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID) + faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID) if err != nil { suite.FailNow(err.Error()) } @@ -51,7 +51,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaves() { func (suite *StatusFaveTestSuite) TestGetStatusFavesNone() { testStatus := suite.testStatuses["admin_account_status_4"] - faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID) + faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID) if err != nil { suite.FailNow(err.Error()) } diff --git a/internal/db/media.go b/internal/db/media.go index 66fa258fe..94a365c26 100644 --- a/internal/db/media.go +++ b/internal/db/media.go @@ -41,10 +41,10 @@ type Media interface { // DeleteAttachment deletes the attachment with given ID from the database. DeleteAttachment(ctx context.Context, id string) error - // GetAttachments ... + // GetAttachments fetches media attachments up to a given max ID, and at most limit. GetAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) - // GetRemoteAttachments ... + // GetRemoteAttachments fetches media attachments with a non-empty domain, up to a given max ID, and at most limit. GetRemoteAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) // GetCachedAttachmentsOlderThan gets limit n remote attachments (including avatars and headers) older than diff --git a/internal/db/status.go b/internal/db/status.go index 6f9848f57..f4421fa2e 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -34,6 +34,9 @@ type Status interface { // GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, error) + // GetStatusBoost fetches the status whose boost_of_id column refers to boostOfID, authored by given account ID. + GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) + // PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc). PopulateStatus(ctx context.Context, status *gtsmodel.Status) error @@ -46,21 +49,27 @@ type Status interface { // DeleteStatusByID deletes one status from the database. DeleteStatusByID(ctx context.Context, id string) error - // CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong - CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, error) - - // CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong - CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, error) - - // CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong - CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, error) - // GetStatuses gets a slice of statuses corresponding to the given status IDs. GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error) // GetStatusesUsingEmoji fetches all status models using emoji with given ID stored in their 'emojis' column. GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Status, error) + // GetStatusReplies returns the *direct* (i.e. in_reply_to_id column) replies to this status ID. + GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) + + // CountStatusReplies returns the number of stored *direct* (i.e. in_reply_to_id column) replies to this status ID. + CountStatusReplies(ctx context.Context, statusID string) (int, error) + + // GetStatusBoosts returns all statuses whose boost_of_id column refer to given status ID. + GetStatusBoosts(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) + + // CountStatusBoosts returns the number of stored boosts for status ID. + CountStatusBoosts(ctx context.Context, statusID string) (int, error) + + // IsStatusBoostedBy checks whether the given status ID is boosted by account ID. + IsStatusBoostedBy(ctx context.Context, statusID string, accountID string) (bool, error) + // GetStatusParents gets the parent statuses of a given status. // // If onlyDirect is true, only the immediate parent will be returned. @@ -71,19 +80,9 @@ type Status interface { // If onlyDirect is true, only the immediate children will be returned. GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) - // IsStatusFavedBy checks if a given status has been faved by a given account ID - IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) - - // IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID - IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) - // IsStatusMutedBy checks if a given status has been muted by a given account ID IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) - - // GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status. - // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) } diff --git a/internal/db/statusfave.go b/internal/db/statusfave.go index 37769ff79..343a80caa 100644 --- a/internal/db/statusfave.go +++ b/internal/db/statusfave.go @@ -24,16 +24,15 @@ ) type StatusFave interface { - // GetStatusFaveByAccountID gets one status fave created by the given - // accountID, targeting the given statusID. + // GetStatusFaveByAccountID gets one status fave created by the given accountID, targeting the given statusID. GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) // GetStatusFave returns one status fave with the given id. GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, error) - // GetStatusFaves returns a slice of faves/likes of the given status. + // GetStatusFaves returns a slice of faves/likes of the status with given ID. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) + GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) // PopulateStatusFave ensures that all sub-models of a fave are populated (account, status, etc). PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error @@ -59,8 +58,13 @@ type StatusFave interface { // At least one parameter must not be an empty string. DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error - // DeleteStatusFavesForStatus deletes all status faves that target the - // given status ID. This is useful when a status has been deleted, and you need - // to clean up after it. + // DeleteStatusFavesForStatus deletes all status faves that target the given status ID. + // This is useful when a status has been deleted, and you need to clean up after it. DeleteStatusFavesForStatus(ctx context.Context, statusID string) error + + // CountStatusFaves returns the number of status favourites registered for status with ID. + CountStatusFaves(ctx context.Context, statusID string) (int, error) + + // IsStatusFavedBy returns whether the status with ID has been favourited by account with ID. + IsStatusFavedBy(ctx context.Context, statusID string, accountID string) (bool, error) } diff --git a/internal/gtserror/multi.go b/internal/gtserror/multi.go index 1c533b285..3d39333b6 100644 --- a/internal/gtserror/multi.go +++ b/internal/gtserror/multi.go @@ -19,16 +19,13 @@ import ( "errors" - "fmt" ) // MultiError allows encapsulating multiple // errors under a singular instance, which // is useful when you only want to log on // errors, not return early / bubble up. -type MultiError struct { - e []error -} +type MultiError []error // NewMultiError returns a *MultiError with // the capacity of its underlying error slice @@ -40,15 +37,13 @@ type MultiError struct { // // If you don't know in advance what the capacity // must be, just use new(MultiError) instead. -func NewMultiError(capacity int) *MultiError { - return &MultiError{ - e: make([]error, 0, capacity), - } +func NewMultiError(capacity int) MultiError { + return make([]error, 0, capacity) } // Append the given error to the MultiError. func (m *MultiError) Append(err error) { - m.e = append(m.e, err) + (*m) = append((*m), err) } // Append the given format string to the MultiError. @@ -56,12 +51,13 @@ func (m *MultiError) Append(err error) { // It is valid to use %w in the format string // to wrap any other errors. func (m *MultiError) Appendf(format string, args ...any) { - m.e = append(m.e, fmt.Errorf(format, args...)) + err := newfAt(3, format, args...) + (*m) = append((*m), err) } // Combine the MultiError into a single error. // // Unwrap will work on the returned error as expected. func (m MultiError) Combine() error { - return errors.Join(m.e...) + return errors.Join(m...) } diff --git a/internal/gtserror/multi_test.go b/internal/gtserror/multi_test.go index 9c16c1a53..10c342415 100644 --- a/internal/gtserror/multi_test.go +++ b/internal/gtserror/multi_test.go @@ -15,22 +15,22 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package gtserror +package gtserror_test import ( "errors" "testing" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" ) func TestMultiError(t *testing.T) { - errs := MultiError{ - e: []error{ - db.ErrNoEntries, - errors.New("oopsie woopsie we did a fucky wucky etc"), - }, - } + errs := gtserror.MultiError([]error{ + db.ErrNoEntries, + errors.New("oopsie woopsie we did a fucky wucky etc"), + }) + errs.Appendf("appended + wrapped error: %w", db.ErrAlreadyExists) err := errs.Combine() @@ -50,14 +50,14 @@ func TestMultiError(t *testing.T) { errString := err.Error() expected := `sql: no rows in result set oopsie woopsie we did a fucky wucky etc -appended + wrapped error: already exists` +TestMultiError: appended + wrapped error: already exists` if errString != expected { t.Errorf("errString '%s' should be '%s'", errString, expected) } } func TestMultiErrorEmpty(t *testing.T) { - err := new(MultiError).Combine() + err := new(gtserror.MultiError).Combine() if err != nil { t.Errorf("should be nil") } diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index a613ba485..dd5957531 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -330,7 +330,7 @@ func (p *Processor) deleteAccountStatuses(ctx context.Context, account *gtsmodel }) // Look for any boosts of this status in DB. - boosts, err := p.state.DB.GetStatusReblogs(ctx, status) + boosts, err := p.state.DB.GetStatusBoosts(ctx, status.ID) if err != nil && !errors.Is(err, db.ErrNoEntries) { return gtserror.Newf("error fetching status reblogs for %s: %w", status.ID, err) } diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 030ff506c..07895b6ba 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -380,6 +380,8 @@ func (p *Processor) notify( // wipeStatus contains common logic used to totally delete a status // + all its attachments, notifications, boosts, and timeline entries. func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Status, deleteAttachments bool) error { + var errs gtserror.MultiError + // either delete all attachments for this status, or simply // unattach all attachments for this status, so they'll be // cleaned later by a separate process; reason to unattach rather @@ -389,14 +391,14 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta // todo: p.state.DB.DeleteAttachmentsForStatus for _, a := range statusToDelete.AttachmentIDs { if err := p.media.Delete(ctx, a); err != nil { - return err + errs.Appendf("error deleting media: %w", err) } } } else { // todo: p.state.DB.UnattachAttachmentsForStatus for _, a := range statusToDelete.AttachmentIDs { if _, err := p.media.Unattach(ctx, statusToDelete.Account, a); err != nil { - return err + errs.Appendf("error unattaching media: %w", err) } } } @@ -405,44 +407,55 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta // todo: p.state.DB.DeleteMentionsForStatus for _, id := range statusToDelete.MentionIDs { if err := p.state.DB.DeleteMentionByID(ctx, id); err != nil { - return err + errs.Appendf("error deleting status mention: %w", err) } } // delete all notification entries generated by this status if err := p.state.DB.DeleteNotificationsForStatus(ctx, statusToDelete.ID); err != nil { - return err + errs.Appendf("error deleting status notifications: %w", err) } // delete all bookmarks that point to this status if err := p.state.DB.DeleteStatusBookmarksForStatus(ctx, statusToDelete.ID); err != nil { - return err + errs.Appendf("error deleting status bookmarks: %w", err) } // delete all faves of this status if err := p.state.DB.DeleteStatusFavesForStatus(ctx, statusToDelete.ID); err != nil { - return err + errs.Appendf("error deleting status faves: %w", err) } // delete all boosts for this status + remove them from timelines - if boosts, err := p.state.DB.GetStatusReblogs(ctx, statusToDelete); err == nil { - for _, b := range boosts { - if err := p.deleteStatusFromTimelines(ctx, b.ID); err != nil { - return err - } - if err := p.state.DB.DeleteStatusByID(ctx, b.ID); err != nil { - return err - } + boosts, err := p.state.DB.GetStatusBoosts( + // we MUST set a barebones context here, + // as depending on where it came from the + // original BoostOf may already be gone. + gtscontext.SetBarebones(ctx), + statusToDelete.ID) + if err != nil { + errs.Appendf("error fetching status boosts: %w", err) + } + for _, b := range boosts { + if err := p.deleteStatusFromTimelines(ctx, b.ID); err != nil { + errs.Appendf("error deleting boost from timelines: %w", err) + } + if err := p.state.DB.DeleteStatusByID(ctx, b.ID); err != nil { + errs.Appendf("error deleting boost: %w", err) } } // delete this status from any and all timelines if err := p.deleteStatusFromTimelines(ctx, statusToDelete.ID); err != nil { - return err + errs.Appendf("error deleting status from timelines: %w", err) } - // delete the status itself - return p.state.DB.DeleteStatusByID(ctx, statusToDelete.ID) + // finally, delete the status itself + if err := p.state.DB.DeleteStatusByID(ctx, statusToDelete.ID); err != nil { + errs.Appendf("error deleting status: %w", err) + } + + return errs.Combine() } // deleteStatusFromTimelines completely removes the given status from all timelines. diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go index e5d38d9d2..eccd81886 100644 --- a/internal/processing/status/boost.go +++ b/internal/processing/status/boost.go @@ -106,47 +106,24 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel return nil, gtserror.NewErrorNotFound(errors.New("status is not visible")) } - // check if we actually have a boost for this status - var toUnboost bool - - gtsBoost := >smodel.Status{} - where := []db.Where{ - { - Key: "boost_of_id", - Value: targetStatusID, - }, - { - Key: "account_id", - Value: requestingAccount.ID, - }, - } - err = p.state.DB.GetWhere(ctx, where, gtsBoost) - if err == nil { - // we have a boost - toUnboost = true - } - + // Check whether the requesting account has boosted the given status ID. + boost, err := p.state.DB.GetStatusBoost(ctx, targetStatusID, requestingAccount.ID) if err != nil { - // something went wrong in the db finding the boost - if err != db.ErrNoEntries { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("error fetching existing boost from database: %s", err)) - } - // we just don't have a boost - toUnboost = false + return nil, gtserror.NewErrorNotFound(fmt.Errorf("error checking status boost %s: %w", targetStatusID, err)) } - if toUnboost { + if boost != nil { // pin some stuff onto the boost while we have it out of the db - gtsBoost.Account = requestingAccount - gtsBoost.BoostOf = targetStatus - gtsBoost.BoostOfAccount = targetStatus.Account - gtsBoost.BoostOf.Account = targetStatus.Account + boost.Account = requestingAccount + boost.BoostOf = targetStatus + boost.BoostOfAccount = targetStatus.Account + boost.BoostOf.Account = targetStatus.Account // send it back to the processor for async processing p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityAnnounce, APActivityType: ap.ActivityUndo, - GTSModel: gtsBoost, + GTSModel: boost, OriginAccount: requestingAccount, TargetAccount: targetStatus.Account, }) @@ -189,15 +166,15 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm return nil, gtserror.NewErrorNotFound(err) } - statusReblogs, err := p.state.DB.GetStatusReblogs(ctx, targetStatus) + statusBoosts, err := p.state.DB.GetStatusBoosts(ctx, targetStatus.ID) if err != nil { err = fmt.Errorf("BoostedBy: error seeing who boosted status: %s", err) return nil, gtserror.NewErrorNotFound(err) } // filter account IDs so the user doesn't see accounts they blocked or which blocked them - accountIDs := make([]string, 0, len(statusReblogs)) - for _, s := range statusReblogs { + accountIDs := make([]string, 0, len(statusBoosts)) + for _, s := range statusBoosts { blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, s.AccountID) if err != nil { err = fmt.Errorf("BoostedBy: error checking blocks: %s", err) diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go index 77d3f67e9..9da243312 100644 --- a/internal/processing/status/fave.go +++ b/internal/processing/status/fave.go @@ -112,7 +112,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc return nil, errWithCode } - statusFaves, err := p.state.DB.GetStatusFavesForStatus(ctx, targetStatus.ID) + statusFaves, err := p.state.DB.GetStatusFaves(ctx, targetStatus.ID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("FavedBy: error seeing who faved status: %s", err)) } diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index 8ad1681d0..2dc0e4dd5 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -600,17 +600,17 @@ func (c *converter) StatusToAPIStatus(ctx context.Context, s *gtsmodel.Status, r return nil, fmt.Errorf("error converting status author: %w", err) } - repliesCount, err := c.db.CountStatusReplies(ctx, s) + repliesCount, err := c.db.CountStatusReplies(ctx, s.ID) if err != nil { return nil, fmt.Errorf("error counting replies: %w", err) } - reblogsCount, err := c.db.CountStatusReblogs(ctx, s) + reblogsCount, err := c.db.CountStatusBoosts(ctx, s.ID) if err != nil { return nil, fmt.Errorf("error counting reblogs: %w", err) } - favesCount, err := c.db.CountStatusFaves(ctx, s) + favesCount, err := c.db.CountStatusFaves(ctx, s.ID) if err != nil { return nil, fmt.Errorf("error counting faves: %w", err) } diff --git a/internal/typeutils/util.go b/internal/typeutils/util.go index 0100200dc..42387afec 100644 --- a/internal/typeutils/util.go +++ b/internal/typeutils/util.go @@ -40,13 +40,13 @@ func (c *converter) interactionsWithStatusForAccount(ctx context.Context, s *gts si := &statusInteractions{} if requestingAccount != nil { - faved, err := c.db.IsStatusFavedBy(ctx, s, requestingAccount.ID) + faved, err := c.db.IsStatusFavedBy(ctx, s.ID, requestingAccount.ID) if err != nil { return nil, fmt.Errorf("error checking if requesting account has faved status: %s", err) } si.Faved = faved - reblogged, err := c.db.IsStatusRebloggedBy(ctx, s, requestingAccount.ID) + reblogged, err := c.db.IsStatusBoostedBy(ctx, s.ID, requestingAccount.ID) if err != nil { return nil, fmt.Errorf("error checking if requesting account has reblogged status: %s", err) } diff --git a/test/envparsing.sh b/test/envparsing.sh index f75b6fd3f..59c69e1b5 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -21,12 +21,14 @@ EXPECT=$(cat << "EOF" "account-mem-ratio": 18, "account-note-mem-ratio": 0.1, "block-mem-ratio": 3, + "boost-of-ids-mem-ratio": 3, "emoji-category-mem-ratio": 0.1, "emoji-mem-ratio": 3, "follow-ids-mem-ratio": 4, "follow-mem-ratio": 4, "follow-request-ids-mem-ratio": 2, "follow-request-mem-ratio": 2, + "in-reply-to-ids-mem-ratio": 3, "instance-mem-ratio": 1, "list-entry-mem-ratio": 3, "list-mem-ratio": 3, @@ -36,6 +38,7 @@ EXPECT=$(cat << "EOF" "mention-mem-ratio": 5, "notification-mem-ratio": 5, "report-mem-ratio": 1, + "status-fave-ids-mem-ratio": 3, "status-fave-mem-ratio": 5, "status-mem-ratio": 18, "tag-mem-ratio": 3, diff --git a/vendor/codeberg.org/gruf/go-cache/v3/result/cache.go b/vendor/codeberg.org/gruf/go-cache/v3/result/cache.go index f31e6604a..665481d55 100644 --- a/vendor/codeberg.org/gruf/go-cache/v3/result/cache.go +++ b/vendor/codeberg.org/gruf/go-cache/v3/result/cache.go @@ -11,28 +11,7 @@ "codeberg.org/gruf/go-errors/v2" ) -type result struct { - // Result primary key - PKey int64 - - // keys accessible under - Keys cacheKeys - - // cached value - Value any - - // cached error - Error error -} - -// getResultValue is a safe way of casting and fetching result value. -func getResultValue[T any](res *result) T { - v, ok := res.Value.(T) - if !ok { - fmt.Fprintf(os.Stderr, "!! BUG: unexpected value type in result: %T\n", res.Value) - } - return v -} +var ErrUnsupportedZero = errors.New("") // Lookup represents a struct object lookup method in the cache. type Lookup struct { @@ -255,13 +234,15 @@ func (c *Cache[T]) Load(lookup string, load func() (T, error), keyParts ...any) evict = c.store(res) } - // Catch and return error - if res.Error != nil { - return zero, res.Error + // Catch and return cached error + if err := res.Error; err != nil { + return zero, err } - // Return a copy of value from cache - return c.copy(getResultValue[T](res)), nil + // Copy value from cached result. + v := c.copy(getResultValue[T](res)) + + return v, nil } // Store will call the given store function, and on success store the value in the cache as a positive result. @@ -332,11 +313,13 @@ func (c *Cache[T]) Has(lookup string, keyParts ...any) bool { } } + // Check for result AND non-error result. + ok := (res != nil && res.Error == nil) + // Done with lock c.cache.Unlock() - // Check for result AND non-error result. - return (res != nil && res.Error == nil) + return ok } // Invalidate will invalidate any result from the cache found under given lookup and key parts. @@ -407,13 +390,18 @@ func (c *Cache[T]) store(res *result) (evict func()) { key.info.pkeys[key.key] = pkeys } + // Acquire new cache entry. + entry := simple.GetEntry() + entry.Key = res.PKey + entry.Value = res + + evictFn := func(_ int64, entry *simple.Entry) { + // on evict during set, store evicted result. + toEvict = append(toEvict, entry.Value.(*result)) + } + // Store main entry under primary key, catch evicted. - c.cache.Cache.SetWithHook(res.PKey, &simple.Entry{ - Key: res.PKey, - Value: res, - }, func(_ int64, item *simple.Entry) { - toEvict = append(toEvict, item.Value.(*result)) - }) + c.cache.Cache.SetWithHook(res.PKey, entry, evictFn) if len(toEvict) == 0 { // none evicted. @@ -421,9 +409,35 @@ func (c *Cache[T]) store(res *result) (evict func()) { } return func() { - for _, res := range toEvict { + for i := range toEvict { + // Rescope result. + res := toEvict[i] + // Call evict hook on each entry. c.cache.Evict(res.PKey, res) } } } + +type result struct { + // Result primary key + PKey int64 + + // keys accessible under + Keys cacheKeys + + // cached value + Value any + + // cached error + Error error +} + +// getResultValue is a safe way of casting and fetching result value. +func getResultValue[T any](res *result) T { + v, ok := res.Value.(T) + if !ok { + fmt.Fprintf(os.Stderr, "!! BUG: unexpected value type in result: %T\n", res.Value) + } + return v +} diff --git a/vendor/codeberg.org/gruf/go-cache/v3/result/key.go b/vendor/codeberg.org/gruf/go-cache/v3/result/key.go index cf86c7c30..5e10e6fa1 100644 --- a/vendor/codeberg.org/gruf/go-cache/v3/result/key.go +++ b/vendor/codeberg.org/gruf/go-cache/v3/result/key.go @@ -47,27 +47,32 @@ func (sk structKeys) generate(a any) []cacheKey { buf := getBuf() defer putBuf(buf) +outer: for i := range sk { // Reset buffer - buf.B = buf.B[:0] + buf.Reset() // Append each field value to buffer. for _, field := range sk[i].fields { fv := v.Field(field.index) fi := fv.Interface() - buf.B = field.mangle(buf.B, fi) + + // Mangle this key part into buffer. + ok := field.manglePart(buf, fi) + + if !ok { + // don't generate keys + // for zero value parts. + continue outer + } + + // Append part separator. buf.B = append(buf.B, '.') } // Drop last '.' buf.Truncate(1) - // Don't generate keys for zero values - if allowZero := sk[i].zero == ""; // nocollapse - !allowZero && buf.String() == sk[i].zero { - continue - } - // Append new cached key to slice keys = append(keys, cacheKey{ info: &sk[i], @@ -114,14 +119,6 @@ type structKey struct { // period ('.') separated struct field names. name string - // zero is the possible zero value for this key. - // if set, this will _always_ be non-empty, as - // the mangled cache key will never be empty. - // - // i.e. zero = "" --> allow zero value keys - // zero != "" --> don't allow zero value keys - zero string - // unique determines whether this structKey supports // multiple or just the singular unique result. unique bool @@ -135,47 +132,10 @@ type structKey struct { pkeys map[string][]int64 } -type structField struct { - // index is the reflect index of this struct field. - index int - - // mangle is the mangler function for - // serializing values of this struct field. - mangle mangler.Mangler -} - -// genKey generates a cache key string for given key parts (i.e. serializes them using "go-mangler"). -func (sk *structKey) genKey(parts []any) string { - // Check this expected no. key parts. - if len(parts) != len(sk.fields) { - panic(fmt.Sprintf("incorrect no. key parts provided: want=%d received=%d", len(parts), len(sk.fields))) - } - - // Acquire byte buffer - buf := getBuf() - defer putBuf(buf) - buf.Reset() - - // Encode each key part - for i, part := range parts { - buf.B = sk.fields[i].mangle(buf.B, part) - buf.B = append(buf.B, '.') - } - - // Drop last '.' - buf.Truncate(1) - - // Return string copy - return string(buf.B) -} - // newStructKey will generate a structKey{} information object for user-given lookup // key information, and the receiving generic paramter's type information. Panics on error. func newStructKey(lk Lookup, t reflect.Type) structKey { - var ( - sk structKey - zeros []any - ) + var sk structKey // Set the lookup name sk.name = lk.Name @@ -183,9 +143,6 @@ func newStructKey(lk Lookup, t reflect.Type) structKey { // Split dot-separated lookup to get // the individual struct field names names := strings.Split(lk.Name, ".") - if len(names) == 0 { - panic("no key fields specified") - } // Allocate the mangler and field indices slice. sk.fields = make([]structField, len(names)) @@ -213,16 +170,12 @@ func newStructKey(lk Lookup, t reflect.Type) structKey { sk.fields[i].mangle = mangler.Get(ft.Type) if !lk.AllowZero { - // Append the zero value interface - zeros = append(zeros, v.Interface()) + // Append the mangled zero value interface + zero := sk.fields[i].mangle(nil, v.Interface()) + sk.fields[i].zero = string(zero) } } - if len(zeros) > 0 { - // Generate zero value string - sk.zero = sk.genKey(zeros) - } - // Set unique lookup flag. sk.unique = !lk.Multi @@ -232,6 +185,68 @@ func newStructKey(lk Lookup, t reflect.Type) structKey { return sk } +// genKey generates a cache key string for given key parts (i.e. serializes them using "go-mangler"). +func (sk *structKey) genKey(parts []any) string { + // Check this expected no. key parts. + if len(parts) != len(sk.fields) { + panic(fmt.Sprintf("incorrect no. key parts provided: want=%d received=%d", len(parts), len(sk.fields))) + } + + // Acquire byte buffer + buf := getBuf() + defer putBuf(buf) + buf.Reset() + + for i, part := range parts { + // Mangle this key part into buffer. + // specifically ignoring whether this + // is returning a zero value key part. + _ = sk.fields[i].manglePart(buf, part) + + // Append part separator. + buf.B = append(buf.B, '.') + } + + // Drop last '.' + buf.Truncate(1) + + // Return string copy + return string(buf.B) +} + +type structField struct { + // index is the reflect index of this struct field. + index int + + // zero is the possible zero value for this + // key part. if set, this will _always_ be + // non-empty due to how the mangler works. + // + // i.e. zero = "" --> allow zero value keys + // zero != "" --> don't allow zero value keys + zero string + + // mangle is the mangler function for + // serializing values of this struct field. + mangle mangler.Mangler +} + +// manglePart ... +func (field *structField) manglePart(buf *byteutil.Buffer, part any) bool { + // Start of part bytes. + start := len(buf.B) + + // Mangle this key part into buffer. + buf.B = field.mangle(buf.B, part) + + // End of part bytes. + end := len(buf.B) + + // Return whether this is zero value. + return (field.zero == "" || + string(buf.B[start:end]) != field.zero) +} + // isExported checks whether function name is exported. func isExported(fnName string) bool { r, _ := utf8.DecodeRuneInString(fnName) @@ -246,12 +261,12 @@ func isExported(fnName string) bool { }, } -// getBuf ... +// getBuf acquires a byte buffer from memory pool. func getBuf() *byteutil.Buffer { return bufPool.Get().(*byteutil.Buffer) } -// putBuf ... +// putBuf replaces a byte buffer back in memory pool. func putBuf(buf *byteutil.Buffer) { if buf.Cap() > int(^uint16(0)) { return // drop large bufs diff --git a/vendor/codeberg.org/gruf/go-cache/v3/simple/cache.go b/vendor/codeberg.org/gruf/go-cache/v3/simple/cache.go index 0224871bc..1452a0648 100644 --- a/vendor/codeberg.org/gruf/go-cache/v3/simple/cache.go +++ b/vendor/codeberg.org/gruf/go-cache/v3/simple/cache.go @@ -102,7 +102,7 @@ func (c *Cache[K, V]) Add(key K, value V) bool { } // Alloc new entry. - new := getEntry() + new := GetEntry() new.Key = key new.Value = value @@ -111,7 +111,7 @@ func (c *Cache[K, V]) Add(key K, value V) bool { evcK = item.Key.(K) evcV = item.Value.(V) ev = true - putEntry(item) + PutEntry(item) }) // Set hook func ptr. @@ -161,7 +161,7 @@ func (c *Cache[K, V]) Set(key K, value V) { item.Value = value } else { // Alloc new entry. - new := getEntry() + new := GetEntry() new.Key = key new.Value = value @@ -170,7 +170,7 @@ func (c *Cache[K, V]) Set(key K, value V) { evcK = item.Key.(K) evcV = item.Value.(V) ev = true - putEntry(item) + PutEntry(item) }) } @@ -311,7 +311,7 @@ func (c *Cache[K, V]) Invalidate(key K) (ok bool) { _ = c.Cache.Delete(key) // Free entry - putEntry(item) + PutEntry(item) // Set hook func ptrs. invalid = c.Invalid @@ -367,7 +367,7 @@ func (c *Cache[K, V]) InvalidateAll(keys ...K) (ok bool) { invalid(k, v) // Free this entry. - putEntry(items[x]) + PutEntry(items[x]) } } @@ -410,7 +410,7 @@ func (c *Cache[K, V]) Trim(perc float64) { invalid(k, v) // Free this entry. - putEntry(items[x]) + PutEntry(items[x]) } } } @@ -438,7 +438,7 @@ func (c *Cache[K, V]) locked(fn func()) { func (c *Cache[K, V]) truncate(sz int, hook func(K, V)) []*Entry { if hook == nil { // No hook to execute, simply release all truncated entries. - c.Cache.Truncate(sz, func(_ K, item *Entry) { putEntry(item) }) + c.Cache.Truncate(sz, func(_ K, item *Entry) { PutEntry(item) }) return nil } diff --git a/vendor/codeberg.org/gruf/go-cache/v3/simple/pool.go b/vendor/codeberg.org/gruf/go-cache/v3/simple/pool.go index 2fc99ab0f..34ae17546 100644 --- a/vendor/codeberg.org/gruf/go-cache/v3/simple/pool.go +++ b/vendor/codeberg.org/gruf/go-cache/v3/simple/pool.go @@ -6,8 +6,8 @@ // objects, regardless of cache type. var entryPool sync.Pool -// getEntry fetches an Entry from pool, or allocates new. -func getEntry() *Entry { +// GetEntry fetches an Entry from pool, or allocates new. +func GetEntry() *Entry { v := entryPool.Get() if v == nil { return new(Entry) @@ -15,8 +15,8 @@ func getEntry() *Entry { return v.(*Entry) } -// putEntry replaces an Entry in the pool. -func putEntry(e *Entry) { +// PutEntry replaces an Entry in the pool. +func PutEntry(e *Entry) { e.Key = nil e.Value = nil entryPool.Put(e) diff --git a/vendor/modules.txt b/vendor/modules.txt index a711fcba7..54425e3d3 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -13,7 +13,7 @@ codeberg.org/gruf/go-bytesize # codeberg.org/gruf/go-byteutil v1.1.2 ## explicit; go 1.16 codeberg.org/gruf/go-byteutil -# codeberg.org/gruf/go-cache/v3 v3.5.3 +# codeberg.org/gruf/go-cache/v3 v3.5.5 ## explicit; go 1.19 codeberg.org/gruf/go-cache/v3 codeberg.org/gruf/go-cache/v3/result