[feature/performance] Store account stats in separate table (#2831)

* [feature/performance] Store account stats in separate table, get stats from remote

* test account stats

* add some missing increment / decrement calls

* change stats function signatures

* rejig logging a bit

* use lock when updating stats
This commit is contained in:
tobi 2024-04-16 13:10:13 +02:00 committed by GitHub
parent f79d50b9b2
commit 3cceed11b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 1285 additions and 450 deletions

View file

@ -105,6 +105,15 @@ func (iter *regularCollectionIterator) PrevItem() TypeOrIRI {
return cur return cur
} }
func (iter *regularCollectionIterator) TotalItems() int {
totalItems := iter.GetActivityStreamsTotalItems()
if totalItems == nil || !totalItems.IsXMLSchemaNonNegativeInteger() {
return -1
}
return totalItems.Get()
}
func (iter *regularCollectionIterator) initItems() bool { func (iter *regularCollectionIterator) initItems() bool {
if iter.once { if iter.once {
return (iter.items != nil) return (iter.items != nil)
@ -147,6 +156,15 @@ func (iter *orderedCollectionIterator) PrevItem() TypeOrIRI {
return cur return cur
} }
func (iter *orderedCollectionIterator) TotalItems() int {
totalItems := iter.GetActivityStreamsTotalItems()
if totalItems == nil || !totalItems.IsXMLSchemaNonNegativeInteger() {
return -1
}
return totalItems.Get()
}
func (iter *orderedCollectionIterator) initItems() bool { func (iter *orderedCollectionIterator) initItems() bool {
if iter.once { if iter.once {
return (iter.items != nil) return (iter.items != nil)
@ -203,6 +221,15 @@ func (iter *regularCollectionPageIterator) PrevItem() TypeOrIRI {
return cur return cur
} }
func (iter *regularCollectionPageIterator) TotalItems() int {
totalItems := iter.GetActivityStreamsTotalItems()
if totalItems == nil || !totalItems.IsXMLSchemaNonNegativeInteger() {
return -1
}
return totalItems.Get()
}
func (iter *regularCollectionPageIterator) initItems() bool { func (iter *regularCollectionPageIterator) initItems() bool {
if iter.once { if iter.once {
return (iter.items != nil) return (iter.items != nil)
@ -259,6 +286,15 @@ func (iter *orderedCollectionPageIterator) PrevItem() TypeOrIRI {
return cur return cur
} }
func (iter *orderedCollectionPageIterator) TotalItems() int {
totalItems := iter.GetActivityStreamsTotalItems()
if totalItems == nil || !totalItems.IsXMLSchemaNonNegativeInteger() {
return -1
}
return totalItems.Get()
}
func (iter *orderedCollectionPageIterator) initItems() bool { func (iter *orderedCollectionPageIterator) initItems() bool {
if iter.once { if iter.once {
return (iter.items != nil) return (iter.items != nil)

View file

@ -307,6 +307,12 @@ type CollectionIterator interface {
NextItem() TypeOrIRI NextItem() TypeOrIRI
PrevItem() TypeOrIRI PrevItem() TypeOrIRI
// TotalItems returns the total items
// present in the collection, derived
// from the totalItems property, or -1
// if totalItems not present / readable.
TotalItems() int
} }
// CollectionPageIterator represents the minimum interface for interacting with a wrapped // CollectionPageIterator represents the minimum interface for interacting with a wrapped
@ -319,6 +325,12 @@ type CollectionPageIterator interface {
NextItem() TypeOrIRI NextItem() TypeOrIRI
PrevItem() TypeOrIRI PrevItem() TypeOrIRI
// TotalItems returns the total items
// present in the collection, derived
// from the totalItems property, or -1
// if totalItems not present / readable.
TotalItems() int
} }
// Flaggable represents the minimum interface for an activitystreams 'Flag' activity. // Flaggable represents the minimum interface for an activitystreams 'Flag' activity.

View file

@ -48,11 +48,12 @@ func (suite *StatusPinTestSuite) createPin(
expectedHTTPStatus int, expectedHTTPStatus int,
expectedBody string, expectedBody string,
targetStatusID string, targetStatusID string,
requestingAcct *gtsmodel.Account,
) (*apimodel.Status, error) { ) (*apimodel.Status, error) {
// instantiate recorder + test context // instantiate recorder + test context
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx, _ := testrig.CreateGinTestContext(recorder, nil) ctx, _ := testrig.CreateGinTestContext(recorder, nil)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) ctx.Set(oauth.SessionAuthorizedAccount, requestingAcct)
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
@ -101,8 +102,10 @@ func (suite *StatusPinTestSuite) createPin(
func (suite *StatusPinTestSuite) TestPinStatusPublicOK() { func (suite *StatusPinTestSuite) TestPinStatusPublicOK() {
// Pin an unpinned public status that this account owns. // Pin an unpinned public status that this account owns.
targetStatus := suite.testStatuses["local_account_1_status_1"] targetStatus := suite.testStatuses["local_account_1_status_1"]
testAccount := new(gtsmodel.Account)
*testAccount = *suite.testAccounts["local_account_1"]
resp, err := suite.createPin(http.StatusOK, "", targetStatus.ID) resp, err := suite.createPin(http.StatusOK, "", targetStatus.ID, testAccount)
if err != nil { if err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
@ -113,8 +116,10 @@ func (suite *StatusPinTestSuite) TestPinStatusPublicOK() {
func (suite *StatusPinTestSuite) TestPinStatusFollowersOnlyOK() { func (suite *StatusPinTestSuite) TestPinStatusFollowersOnlyOK() {
// Pin an unpinned followers only status that this account owns. // Pin an unpinned followers only status that this account owns.
targetStatus := suite.testStatuses["local_account_1_status_5"] targetStatus := suite.testStatuses["local_account_1_status_5"]
testAccount := new(gtsmodel.Account)
*testAccount = *suite.testAccounts["local_account_1"]
resp, err := suite.createPin(http.StatusOK, "", targetStatus.ID) resp, err := suite.createPin(http.StatusOK, "", targetStatus.ID, testAccount)
if err != nil { if err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
@ -127,6 +132,8 @@ func (suite *StatusPinTestSuite) TestPinStatusTwiceError() {
targetStatus := &gtsmodel.Status{} targetStatus := &gtsmodel.Status{}
*targetStatus = *suite.testStatuses["local_account_1_status_5"] *targetStatus = *suite.testStatuses["local_account_1_status_5"]
targetStatus.PinnedAt = time.Now() targetStatus.PinnedAt = time.Now()
testAccount := new(gtsmodel.Account)
*testAccount = *suite.testAccounts["local_account_1"]
if err := suite.db.UpdateStatus(context.Background(), targetStatus, "pinned_at"); err != nil { if err := suite.db.UpdateStatus(context.Background(), targetStatus, "pinned_at"); err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
@ -136,6 +143,7 @@ func (suite *StatusPinTestSuite) TestPinStatusTwiceError() {
http.StatusUnprocessableEntity, http.StatusUnprocessableEntity,
`{"error":"Unprocessable Entity: status already pinned"}`, `{"error":"Unprocessable Entity: status already pinned"}`,
targetStatus.ID, targetStatus.ID,
testAccount,
); err != nil { ); err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
@ -144,11 +152,14 @@ func (suite *StatusPinTestSuite) TestPinStatusTwiceError() {
func (suite *StatusPinTestSuite) TestPinStatusOtherAccountError() { func (suite *StatusPinTestSuite) TestPinStatusOtherAccountError() {
// Try to pin a status that doesn't belong to us. // Try to pin a status that doesn't belong to us.
targetStatus := suite.testStatuses["admin_account_status_1"] targetStatus := suite.testStatuses["admin_account_status_1"]
testAccount := new(gtsmodel.Account)
*testAccount = *suite.testAccounts["local_account_1"]
if _, err := suite.createPin( if _, err := suite.createPin(
http.StatusUnprocessableEntity, http.StatusUnprocessableEntity,
`{"error":"Unprocessable Entity: status 01F8MH75CBF9JFX4ZAD54N0W0R does not belong to account 01F8MH1H7YV1Z7D2C8K2730QBF"}`, `{"error":"Unprocessable Entity: status 01F8MH75CBF9JFX4ZAD54N0W0R does not belong to account 01F8MH1H7YV1Z7D2C8K2730QBF"}`,
targetStatus.ID, targetStatus.ID,
testAccount,
); err != nil { ); err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
@ -156,7 +167,8 @@ func (suite *StatusPinTestSuite) TestPinStatusOtherAccountError() {
func (suite *StatusPinTestSuite) TestPinStatusTooManyPins() { func (suite *StatusPinTestSuite) TestPinStatusTooManyPins() {
// Test pinning too many statuses. // Test pinning too many statuses.
testAccount := suite.testAccounts["local_account_1"] testAccount := new(gtsmodel.Account)
*testAccount = *suite.testAccounts["local_account_1"]
// Spam 10 pinned statuses into the database. // Spam 10 pinned statuses into the database.
ctx := context.Background() ctx := context.Background()
@ -181,12 +193,18 @@ func (suite *StatusPinTestSuite) TestPinStatusTooManyPins() {
} }
} }
// Regenerate account stats to set pinned count.
if err := suite.db.RegenerateAccountStats(ctx, testAccount); err != nil {
suite.FailNow(err.Error())
}
// Try to pin one more status as a treat. // Try to pin one more status as a treat.
targetStatus := suite.testStatuses["local_account_1_status_1"] targetStatus := suite.testStatuses["local_account_1_status_1"]
if _, err := suite.createPin( if _, err := suite.createPin(
http.StatusUnprocessableEntity, http.StatusUnprocessableEntity,
`{"error":"Unprocessable Entity: status pin limit exceeded, you've already pinned 10 status(es) out of 10"}`, `{"error":"Unprocessable Entity: status pin limit exceeded, you've already pinned 10 status(es) out of 10"}`,
targetStatus.ID, targetStatus.ID,
testAccount,
); err != nil { ); err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }

View file

@ -52,9 +52,9 @@ func (c *Caches) Init() {
log.Infof(nil, "init: %p", c) log.Infof(nil, "init: %p", c)
c.initAccount() c.initAccount()
c.initAccountCounts()
c.initAccountNote() c.initAccountNote()
c.initAccountSettings() c.initAccountSettings()
c.initAccountStats()
c.initApplication() c.initApplication()
c.initBlock() c.initBlock()
c.initBlockIDs() c.initBlockIDs()
@ -124,6 +124,7 @@ func (c *Caches) Sweep(threshold float64) {
c.GTS.Account.Trim(threshold) c.GTS.Account.Trim(threshold)
c.GTS.AccountNote.Trim(threshold) c.GTS.AccountNote.Trim(threshold)
c.GTS.AccountSettings.Trim(threshold) c.GTS.AccountSettings.Trim(threshold)
c.GTS.AccountStats.Trim(threshold)
c.GTS.Block.Trim(threshold) c.GTS.Block.Trim(threshold)
c.GTS.BlockIDs.Trim(threshold) c.GTS.BlockIDs.Trim(threshold)
c.GTS.Emoji.Trim(threshold) c.GTS.Emoji.Trim(threshold)

51
internal/cache/db.go vendored
View file

@ -20,7 +20,6 @@
import ( import (
"time" "time"
"codeberg.org/gruf/go-cache/v3/simple"
"codeberg.org/gruf/go-cache/v3/ttl" "codeberg.org/gruf/go-cache/v3/ttl"
"codeberg.org/gruf/go-structr" "codeberg.org/gruf/go-structr"
"github.com/superseriousbusiness/gotosocial/internal/cache/domain" "github.com/superseriousbusiness/gotosocial/internal/cache/domain"
@ -36,16 +35,12 @@ type GTSCaches struct {
// AccountNote provides access to the gtsmodel Note database cache. // AccountNote provides access to the gtsmodel Note database cache.
AccountNote StructCache[*gtsmodel.AccountNote] AccountNote StructCache[*gtsmodel.AccountNote]
// TEMPORARY CACHE TO ALLEVIATE SLOW COUNT QUERIES,
// (in time will be removed when these IDs are cached).
AccountCounts *simple.Cache[string, struct {
Statuses int
Pinned int
}]
// AccountSettings provides access to the gtsmodel AccountSettings database cache. // AccountSettings provides access to the gtsmodel AccountSettings database cache.
AccountSettings StructCache[*gtsmodel.AccountSettings] AccountSettings StructCache[*gtsmodel.AccountSettings]
// AccountStats provides access to the gtsmodel AccountStats database cache.
AccountStats StructCache[*gtsmodel.AccountStats]
// Application provides access to the gtsmodel Application database cache. // Application provides access to the gtsmodel Application database cache.
Application StructCache[*gtsmodel.Application] Application StructCache[*gtsmodel.Application]
@ -200,6 +195,7 @@ func (c *Caches) initAccount() {
a2.AlsoKnownAs = nil a2.AlsoKnownAs = nil
a2.Move = nil a2.Move = nil
a2.Settings = nil a2.Settings = nil
a2.Stats = nil
return a2 return a2
} }
@ -223,22 +219,6 @@ func (c *Caches) initAccount() {
}) })
} }
func (c *Caches) initAccountCounts() {
// Simply use size of accounts cache,
// as this cache will be very small.
cap := c.GTS.Account.Cap()
if cap == 0 {
panic("must be initialized before accounts")
}
log.Infof(nil, "cache size = %d", cap)
c.GTS.AccountCounts = simple.New[string, struct {
Statuses int
Pinned int
}](0, cap)
}
func (c *Caches) initAccountNote() { func (c *Caches) initAccountNote() {
// Calculate maximum cache size. // Calculate maximum cache size.
cap := calculateResultCacheMax( cap := calculateResultCacheMax(
@ -295,6 +275,29 @@ func (c *Caches) initAccountSettings() {
}) })
} }
func (c *Caches) initAccountStats() {
// Calculate maximum cache size.
cap := calculateResultCacheMax(
sizeofAccountStats(), // model in-mem size.
config.GetCacheAccountStatsMemRatio(),
)
log.Infof(nil, "cache size = %d", cap)
c.GTS.AccountStats.Init(structr.CacheConfig[*gtsmodel.AccountStats]{
Indices: []structr.IndexConfig{
{Fields: "AccountID"},
},
MaxSize: cap,
IgnoreErr: ignoreErrors,
Copy: func(s1 *gtsmodel.AccountStats) *gtsmodel.AccountStats {
s2 := new(gtsmodel.AccountStats)
*s2 = *s1
return s2
},
})
}
func (c *Caches) initApplication() { func (c *Caches) initApplication() {
// Calculate maximum cache size. // Calculate maximum cache size.
cap := calculateResultCacheMax( cap := calculateResultCacheMax(

View file

@ -27,8 +27,8 @@
// HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE. // HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE.
func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) { func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) {
// Invalidate status counts for this account. // Invalidate stats for this account.
c.GTS.AccountCounts.Invalidate(account.ID) c.GTS.AccountStats.Invalidate("AccountID", account.ID)
// Invalidate account ID cached visibility. // Invalidate account ID cached visibility.
c.Visibility.Invalidate("ItemID", account.ID) c.Visibility.Invalidate("ItemID", account.ID)
@ -168,8 +168,8 @@ func (c *Caches) OnInvalidatePollVote(vote *gtsmodel.PollVote) {
} }
func (c *Caches) OnInvalidateStatus(status *gtsmodel.Status) { func (c *Caches) OnInvalidateStatus(status *gtsmodel.Status) {
// Invalidate status counts for this account. // Invalidate stats for this account.
c.GTS.AccountCounts.Invalidate(status.AccountID) c.GTS.AccountStats.Invalidate("AccountID", status.AccountID)
// Invalidate status ID cached visibility. // Invalidate status ID cached visibility.
c.Visibility.Invalidate("ItemID", status.ID) c.Visibility.Invalidate("ItemID", status.ID)

View file

@ -264,6 +264,17 @@ func sizeofAccountSettings() uintptr {
})) }))
} }
func sizeofAccountStats() uintptr {
return uintptr(size.Of(&gtsmodel.AccountStats{
AccountID: exampleID,
FollowersCount: util.Ptr(100),
FollowingCount: util.Ptr(100),
StatusesCount: util.Ptr(100),
StatusesPinnedCount: util.Ptr(100),
LastStatusAt: exampleTime,
}))
}
func sizeofApplication() uintptr { func sizeofApplication() uintptr {
return uintptr(size.Of(&gtsmodel.Application{ return uintptr(size.Of(&gtsmodel.Application{
ID: exampleID, ID: exampleID,

View file

@ -195,6 +195,7 @@ type CacheConfiguration struct {
AccountMemRatio float64 `name:"account-mem-ratio"` AccountMemRatio float64 `name:"account-mem-ratio"`
AccountNoteMemRatio float64 `name:"account-note-mem-ratio"` AccountNoteMemRatio float64 `name:"account-note-mem-ratio"`
AccountSettingsMemRatio float64 `name:"account-settings-mem-ratio"` AccountSettingsMemRatio float64 `name:"account-settings-mem-ratio"`
AccountStatsMemRatio float64 `name:"account-stats-mem-ratio"`
ApplicationMemRatio float64 `name:"application-mem-ratio"` ApplicationMemRatio float64 `name:"application-mem-ratio"`
BlockMemRatio float64 `name:"block-mem-ratio"` BlockMemRatio float64 `name:"block-mem-ratio"`
BlockIDsMemRatio float64 `name:"block-mem-ratio"` BlockIDsMemRatio float64 `name:"block-mem-ratio"`

View file

@ -159,6 +159,7 @@
AccountMemRatio: 5, AccountMemRatio: 5,
AccountNoteMemRatio: 1, AccountNoteMemRatio: 1,
AccountSettingsMemRatio: 0.1, AccountSettingsMemRatio: 0.1,
AccountStatsMemRatio: 2,
ApplicationMemRatio: 0.1, ApplicationMemRatio: 0.1,
BlockMemRatio: 2, BlockMemRatio: 2,
BlockIDsMemRatio: 3, BlockIDsMemRatio: 3,

View file

@ -2825,6 +2825,31 @@ func GetCacheAccountSettingsMemRatio() float64 { return global.GetCacheAccountSe
// SetCacheAccountSettingsMemRatio safely sets the value for global configuration 'Cache.AccountSettingsMemRatio' field // SetCacheAccountSettingsMemRatio safely sets the value for global configuration 'Cache.AccountSettingsMemRatio' field
func SetCacheAccountSettingsMemRatio(v float64) { global.SetCacheAccountSettingsMemRatio(v) } func SetCacheAccountSettingsMemRatio(v float64) { global.SetCacheAccountSettingsMemRatio(v) }
// GetCacheAccountStatsMemRatio safely fetches the Configuration value for state's 'Cache.AccountStatsMemRatio' field
func (st *ConfigState) GetCacheAccountStatsMemRatio() (v float64) {
st.mutex.RLock()
v = st.config.Cache.AccountStatsMemRatio
st.mutex.RUnlock()
return
}
// SetCacheAccountStatsMemRatio safely sets the Configuration value for state's 'Cache.AccountStatsMemRatio' field
func (st *ConfigState) SetCacheAccountStatsMemRatio(v float64) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.AccountStatsMemRatio = v
st.reloadToViper()
}
// CacheAccountStatsMemRatioFlag returns the flag name for the 'Cache.AccountStatsMemRatio' field
func CacheAccountStatsMemRatioFlag() string { return "cache-account-stats-mem-ratio" }
// GetCacheAccountStatsMemRatio safely fetches the value for global configuration 'Cache.AccountStatsMemRatio' field
func GetCacheAccountStatsMemRatio() float64 { return global.GetCacheAccountStatsMemRatio() }
// SetCacheAccountStatsMemRatio safely sets the value for global configuration 'Cache.AccountStatsMemRatio' field
func SetCacheAccountStatsMemRatio(v float64) { global.SetCacheAccountStatsMemRatio(v) }
// GetCacheApplicationMemRatio safely fetches the Configuration value for state's 'Cache.ApplicationMemRatio' field // GetCacheApplicationMemRatio safely fetches the Configuration value for state's 'Cache.ApplicationMemRatio' field
func (st *ConfigState) GetCacheApplicationMemRatio() (v float64) { func (st *ConfigState) GetCacheApplicationMemRatio() (v float64) {
st.mutex.RLock() st.mutex.RLock()

View file

@ -20,7 +20,6 @@
import ( import (
"context" "context"
"net/netip" "net/netip"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
@ -100,12 +99,6 @@ type Account interface {
// GetAccountsUsingEmoji fetches all account models using emoji with given ID stored in their 'emojis' column. // GetAccountsUsingEmoji fetches all account models using emoji with given ID stored in their 'emojis' column.
GetAccountsUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Account, error) GetAccountsUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Account, error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
CountAccountStatuses(ctx context.Context, accountID string) (int, error)
// CountAccountPinned returns the total number of pinned statuses owned by account with the given id.
CountAccountPinned(ctx context.Context, accountID string) (int, error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this! // be very memory intensive so you probably shouldn't do this!
@ -128,13 +121,6 @@ type Account interface {
// In the case of no statuses, this function will return db.ErrNoEntries. // In the case of no statuses, this function will return db.ErrNoEntries.
GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned.
//
// The returned time will be zero if account has never posted anything.
GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
@ -150,4 +136,24 @@ type Account interface {
// Update local account settings. // Update local account settings.
UpdateAccountSettings(ctx context.Context, settings *gtsmodel.AccountSettings, columns ...string) error UpdateAccountSettings(ctx context.Context, settings *gtsmodel.AccountSettings, columns ...string) error
// PopulateAccountStats gets (or creates and gets) account stats for
// the given account, and attaches them to the account model.
PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error
// RegenerateAccountStats creates, upserts, and returns stats
// for the given account, and attaches them to the account model.
//
// Unlike GetAccountStats, it will always get the database stats fresh.
// This can be used to "refresh" stats.
//
// Because this involves database calls that can be expensive (on Postgres
// specifically), callers should prefer GetAccountStats in 99% of cases.
RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error
// Update account stats.
UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error
// DeleteAccountStats deletes the accountStats entry for the given accountID.
DeleteAccountStats(ctx context.Context, accountID string) error
} }

View file

@ -630,6 +630,13 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
} }
} }
if account.Stats == nil {
// Get / Create stats for this account.
if err := a.state.DB.PopulateAccountStats(ctx, account); err != nil {
errs.Appendf("error populating account stats: %w", err)
}
}
return errs.Combine() return errs.Combine()
} }
@ -735,31 +742,6 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
}) })
} }
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) {
createdAt := time.Time{}
q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.created_at").
Where("? = ?", bun.Ident("status.account_id"), accountID).
Order("status.id DESC").
Limit(1)
if webOnly {
q = q.
Where("? IS NULL", bun.Ident("status.in_reply_to_uri")).
Where("? IS NULL", bun.Ident("status.boost_of_id")).
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
Where("? = ?", bun.Ident("status.federated"), true)
}
if err := q.Scan(ctx, &createdAt); err != nil {
return time.Time{}, err
}
return createdAt, nil
}
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error { func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
if *mediaAttachment.Avatar && *mediaAttachment.Header { if *mediaAttachment.Avatar && *mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar") return errors.New("one media attachment cannot be both header and avatar")
@ -845,59 +827,6 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
return *faves, nil return *faves, nil
} }
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) {
counts, err := a.getAccountStatusCounts(ctx, accountID)
return counts.Statuses, err
}
func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) {
counts, err := a.getAccountStatusCounts(ctx, accountID)
return counts.Pinned, err
}
func (a *accountDB) getAccountStatusCounts(ctx context.Context, accountID string) (struct {
Statuses int
Pinned int
}, error) {
// Check for an already cached copy of account status counts.
counts, ok := a.state.Caches.GTS.AccountCounts.Get(accountID)
if ok {
return counts, nil
}
if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
var err error
// Scan database for account statuses.
counts.Statuses, err = tx.NewSelect().
Table("statuses").
Where("? = ?", bun.Ident("account_id"), accountID).
Count(ctx)
if err != nil {
return err
}
// Scan database for pinned statuses.
counts.Pinned, err = tx.NewSelect().
Table("statuses").
Where("? = ?", bun.Ident("account_id"), accountID).
Where("? IS NOT NULL", bun.Ident("pinned_at")).
Count(ctx)
if err != nil {
return err
}
return nil
}); err != nil {
return counts, err
}
// Store this account counts result in the cache.
a.state.Caches.GTS.AccountCounts.Set(accountID, counts)
return counts, nil
}
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) { func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable // Ensure reasonable
if limit < 0 { if limit < 0 {
@ -1147,3 +1076,185 @@ func (a *accountDB) UpdateAccountSettings(
return nil return nil
}) })
} }
func (a *accountDB) PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error {
// Fetch stats from db cache with loader callback.
stats, err := a.state.Caches.GTS.AccountStats.LoadOne(
"AccountID",
func() (*gtsmodel.AccountStats, error) {
// Not cached! Perform database query.
var stats gtsmodel.AccountStats
if err := a.db.
NewSelect().
Model(&stats).
Where("? = ?", bun.Ident("account_stats.account_id"), account.ID).
Scan(ctx); err != nil {
return nil, err
}
return &stats, nil
},
account.ID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Real error.
return err
}
if stats == nil {
// Don't have stats yet, generate them.
return a.RegenerateAccountStats(ctx, account)
}
// We have a stats, attach
// it to the account.
account.Stats = stats
// Check if this is a local
// stats by looking at the
// account they pertain to.
if account.IsRemote() {
// Account is remote. Updating
// stats for remote accounts is
// handled in the dereferencer.
//
// Nothing more to do!
return nil
}
// Stats account is local, check
// if we need to regenerate.
const statsFreshness = 48 * time.Hour
expiry := stats.RegeneratedAt.Add(statsFreshness)
if time.Now().After(expiry) {
// Stats have expired, regenerate them.
return a.RegenerateAccountStats(ctx, account)
}
// Stats are still fresh.
return nil
}
func (a *accountDB) RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error {
// Initialize a new stats struct.
stats := &gtsmodel.AccountStats{
AccountID: account.ID,
RegeneratedAt: time.Now(),
}
// Count followers outside of transaction since
// it uses a cache + requires its own db calls.
followerIDs, err := a.state.DB.GetAccountFollowerIDs(ctx, account.ID, nil)
if err != nil {
return err
}
stats.FollowersCount = util.Ptr(len(followerIDs))
// Count following outside of transaction since
// it uses a cache + requires its own db calls.
followIDs, err := a.state.DB.GetAccountFollowIDs(ctx, account.ID, nil)
if err != nil {
return err
}
stats.FollowingCount = util.Ptr(len(followIDs))
// Count follow requests outside of transaction since
// it uses a cache + requires its own db calls.
followRequestIDs, err := a.state.DB.GetAccountFollowRequestIDs(ctx, account.ID, nil)
if err != nil {
return err
}
stats.FollowRequestsCount = util.Ptr(len(followRequestIDs))
// Populate remaining stats struct fields.
// This can be done inside a transaction.
if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
var err error
// Scan database for account statuses.
statusesCount, err := tx.NewSelect().
Table("statuses").
Where("? = ?", bun.Ident("account_id"), account.ID).
Count(ctx)
if err != nil {
return err
}
stats.StatusesCount = &statusesCount
// Scan database for pinned statuses.
statusesPinnedCount, err := tx.NewSelect().
Table("statuses").
Where("? = ?", bun.Ident("account_id"), account.ID).
Where("? IS NOT NULL", bun.Ident("pinned_at")).
Count(ctx)
if err != nil {
return err
}
stats.StatusesPinnedCount = &statusesPinnedCount
// Scan database for last status.
lastStatusAt := time.Time{}
err = tx.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.created_at").
Where("? = ?", bun.Ident("status.account_id"), account.ID).
Order("status.id DESC").
Limit(1).
Scan(ctx, &lastStatusAt)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
stats.LastStatusAt = lastStatusAt
return nil
}); err != nil {
return err
}
// Upsert this stats in case a race
// meant someone else inserted it first.
if err := a.state.Caches.GTS.AccountStats.Store(stats, func() error {
if _, err := NewUpsert(a.db).
Model(stats).
Constraint("account_id").
Exec(ctx); err != nil {
return err
}
return nil
}); err != nil {
return err
}
account.Stats = stats
return nil
}
func (a *accountDB) UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error {
return a.state.Caches.GTS.AccountStats.Store(stats, func() error {
if _, err := a.db.
NewUpdate().
Model(stats).
Column(columns...).
Where("? = ?", bun.Ident("account_stats.account_id"), stats.AccountID).
Exec(ctx); err != nil {
return err
}
return nil
})
}
func (a *accountDB) DeleteAccountStats(ctx context.Context, accountID string) error {
defer a.state.Caches.GTS.AccountStats.Invalidate("AccountID", accountID)
if _, err := a.db.
NewDelete().
Table("account_stats").
Where("? = ?", bun.Ident("account_id"), accountID).
Exec(ctx); err != nil {
return err
}
return nil
}

View file

@ -220,6 +220,8 @@ func (suite *AccountTestSuite) TestGetAccountBy() {
a2.Emojis = nil a2.Emojis = nil
a1.Settings = nil a1.Settings = nil
a2.Settings = nil a2.Settings = nil
a1.Stats = nil
a2.Stats = nil
// Clear database-set fields. // Clear database-set fields.
a1.CreatedAt = time.Time{} a1.CreatedAt = time.Time{}
@ -413,18 +415,6 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second) suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second)
} }
func (suite *AccountTestSuite) TestGetAccountLastPosted() {
lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false)
suite.NoError(err)
suite.EqualValues(1702200240, lastPosted.Unix())
}
func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() {
lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, true)
suite.NoError(err)
suite.EqualValues(1702200240, lastPosted.Unix())
}
func (suite *AccountTestSuite) TestInsertAccountWithDefaults() { func (suite *AccountTestSuite) TestInsertAccountWithDefaults() {
key, err := rsa.GenerateKey(rand.Reader, 2048) key, err := rsa.GenerateKey(rand.Reader, 2048)
suite.NoError(err) suite.NoError(err)
@ -466,22 +456,6 @@ func (suite *AccountTestSuite) TestGetAccountPinnedStatusesNothingPinned() {
suite.Empty(statuses) // This account has nothing pinned. suite.Empty(statuses) // This account has nothing pinned.
} }
func (suite *AccountTestSuite) TestCountAccountPinnedSomeResults() {
testAccount := suite.testAccounts["admin_account"]
pinned, err := suite.db.CountAccountPinned(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Equal(pinned, 2) // This account has 2 statuses pinned.
}
func (suite *AccountTestSuite) TestCountAccountPinnedNothingPinned() {
testAccount := suite.testAccounts["local_account_1"]
pinned, err := suite.db.CountAccountPinned(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Equal(pinned, 0) // This account has nothing pinned.
}
func (suite *AccountTestSuite) TestPopulateAccountWithUnknownMovedToURI() { func (suite *AccountTestSuite) TestPopulateAccountWithUnknownMovedToURI() {
testAccount := &gtsmodel.Account{} testAccount := &gtsmodel.Account{}
*testAccount = *suite.testAccounts["local_account_1"] *testAccount = *suite.testAccounts["local_account_1"]
@ -676,6 +650,55 @@ func (suite *AccountTestSuite) TestGetPendingAccounts() {
suite.Len(accounts, 1) suite.Len(accounts, 1)
} }
func (suite *AccountTestSuite) TestAccountStatsAll() {
ctx := context.Background()
for _, account := range suite.testAccounts {
// Get stats for the first time. They
// should all be generated now since
// they're not stored in the test rig.
if err := suite.db.PopulateAccountStats(ctx, account); err != nil {
suite.FailNow(err.Error())
}
stats := account.Stats
suite.NotNil(stats)
suite.WithinDuration(time.Now(), stats.RegeneratedAt, 5*time.Second)
// Get stats a second time. They shouldn't
// be regenerated since we just did it.
if err := suite.db.PopulateAccountStats(ctx, account); err != nil {
suite.FailNow(err.Error())
}
stats2 := account.Stats
suite.NotNil(stats2)
suite.Equal(stats2.RegeneratedAt, stats.RegeneratedAt)
// Update the stats to indicate they're out of date.
stats2.RegeneratedAt = time.Now().Add(-72 * time.Hour)
if err := suite.db.UpdateAccountStats(ctx, stats2, "regenerated_at"); err != nil {
suite.FailNow(err.Error())
}
// Get stats for a third time, they
// should get regenerated now, but
// only for local accounts.
if err := suite.db.PopulateAccountStats(ctx, account); err != nil {
suite.FailNow(err.Error())
}
stats3 := account.Stats
suite.NotNil(stats3)
if account.IsLocal() {
suite.True(stats3.RegeneratedAt.After(stats.RegeneratedAt))
} else {
suite.False(stats3.RegeneratedAt.After(stats.RegeneratedAt))
}
// Now delete the stats.
if err := suite.db.DeleteAccountStats(ctx, account.ID); err != nil {
suite.FailNow(err.Error())
}
}
}
func TestAccountTestSuite(t *testing.T) { func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite)) suite.Run(t, new(AccountTestSuite))
} }

View file

@ -0,0 +1,52 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Create new AccountStats table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.AccountStats{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View file

@ -112,7 +112,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
} }
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) {
followIDs, err := r.getAccountFollowIDs(ctx, accountID, page) followIDs, err := r.GetAccountFollowIDs(ctx, accountID, page)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -120,7 +120,7 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
} }
func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) followIDs, err := r.GetAccountLocalFollowIDs(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -128,7 +128,7 @@ func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID s
} }
func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) {
followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, page) followerIDs, err := r.GetAccountFollowerIDs(ctx, accountID, page)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -136,7 +136,7 @@ func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID stri
} }
func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) followerIDs, err := r.GetAccountLocalFollowerIDs(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -144,7 +144,7 @@ func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID
} }
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) {
followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, page) followReqIDs, err := r.GetAccountFollowRequestIDs(ctx, accountID, page)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -152,7 +152,7 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID
} }
func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) {
followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, page) followReqIDs, err := r.GetAccountFollowRequestingIDs(ctx, accountID, page)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -160,49 +160,14 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account
} }
func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) {
blockIDs, err := r.getAccountBlockIDs(ctx, accountID, page) blockIDs, err := r.GetAccountBlockIDs(ctx, accountID, page)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r.GetBlocksByIDs(ctx, blockIDs) return r.GetBlocksByIDs(ctx, blockIDs)
} }
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) GetAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
followIDs, err := r.getAccountFollowIDs(ctx, accountID, nil)
return len(followIDs), err
}
func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID)
return len(followIDs), err
}
func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, nil)
return len(followerIDs), err
}
func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID)
return len(followerIDs), err
}
func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, nil)
return len(followReqIDs), err
}
func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, nil)
return len(followReqIDs), err
}
func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) {
blockIDs, err := r.getAccountBlockIDs(ctx, accountID, nil)
return len(blockIDs), err
}
func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) {
var followIDs []string var followIDs []string
@ -217,7 +182,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri
}) })
} }
func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { func (r *relationshipDB) GetAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) {
return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) { return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) {
var followIDs []string var followIDs []string
@ -232,7 +197,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
}) })
} }
func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { func (r *relationshipDB) GetAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) {
var followIDs []string var followIDs []string
@ -247,7 +212,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st
}) })
} }
func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { func (r *relationshipDB) GetAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) {
return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) { return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) {
var followIDs []string var followIDs []string
@ -262,7 +227,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
}) })
} }
func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { func (r *relationshipDB) GetAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) {
var followReqIDs []string var followReqIDs []string
@ -277,7 +242,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
}) })
} }
func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { func (r *relationshipDB) GetAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) {
var followReqIDs []string var followReqIDs []string
@ -292,7 +257,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
}) })
} }
func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { func (r *relationshipDB) GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) {
var blockIDs []string var blockIDs []string

View file

@ -773,20 +773,6 @@ func (suite *RelationshipTestSuite) TestGetAccountFollows() {
suite.Len(follows, 2) suite.Len(follows, 2)
} }
func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowers() { func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
account := suite.testAccounts["local_account_1"] account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil) follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil)
@ -794,20 +780,6 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
suite.Len(follows, 2) suite.Len(follows, 2)
} }
func (suite *RelationshipTestSuite) TestCountAccountFollowers() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestUnfollowExisting() { func (suite *RelationshipTestSuite) TestUnfollowExisting() {
originAccount := suite.testAccounts["local_account_1"] originAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"] targetAccount := suite.testAccounts["admin_account"]

View file

@ -189,14 +189,14 @@ func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) {
constraintIDPlaceholders = append(constraintIDPlaceholders, "?") constraintIDPlaceholders = append(constraintIDPlaceholders, "?")
constraintIDs = append(constraintIDs, bun.Ident(constraint)) constraintIDs = append(constraintIDs, bun.Ident(constraint))
} }
onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update" onSQL := "CONFLICT (" + strings.Join(constraintIDPlaceholders, ", ") + ") DO UPDATE"
setClauses := make([]string, 0, len(columns)) setClauses := make([]string, 0, len(columns))
setIDs := make([]interface{}, 0, 2*len(columns)) setIDs := make([]interface{}, 0, 2*len(columns))
for _, column := range columns { for _, column := range columns {
setClauses = append(setClauses, "? = ?")
// "excluded" is a special table that contains only the row involved in a conflict. // "excluded" is a special table that contains only the row involved in a conflict.
setClauses = append(setClauses, "? = excluded.?") setIDs = append(setIDs, bun.Ident(column), bun.Ident("excluded."+column))
setIDs = append(setIDs, bun.Ident(column), bun.Ident(column))
} }
setSQL := strings.Join(setClauses, ", ") setSQL := strings.Join(setClauses, ", ")

View file

@ -140,44 +140,44 @@ type Relationship interface {
// GetAccountFollows returns a slice of follows owned by the given accountID. // GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error)
// GetAccountFollowIDs is like GetAccountFollows, but returns just IDs.
GetAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance. // GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.
GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// GetAccountLocalFollowIDs is like GetAccountLocalFollows, but returns just IDs.
GetAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error)
// GetAccountFollowers fetches follows that target given accountID. // GetAccountFollowers fetches follows that target given accountID.
GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error)
// GetAccountFollowerIDs is like GetAccountFollowers, but returns just IDs.
GetAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// GetAccountLocalFollowerIDs is like GetAccountLocalFollowers, but returns just IDs.
GetAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error)
// GetAccountFollowRequests returns all follow requests targeting the given account. // GetAccountFollowRequests returns all follow requests targeting the given account.
GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error)
// GetAccountFollowRequestIDs is like GetAccountFollowRequests, but returns just IDs.
GetAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
// GetAccountFollowRequesting returns all follow requests originating from the given account. // GetAccountFollowRequesting returns all follow requests originating from the given account.
GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error)
// GetAccountFollowRequestingIDs is like GetAccountFollowRequesting, but returns just IDs.
GetAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters.
GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error)
// CountAccountFollows returns the amount of accounts that the given accountID is following. // GetAccountBlockIDs is like GetAccountBlocks, but returns just IDs.
CountAccountFollows(ctx context.Context, accountID string) (int, error) GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
// CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.
CountAccountLocalFollows(ctx context.Context, accountID string) (int, error)
// CountAccountFollowers returns the amounts that the given ID is followed by.
CountAccountFollowers(ctx context.Context, accountID string) (int, error)
// CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.
CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error)
// CountAccountFollowRequests returns number of follow requests targeting the given account.
CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)
// CountAccountFollowerRequests returns number of follow requests originating from the given account.
CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
// CountAccountBlocks ...
CountAccountBlocks(ctx context.Context, accountID string) (int, error)
// GetNote gets a private note from a source account on a target account, if it exists. // GetNote gets a private note from a source account on a target account, if it exists.
GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error)

View file

@ -695,7 +695,7 @@ func (d *Dereferencer) enrichAccount(
representation of the target account, derived from representation of the target account, derived from
a combination of webfinger lookups and dereferencing. a combination of webfinger lookups and dereferencing.
Further fetching beyond this point is for peripheral Further fetching beyond this point is for peripheral
things like account avatar, header, emojis. things like account avatar, header, emojis, stats.
*/ */
// Ensure internal db ID is // Ensure internal db ID is
@ -718,6 +718,11 @@ func (d *Dereferencer) enrichAccount(
log.Errorf(ctx, "error fetching remote emojis for account %s: %v", uri, err) log.Errorf(ctx, "error fetching remote emojis for account %s: %v", uri, err)
} }
// Fetch followers/following count for this account.
if err := d.fetchRemoteAccountStats(ctx, latestAcc, requestUser); err != nil {
log.Errorf(ctx, "error fetching remote stats for account %s: %v", uri, err)
}
if account.IsNew() { if account.IsNew() {
// Prefer published/created time from // Prefer published/created time from
// apubAcc, fall back to FetchedAt value. // apubAcc, fall back to FetchedAt value.
@ -1036,6 +1041,113 @@ func (d *Dereferencer) fetchRemoteAccountEmojis(ctx context.Context, targetAccou
return changed, nil return changed, nil
} }
func (d *Dereferencer) fetchRemoteAccountStats(ctx context.Context, account *gtsmodel.Account, requestUser string) error {
// Ensure we have a stats model for this account.
if account.Stats == nil {
if err := d.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
// We want to update stats by getting remote
// followers/following/statuses counts for
// this account.
//
// If we fail getting any particular stat,
// it will just fall back to counting local.
// Followers first.
if count, err := d.countCollection(
ctx,
account.FollowersURI,
requestUser,
); err != nil {
// Log this but don't bail.
log.Warnf(ctx,
"couldn't count followers for @%s@%s: %v",
account.Username, account.Domain, err,
)
} else if count > 0 {
// Positive integer is useful!
account.Stats.FollowersCount = &count
}
// Now following.
if count, err := d.countCollection(
ctx,
account.FollowingURI,
requestUser,
); err != nil {
// Log this but don't bail.
log.Warnf(ctx,
"couldn't count following for @%s@%s: %v",
account.Username, account.Domain, err,
)
} else if count > 0 {
// Positive integer is useful!
account.Stats.FollowingCount = &count
}
// Now statuses count.
if count, err := d.countCollection(
ctx,
account.OutboxURI,
requestUser,
); err != nil {
// Log this but don't bail.
log.Warnf(ctx,
"couldn't count statuses for @%s@%s: %v",
account.Username, account.Domain, err,
)
} else if count > 0 {
// Positive integer is useful!
account.Stats.StatusesCount = &count
}
// Update stats now.
if err := d.state.DB.UpdateAccountStats(
ctx,
account.Stats,
"followers_count",
"following_count",
"statuses_count",
); err != nil {
return gtserror.Newf("db error updating account stats: %w", err)
}
return nil
}
// countCollection parses the given uriStr,
// dereferences the result as a collection
// type, and returns total items as 0, or
// a positive integer, or -1 if total items
// cannot be counted.
//
// Error will be returned for invalid non-empty
// URIs or dereferencing isses.
func (d *Dereferencer) countCollection(
ctx context.Context,
uriStr string,
requestUser string,
) (int, error) {
if uriStr == "" {
return -1, nil
}
uri, err := url.Parse(uriStr)
if err != nil {
return -1, err
}
collect, err := d.dereferenceCollection(ctx, requestUser, uri)
if err != nil {
return -1, err
}
return collect.TotalItems(), nil
}
// dereferenceAccountFeatured dereferences an account's featuredCollectionURI (if not empty). For each discovered status, this status will // dereferenceAccountFeatured dereferences an account's featuredCollectionURI (if not empty). For each discovered status, this status will
// be dereferenced (if necessary) and marked as pinned (if necessary). Then, old pins will be removed if they're not included in new pins. // be dereferenced (if necessary) and marked as pinned (if necessary). Then, old pins will be removed if they're not included in new pins.
func (d *Dereferencer) dereferenceAccountFeatured(ctx context.Context, requestUser string, account *gtsmodel.Account) error { func (d *Dereferencer) dereferenceAccountFeatured(ctx context.Context, requestUser string, account *gtsmodel.Account) error {

View file

@ -40,7 +40,7 @@ func (d *Dereferencer) dereferenceCollection(ctx context.Context, username strin
rsp, err := transport.Dereference(ctx, pageIRI) rsp, err := transport.Dereference(ctx, pageIRI)
if err != nil { if err != nil {
return nil, gtserror.Newf("error deferencing %s: %w", pageIRI.String(), err) return nil, gtserror.Newf("error dereferencing %s: %w", pageIRI.String(), err)
} }
collect, err := ap.ResolveCollection(ctx, rsp.Body) collect, err := ap.ResolveCollection(ctx, rsp.Body)

View file

@ -89,11 +89,13 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
return err return err
} }
// Process side effects asynchronously.
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept, APActivityType: ap.ActivityAccept,
GTSModel: follow, GTSModel: follow,
ReceivingAccount: receivingAcct, ReceivingAccount: receivingAcct,
RequestingAccount: requestingAcct,
}) })
} }
@ -136,11 +138,13 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
return err return err
} }
// Process side effects asynchronously.
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept, APActivityType: ap.ActivityAccept,
GTSModel: follow, GTSModel: follow,
ReceivingAccount: receivingAcct, ReceivingAccount: receivingAcct,
RequestingAccount: requestingAcct,
}) })
continue continue

View file

@ -82,10 +82,11 @@ func (f *federatingDB) Announce(ctx context.Context, announce vocab.ActivityStre
// This is a new boost. Process side effects asynchronously. // This is a new boost. Process side effects asynchronously.
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ActivityAnnounce, APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: boost, GTSModel: boost,
ReceivingAccount: receivingAcct, ReceivingAccount: receivingAcct,
RequestingAccount: requestingAcct,
}) })
return nil return nil

View file

@ -131,10 +131,11 @@ func (f *federatingDB) activityBlock(ctx context.Context, asType vocab.Type, rec
} }
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ActivityBlock, APObjectType: ap.ActivityBlock,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: block, GTSModel: block,
ReceivingAccount: receiving, ReceivingAccount: receiving,
RequestingAccount: requestingAccount,
}) })
return nil return nil
@ -307,7 +308,8 @@ func (f *federatingDB) createPollOptionables(
PollID: inReplyTo.PollID, PollID: inReplyTo.PollID,
Poll: inReplyTo.Poll, Poll: inReplyTo.Poll,
}, },
ReceivingAccount: receiver, ReceivingAccount: receiver,
RequestingAccount: requester,
}) })
return nil return nil
@ -376,12 +378,13 @@ func (f *federatingDB) createStatusable(
// Pass the statusable URI (APIri) into the processor // Pass the statusable URI (APIri) into the processor
// worker and do the rest of the processing asynchronously. // worker and do the rest of the processing asynchronously.
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
APIri: ap.GetJSONLDId(statusable), APIri: ap.GetJSONLDId(statusable),
APObjectModel: nil, APObjectModel: nil,
GTSModel: nil, GTSModel: nil,
ReceivingAccount: receiver, ReceivingAccount: receiver,
RequestingAccount: requester,
}) })
return nil return nil
} }
@ -389,12 +392,13 @@ func (f *federatingDB) createStatusable(
// Do the rest of the processing asynchronously. The processor // Do the rest of the processing asynchronously. The processor
// will handle inserting/updating + further dereferencing the status. // will handle inserting/updating + further dereferencing the status.
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
APIri: nil, APIri: nil,
GTSModel: nil, GTSModel: nil,
APObjectModel: statusable, APObjectModel: statusable,
ReceivingAccount: receiver, ReceivingAccount: receiver,
RequestingAccount: requester,
}) })
return nil return nil
@ -436,10 +440,11 @@ func (f *federatingDB) activityFollow(ctx context.Context, asType vocab.Type, re
} }
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: followRequest, GTSModel: followRequest,
ReceivingAccount: receivingAccount, ReceivingAccount: receivingAccount,
RequestingAccount: requestingAccount,
}) })
return nil return nil
@ -480,10 +485,11 @@ func (f *federatingDB) activityLike(ctx context.Context, asType vocab.Type, rece
} }
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ActivityLike, APObjectType: ap.ActivityLike,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: fave, GTSModel: fave,
ReceivingAccount: receivingAccount, ReceivingAccount: receivingAccount,
RequestingAccount: requestingAccount,
}) })
return nil return nil
@ -531,10 +537,11 @@ func (f *federatingDB) activityFlag(ctx context.Context, asType vocab.Type, rece
} }
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ActivityFlag, APObjectType: ap.ActivityFlag,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: report, GTSModel: report,
ReceivingAccount: receivingAccount, ReceivingAccount: receivingAccount,
RequestingAccount: requestingAccount,
}) })
return nil return nil

View file

@ -63,10 +63,11 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
if a, err := f.state.DB.GetAccountByURI(ctx, id.String()); err == nil && requestingAcct.ID == a.ID { if a, err := f.state.DB.GetAccountByURI(ctx, id.String()); err == nil && requestingAcct.ID == a.ID {
l.Debugf("deleting account: %s", a.ID) l.Debugf("deleting account: %s", a.ID)
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,
GTSModel: a, GTSModel: a,
ReceivingAccount: receivingAcct, ReceivingAccount: receivingAcct,
RequestingAccount: requestingAcct,
}) })
} }

View file

@ -99,11 +99,12 @@ func (f *federatingDB) updateAccountable(ctx context.Context, receivingAcct *gts
// updating of eg., avatar/header, emojis, etc. The actual db // updating of eg., avatar/header, emojis, etc. The actual db
// inserts/updates will take place there. // inserts/updates will take place there.
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate, APActivityType: ap.ActivityUpdate,
GTSModel: requestingAcct, GTSModel: requestingAcct,
APObjectModel: accountable, APObjectModel: accountable,
ReceivingAccount: receivingAcct, ReceivingAccount: receivingAcct,
RequestingAccount: requestingAcct,
}) })
return nil return nil
@ -155,11 +156,12 @@ func (f *federatingDB) updateStatusable(ctx context.Context, receivingAcct *gtsm
// Queue an UPDATE NOTE activity to our fedi API worker, // Queue an UPDATE NOTE activity to our fedi API worker,
// this will handle necessary database insertions, etc. // this will handle necessary database insertions, etc.
f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityUpdate, APActivityType: ap.ActivityUpdate,
GTSModel: status, // original status GTSModel: status, // original status
APObjectModel: (ap.Statusable)(statusable), APObjectModel: (ap.Statusable)(statusable),
ReceivingAccount: receivingAcct, ReceivingAccount: receivingAcct,
RequestingAccount: requestingAcct,
}) })
return nil return nil

View file

@ -80,6 +80,7 @@ type Account struct {
SuspendedAt time.Time `bun:"type:timestamptz,nullzero"` // When was this account suspended (eg., don't allow it to log in/post, don't accept media/posts from this account) SuspendedAt time.Time `bun:"type:timestamptz,nullzero"` // When was this account suspended (eg., don't allow it to log in/post, don't accept media/posts from this account)
SuspensionOrigin string `bun:"type:CHAR(26),nullzero"` // id of the database entry that caused this account to become suspended -- can be an account ID or a domain block ID SuspensionOrigin string `bun:"type:CHAR(26),nullzero"` // id of the database entry that caused this account to become suspended -- can be an account ID or a domain block ID
Settings *AccountSettings `bun:"-"` // gtsmodel.AccountSettings for this account. Settings *AccountSettings `bun:"-"` // gtsmodel.AccountSettings for this account.
Stats *AccountStats `bun:"-"` // gtsmodel.AccountStats for this account.
} }
// IsLocal returns whether account is a local user account. // IsLocal returns whether account is a local user account.

View file

@ -0,0 +1,33 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtsmodel
import "time"
// AccountStats models statistics
// for a remote or local account.
type AccountStats struct {
AccountID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // AccountID of this AccountStats.
RegeneratedAt time.Time `bun:"type:timestamptz,nullzero"` // Time this stats model was last regenerated (ie., created from scratch using COUNTs).
FollowersCount *int `bun:",nullzero,notnull"` // Number of accounts following AccountID.
FollowingCount *int `bun:",nullzero,notnull"` // Number of accounts followed by AccountID.
FollowRequestsCount *int `bun:",nullzero,notnull"` // Number of pending follow requests aimed at AccountID.
StatusesCount *int `bun:",nullzero,notnull"` // Number of statuses created by AccountID.
StatusesPinnedCount *int `bun:",nullzero,notnull"` // Number of statuses pinned by AccountID.
LastStatusAt time.Time `bun:"type:timestamptz,nullzero"` // Time of most recent status created by AccountID.
}

View file

@ -485,6 +485,11 @@ func (p *Processor) deleteAccountPeripheral(ctx context.Context, account *gtsmod
return gtserror.Newf("error deleting poll votes by account: %w", err) return gtserror.Newf("error deleting poll votes by account: %w", err)
} }
// Delete account stats model.
if err := p.state.DB.DeleteAccountStats(ctx, account.ID); err != nil {
return gtserror.Newf("error deleting stats for account: %w", err)
}
return nil return nil
} }

View file

@ -113,7 +113,7 @@ func (p *Processor) MoveSelf(
// in quick succession, so get a lock on // in quick succession, so get a lock on
// this account. // this account.
lockKey := originAcct.URI lockKey := originAcct.URI
unlock := p.state.ClientLocks.Lock(lockKey) unlock := p.state.AccountLocks.Lock(lockKey)
defer unlock() defer unlock()
// Ensure we have a valid, up-to-date representation of the target account. // Ensure we have a valid, up-to-date representation of the target account.

View file

@ -69,14 +69,18 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)
return nil, never, gtserror.NewErrorNotFound(err) return nil, never, gtserror.NewErrorNotFound(err)
} }
// Ensure account stats populated.
if account.Stats == nil {
if err := p.state.DB.PopulateAccountStats(ctx, account); err != nil {
err = gtserror.Newf("db error getting account stats %s: %w", username, err)
return nil, never, gtserror.NewErrorInternalError(err)
}
}
// LastModified time is needed by callers to check freshness for cacheing. // LastModified time is needed by callers to check freshness for cacheing.
// This might be a zero time.Time if account has never posted a status that's // This might be a zero time.Time if account has never posted a status that's
// eligible to appear in the RSS feed; that's fine. // eligible to appear in the RSS feed; that's fine.
lastPostAt, err := p.state.DB.GetAccountLastPosted(ctx, account.ID, true) lastPostAt := account.Stats.LastStatusAt
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = gtserror.Newf("db error getting account %s last posted: %w", username, err)
return nil, never, gtserror.NewErrorInternalError(err)
}
return func() (string, gtserror.WithCode) { return func() (string, gtserror.WithCode) {
// Assemble author namestring once only. // Assemble author namestring once only.

View file

@ -19,7 +19,6 @@
import ( import (
"context" "context"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -32,14 +31,11 @@ type GetRSSTestSuite struct {
func (suite *GetRSSTestSuite) TestGetAccountRSSAdmin() { func (suite *GetRSSTestSuite) TestGetAccountRSSAdmin() {
getFeed, lastModified, err := suite.accountProcessor.GetRSSFeedForUsername(context.Background(), "admin") getFeed, lastModified, err := suite.accountProcessor.GetRSSFeedForUsername(context.Background(), "admin")
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1634733405, lastModified.Unix()) suite.EqualValues(1634726497, lastModified.Unix())
feed, err := getFeed() feed, err := getFeed()
suite.NoError(err) suite.NoError(err)
suite.Equal("<?xml version=\"1.0\" encoding=\"UTF-8\"?><rss version=\"2.0\" xmlns:content=\"http://purl.org/rss/1.0/modules/content/\">\n <channel>\n <title>Posts from @admin@localhost:8080</title>\n <link>http://localhost:8080/@admin</link>\n <description>Posts from @admin@localhost:8080</description>\n <pubDate>Wed, 20 Oct 2021 10:41:37 +0000</pubDate>\n <lastBuildDate>Wed, 20 Oct 2021 10:41:37 +0000</lastBuildDate>\n <item>\n <title>open to see some puppies</title>\n <link>http://localhost:8080/@admin/statuses/01F8MHAAY43M6RJ473VQFCVH37</link>\n <description>@admin@localhost:8080 made a new post: &#34;🐕🐕🐕🐕🐕&#34;</description>\n <content:encoded><![CDATA[🐕🐕🐕🐕🐕]]></content:encoded>\n <author>@admin@localhost:8080</author>\n <guid>http://localhost:8080/@admin/statuses/01F8MHAAY43M6RJ473VQFCVH37</guid>\n <pubDate>Wed, 20 Oct 2021 12:36:45 +0000</pubDate>\n <source>http://localhost:8080/@admin/feed.rss</source>\n </item>\n <item>\n <title>hello world! #welcome ! first post on the instance :rainbow: !</title>\n <link>http://localhost:8080/@admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R</link>\n <description>@admin@localhost:8080 posted 1 attachment: &#34;hello world! #welcome ! first post on the instance :rainbow: !&#34;</description>\n <content:encoded><![CDATA[hello world! #welcome ! first post on the instance <img src=\"http://localhost:8080/fileserver/01AY6P665V14JJR0AFVRT7311Y/emoji/original/01F8MH9H8E4VG3KDYJR9EGPXCQ.png\" title=\":rainbow:\" alt=\":rainbow:\" width=\"25\" height=\"25\"/> !]]></content:encoded>\n <author>@admin@localhost:8080</author>\n <enclosure url=\"http://localhost:8080/fileserver/01F8MH17FWEB39HZJ76B6VXSKF/attachment/original/01F8MH6NEM8D7527KZAECTCR76.jpg\" length=\"62529\" type=\"image/jpeg\"></enclosure>\n <guid>http://localhost:8080/@admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R</guid>\n <pubDate>Wed, 20 Oct 2021 11:36:45 +0000</pubDate>\n <source>http://localhost:8080/@admin/feed.rss</source>\n </item>\n </channel>\n</rss>", feed)
fmt.Println(feed)
suite.Equal("<?xml version=\"1.0\" encoding=\"UTF-8\"?><rss version=\"2.0\" xmlns:content=\"http://purl.org/rss/1.0/modules/content/\">\n <channel>\n <title>Posts from @admin@localhost:8080</title>\n <link>http://localhost:8080/@admin</link>\n <description>Posts from @admin@localhost:8080</description>\n <pubDate>Wed, 20 Oct 2021 12:36:45 +0000</pubDate>\n <lastBuildDate>Wed, 20 Oct 2021 12:36:45 +0000</lastBuildDate>\n <item>\n <title>open to see some puppies</title>\n <link>http://localhost:8080/@admin/statuses/01F8MHAAY43M6RJ473VQFCVH37</link>\n <description>@admin@localhost:8080 made a new post: &#34;🐕🐕🐕🐕🐕&#34;</description>\n <content:encoded><![CDATA[🐕🐕🐕🐕🐕]]></content:encoded>\n <author>@admin@localhost:8080</author>\n <guid>http://localhost:8080/@admin/statuses/01F8MHAAY43M6RJ473VQFCVH37</guid>\n <pubDate>Wed, 20 Oct 2021 12:36:45 +0000</pubDate>\n <source>http://localhost:8080/@admin/feed.rss</source>\n </item>\n <item>\n <title>hello world! #welcome ! first post on the instance :rainbow: !</title>\n <link>http://localhost:8080/@admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R</link>\n <description>@admin@localhost:8080 posted 1 attachment: &#34;hello world! #welcome ! first post on the instance :rainbow: !&#34;</description>\n <content:encoded><![CDATA[hello world! #welcome ! first post on the instance <img src=\"http://localhost:8080/fileserver/01AY6P665V14JJR0AFVRT7311Y/emoji/original/01F8MH9H8E4VG3KDYJR9EGPXCQ.png\" title=\":rainbow:\" alt=\":rainbow:\" width=\"25\" height=\"25\"/> !]]></content:encoded>\n <author>@admin@localhost:8080</author>\n <enclosure url=\"http://localhost:8080/fileserver/01F8MH17FWEB39HZJ76B6VXSKF/attachment/original/01F8MH6NEM8D7527KZAECTCR76.jpg\" length=\"62529\" type=\"image/jpeg\"></enclosure>\n <guid>http://localhost:8080/@admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R</guid>\n <pubDate>Wed, 20 Oct 2021 11:36:45 +0000</pubDate>\n <source>http://localhost:8080/@admin/feed.rss</source>\n </item>\n </channel>\n</rss>", feed)
} }
func (suite *GetRSSTestSuite) TestGetAccountRSSZork() { func (suite *GetRSSTestSuite) TestGetAccountRSSZork() {
@ -49,9 +45,6 @@ func (suite *GetRSSTestSuite) TestGetAccountRSSZork() {
feed, err := getFeed() feed, err := getFeed()
suite.NoError(err) suite.NoError(err)
fmt.Println(feed)
suite.Equal("<?xml version=\"1.0\" encoding=\"UTF-8\"?><rss version=\"2.0\" xmlns:content=\"http://purl.org/rss/1.0/modules/content/\">\n <channel>\n <title>Posts from @the_mighty_zork@localhost:8080</title>\n <link>http://localhost:8080/@the_mighty_zork</link>\n <description>Posts from @the_mighty_zork@localhost:8080</description>\n <pubDate>Sun, 10 Dec 2023 09:24:00 +0000</pubDate>\n <lastBuildDate>Sun, 10 Dec 2023 09:24:00 +0000</lastBuildDate>\n <image>\n <url>http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.jpg</url>\n <title>Avatar for @the_mighty_zork@localhost:8080</title>\n <link>http://localhost:8080/@the_mighty_zork</link>\n </image>\n <item>\n <title>HTML in post</title>\n <link>http://localhost:8080/@the_mighty_zork/statuses/01HH9KYNQPA416TNJ53NSATP40</link>\n <description>@the_mighty_zork@localhost:8080 made a new post: &#34;Here&#39;s a bunch of HTML, read it and weep, weep then!&#xA;&#xA;```html&#xA;&lt;section class=&#34;about-user&#34;&gt;&#xA; &lt;div class=&#34;col-header&#34;&gt;&#xA; &lt;h2&gt;About&lt;/h2&gt;&#xA; &lt;/div&gt; &#xA; &lt;div class=&#34;fields&#34;&gt;&#xA; &lt;h3 class=&#34;sr-only&#34;&gt;Fields&lt;/h3&gt;&#xA; &lt;dl&gt;&#xA;...</description>\n <content:encoded><![CDATA[<p>Here's a bunch of HTML, read it and weep, weep then!</p><pre><code class=\"language-html\">&lt;section class=&#34;about-user&#34;&gt;\n &lt;div class=&#34;col-header&#34;&gt;\n &lt;h2&gt;About&lt;/h2&gt;\n &lt;/div&gt; \n &lt;div class=&#34;fields&#34;&gt;\n &lt;h3 class=&#34;sr-only&#34;&gt;Fields&lt;/h3&gt;\n &lt;dl&gt;\n &lt;div class=&#34;field&#34;&gt;\n &lt;dt&gt;should you follow me?&lt;/dt&gt;\n &lt;dd&gt;maybe!&lt;/dd&gt;\n &lt;/div&gt;\n &lt;div class=&#34;field&#34;&gt;\n &lt;dt&gt;age&lt;/dt&gt;\n &lt;dd&gt;120&lt;/dd&gt;\n &lt;/div&gt;\n &lt;/dl&gt;\n &lt;/div&gt;\n &lt;div class=&#34;bio&#34;&gt;\n &lt;h3 class=&#34;sr-only&#34;&gt;Bio&lt;/h3&gt;\n &lt;p&gt;i post about things that concern me&lt;/p&gt;\n &lt;/div&gt;\n &lt;div class=&#34;sr-only&#34; role=&#34;group&#34;&gt;\n &lt;h3 class=&#34;sr-only&#34;&gt;Stats&lt;/h3&gt;\n &lt;span&gt;Joined in Jun, 2022.&lt;/span&gt;\n &lt;span&gt;8 posts.&lt;/span&gt;\n &lt;span&gt;Followed by 1.&lt;/span&gt;\n &lt;span&gt;Following 1.&lt;/span&gt;\n &lt;/div&gt;\n &lt;div class=&#34;accountstats&#34; aria-hidden=&#34;true&#34;&gt;\n &lt;b&gt;Joined&lt;/b&gt;&lt;time datetime=&#34;2022-06-04T13:12:00.000Z&#34;&gt;Jun, 2022&lt;/time&gt;\n &lt;b&gt;Posts&lt;/b&gt;&lt;span&gt;8&lt;/span&gt;\n &lt;b&gt;Followed by&lt;/b&gt;&lt;span&gt;1&lt;/span&gt;\n &lt;b&gt;Following&lt;/b&gt;&lt;span&gt;1&lt;/span&gt;\n &lt;/div&gt;\n&lt;/section&gt;\n</code></pre><p>There, hope you liked that!</p>]]></content:encoded>\n <author>@the_mighty_zork@localhost:8080</author>\n <guid>http://localhost:8080/@the_mighty_zork/statuses/01HH9KYNQPA416TNJ53NSATP40</guid>\n <pubDate>Sun, 10 Dec 2023 09:24:00 +0000</pubDate>\n <source>http://localhost:8080/@the_mighty_zork/feed.rss</source>\n </item>\n <item>\n <title>introduction post</title>\n <link>http://localhost:8080/@the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY</link>\n <description>@the_mighty_zork@localhost:8080 made a new post: &#34;hello everyone!&#34;</description>\n <content:encoded><![CDATA[hello everyone!]]></content:encoded>\n <author>@the_mighty_zork@localhost:8080</author>\n <guid>http://localhost:8080/@the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY</guid>\n <pubDate>Wed, 20 Oct 2021 10:40:37 +0000</pubDate>\n <source>http://localhost:8080/@the_mighty_zork/feed.rss</source>\n </item>\n </channel>\n</rss>", feed) suite.Equal("<?xml version=\"1.0\" encoding=\"UTF-8\"?><rss version=\"2.0\" xmlns:content=\"http://purl.org/rss/1.0/modules/content/\">\n <channel>\n <title>Posts from @the_mighty_zork@localhost:8080</title>\n <link>http://localhost:8080/@the_mighty_zork</link>\n <description>Posts from @the_mighty_zork@localhost:8080</description>\n <pubDate>Sun, 10 Dec 2023 09:24:00 +0000</pubDate>\n <lastBuildDate>Sun, 10 Dec 2023 09:24:00 +0000</lastBuildDate>\n <image>\n <url>http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.jpg</url>\n <title>Avatar for @the_mighty_zork@localhost:8080</title>\n <link>http://localhost:8080/@the_mighty_zork</link>\n </image>\n <item>\n <title>HTML in post</title>\n <link>http://localhost:8080/@the_mighty_zork/statuses/01HH9KYNQPA416TNJ53NSATP40</link>\n <description>@the_mighty_zork@localhost:8080 made a new post: &#34;Here&#39;s a bunch of HTML, read it and weep, weep then!&#xA;&#xA;```html&#xA;&lt;section class=&#34;about-user&#34;&gt;&#xA; &lt;div class=&#34;col-header&#34;&gt;&#xA; &lt;h2&gt;About&lt;/h2&gt;&#xA; &lt;/div&gt; &#xA; &lt;div class=&#34;fields&#34;&gt;&#xA; &lt;h3 class=&#34;sr-only&#34;&gt;Fields&lt;/h3&gt;&#xA; &lt;dl&gt;&#xA;...</description>\n <content:encoded><![CDATA[<p>Here's a bunch of HTML, read it and weep, weep then!</p><pre><code class=\"language-html\">&lt;section class=&#34;about-user&#34;&gt;\n &lt;div class=&#34;col-header&#34;&gt;\n &lt;h2&gt;About&lt;/h2&gt;\n &lt;/div&gt; \n &lt;div class=&#34;fields&#34;&gt;\n &lt;h3 class=&#34;sr-only&#34;&gt;Fields&lt;/h3&gt;\n &lt;dl&gt;\n &lt;div class=&#34;field&#34;&gt;\n &lt;dt&gt;should you follow me?&lt;/dt&gt;\n &lt;dd&gt;maybe!&lt;/dd&gt;\n &lt;/div&gt;\n &lt;div class=&#34;field&#34;&gt;\n &lt;dt&gt;age&lt;/dt&gt;\n &lt;dd&gt;120&lt;/dd&gt;\n &lt;/div&gt;\n &lt;/dl&gt;\n &lt;/div&gt;\n &lt;div class=&#34;bio&#34;&gt;\n &lt;h3 class=&#34;sr-only&#34;&gt;Bio&lt;/h3&gt;\n &lt;p&gt;i post about things that concern me&lt;/p&gt;\n &lt;/div&gt;\n &lt;div class=&#34;sr-only&#34; role=&#34;group&#34;&gt;\n &lt;h3 class=&#34;sr-only&#34;&gt;Stats&lt;/h3&gt;\n &lt;span&gt;Joined in Jun, 2022.&lt;/span&gt;\n &lt;span&gt;8 posts.&lt;/span&gt;\n &lt;span&gt;Followed by 1.&lt;/span&gt;\n &lt;span&gt;Following 1.&lt;/span&gt;\n &lt;/div&gt;\n &lt;div class=&#34;accountstats&#34; aria-hidden=&#34;true&#34;&gt;\n &lt;b&gt;Joined&lt;/b&gt;&lt;time datetime=&#34;2022-06-04T13:12:00.000Z&#34;&gt;Jun, 2022&lt;/time&gt;\n &lt;b&gt;Posts&lt;/b&gt;&lt;span&gt;8&lt;/span&gt;\n &lt;b&gt;Followed by&lt;/b&gt;&lt;span&gt;1&lt;/span&gt;\n &lt;b&gt;Following&lt;/b&gt;&lt;span&gt;1&lt;/span&gt;\n &lt;/div&gt;\n&lt;/section&gt;\n</code></pre><p>There, hope you liked that!</p>]]></content:encoded>\n <author>@the_mighty_zork@localhost:8080</author>\n <guid>http://localhost:8080/@the_mighty_zork/statuses/01HH9KYNQPA416TNJ53NSATP40</guid>\n <pubDate>Sun, 10 Dec 2023 09:24:00 +0000</pubDate>\n <source>http://localhost:8080/@the_mighty_zork/feed.rss</source>\n </item>\n <item>\n <title>introduction post</title>\n <link>http://localhost:8080/@the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY</link>\n <description>@the_mighty_zork@localhost:8080 made a new post: &#34;hello everyone!&#34;</description>\n <content:encoded><![CDATA[hello everyone!]]></content:encoded>\n <author>@the_mighty_zork@localhost:8080</author>\n <guid>http://localhost:8080/@the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY</guid>\n <pubDate>Wed, 20 Oct 2021 10:40:37 +0000</pubDate>\n <source>http://localhost:8080/@the_mighty_zork/feed.rss</source>\n </item>\n </channel>\n</rss>", feed)
} }
@ -77,9 +70,6 @@ func (suite *GetRSSTestSuite) TestGetAccountRSSZorkNoPosts() {
feed, err := getFeed() feed, err := getFeed()
suite.NoError(err) suite.NoError(err)
fmt.Println(feed)
suite.Equal("<?xml version=\"1.0\" encoding=\"UTF-8\"?><rss version=\"2.0\" xmlns:content=\"http://purl.org/rss/1.0/modules/content/\">\n <channel>\n <title>Posts from @the_mighty_zork@localhost:8080</title>\n <link>http://localhost:8080/@the_mighty_zork</link>\n <description>Posts from @the_mighty_zork@localhost:8080</description>\n <pubDate>Fri, 20 May 2022 11:09:18 +0000</pubDate>\n <lastBuildDate>Fri, 20 May 2022 11:09:18 +0000</lastBuildDate>\n <image>\n <url>http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.jpg</url>\n <title>Avatar for @the_mighty_zork@localhost:8080</title>\n <link>http://localhost:8080/@the_mighty_zork</link>\n </image>\n </channel>\n</rss>", feed) suite.Equal("<?xml version=\"1.0\" encoding=\"UTF-8\"?><rss version=\"2.0\" xmlns:content=\"http://purl.org/rss/1.0/modules/content/\">\n <channel>\n <title>Posts from @the_mighty_zork@localhost:8080</title>\n <link>http://localhost:8080/@the_mighty_zork</link>\n <description>Posts from @the_mighty_zork@localhost:8080</description>\n <pubDate>Fri, 20 May 2022 11:09:18 +0000</pubDate>\n <lastBuildDate>Fri, 20 May 2022 11:09:18 +0000</lastBuildDate>\n <image>\n <url>http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.jpg</url>\n <title>Avatar for @the_mighty_zork@localhost:8080</title>\n <link>http://localhost:8080/@the_mighty_zork</link>\n </image>\n </channel>\n</rss>", feed)
} }

View file

@ -49,7 +49,7 @@ func (p *Processor) AccountApprove(
// Get a lock on the account URI, // Get a lock on the account URI,
// to ensure it's not also being // to ensure it's not also being
// rejected at the same time! // rejected at the same time!
unlock := p.state.ClientLocks.Lock(user.Account.URI) unlock := p.state.AccountLocks.Lock(user.Account.URI)
defer unlock() defer unlock()
if !*user.Approved { if !*user.Approved {

View file

@ -52,7 +52,7 @@ func (p *Processor) AccountReject(
// Get a lock on the account URI, // Get a lock on the account URI,
// since we're going to be deleting // since we're going to be deleting
// it and its associated user. // it and its associated user.
unlock := p.state.ClientLocks.Lock(user.Account.URI) unlock := p.state.AccountLocks.Lock(user.Account.URI)
defer unlock() defer unlock()
// Can't reject an account with a // Can't reject an account with a

View file

@ -126,11 +126,12 @@ func (p *Processor) FollowersGet(ctx context.Context, requestedUser string, page
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// Calculate total number of followers available for account. // Ensure we have stats for this account.
total, err := p.state.DB.CountAccountFollowers(ctx, receiver.ID) if receiver.Stats == nil {
if err != nil { if err := p.state.DB.PopulateAccountStats(ctx, receiver); err != nil {
err := gtserror.Newf("error counting followers: %w", err) err := gtserror.Newf("error getting stats for account %s: %w", receiver.ID, err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
}
} }
var obj vocab.Type var obj vocab.Type
@ -138,7 +139,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestedUser string, page
// Start the AS collection params. // Start the AS collection params.
var params ap.CollectionParams var params ap.CollectionParams
params.ID = collectionID params.ID = collectionID
params.Total = total params.Total = *receiver.Stats.FollowersCount
switch { switch {
@ -235,11 +236,12 @@ func (p *Processor) FollowingGet(ctx context.Context, requestedUser string, page
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// Calculate total number of following available for account. // Ensure we have stats for this account.
total, err := p.state.DB.CountAccountFollows(ctx, receiver.ID) if receiver.Stats == nil {
if err != nil { if err := p.state.DB.PopulateAccountStats(ctx, receiver); err != nil {
err := gtserror.Newf("error counting follows: %w", err) err := gtserror.Newf("error getting stats for account %s: %w", receiver.ID, err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
}
} }
var obj vocab.Type var obj vocab.Type
@ -247,7 +249,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestedUser string, page
// Start AS collection params. // Start AS collection params.
var params ap.CollectionParams var params ap.CollectionParams
params.ID = collectionID params.ID = collectionID
params.Total = total params.Total = *receiver.Stats.FollowingCount
switch { switch {
case receiver.IsInstance() || case receiver.IsInstance() ||

View file

@ -82,18 +82,26 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A
return nil, errWithCode return nil, errWithCode
} }
// Get a lock on this account.
unlock := p.state.AccountLocks.Lock(requestingAccount.URI)
defer unlock()
if !targetStatus.PinnedAt.IsZero() { if !targetStatus.PinnedAt.IsZero() {
err := errors.New("status already pinned") err := errors.New("status already pinned")
return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error()) return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error())
} }
pinnedCount, err := p.state.DB.CountAccountPinned(ctx, requestingAccount.ID) // Ensure account stats populated.
if err != nil { if requestingAccount.Stats == nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking number of pinned statuses: %w", err)) if err := p.state.DB.PopulateAccountStats(ctx, requestingAccount); err != nil {
err = gtserror.Newf("db error getting account stats: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
} }
pinnedCount := *requestingAccount.Stats.StatusesPinnedCount
if pinnedCount >= allowedPinnedCount { if pinnedCount >= allowedPinnedCount {
err = fmt.Errorf("status pin limit exceeded, you've already pinned %d status(es) out of %d", pinnedCount, allowedPinnedCount) err := fmt.Errorf("status pin limit exceeded, you've already pinned %d status(es) out of %d", pinnedCount, allowedPinnedCount)
return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error()) return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error())
} }
@ -103,6 +111,17 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// Update account stats.
*requestingAccount.Stats.StatusesPinnedCount++
if err := p.state.DB.UpdateAccountStats(
ctx,
requestingAccount.Stats,
"statuses_pinned_count",
); err != nil {
err = gtserror.Newf("db error updating stats: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
if err := p.c.InvalidateTimelinedStatus(ctx, requestingAccount.ID, targetStatusID); err != nil { if err := p.c.InvalidateTimelinedStatus(ctx, requestingAccount.ID, targetStatusID); err != nil {
err = gtserror.Newf("error invalidating status from timelines: %w", err) err = gtserror.Newf("error invalidating status from timelines: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -128,16 +147,45 @@ func (p *Processor) PinRemove(ctx context.Context, requestingAccount *gtsmodel.A
return nil, errWithCode return nil, errWithCode
} }
// Get a lock on this account.
unlock := p.state.AccountLocks.Lock(requestingAccount.URI)
defer unlock()
if targetStatus.PinnedAt.IsZero() { if targetStatus.PinnedAt.IsZero() {
// Status already not pinned.
return p.c.GetAPIStatus(ctx, requestingAccount, targetStatus) return p.c.GetAPIStatus(ctx, requestingAccount, targetStatus)
} }
// Ensure account stats populated.
if requestingAccount.Stats == nil {
if err := p.state.DB.PopulateAccountStats(ctx, requestingAccount); err != nil {
err = gtserror.Newf("db error getting account stats: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
}
targetStatus.PinnedAt = time.Time{} targetStatus.PinnedAt = time.Time{}
if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {
err = gtserror.Newf("db error unpinning status: %w", err) err = gtserror.Newf("db error unpinning status: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// Update account stats.
//
// Clamp to 0 to avoid funny business.
*requestingAccount.Stats.StatusesPinnedCount--
if *requestingAccount.Stats.StatusesPinnedCount < 0 {
*requestingAccount.Stats.StatusesPinnedCount = 0
}
if err := p.state.DB.UpdateAccountStats(
ctx,
requestingAccount.Stats,
"statuses_pinned_count",
); err != nil {
err = gtserror.Newf("db error updating stats: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
if err := p.c.InvalidateTimelinedStatus(ctx, requestingAccount.ID, targetStatusID); err != nil { if err := p.c.InvalidateTimelinedStatus(ctx, requestingAccount.ID, targetStatusID); err != nil {
err = gtserror.Newf("error invalidating status from timelines: %w", err) err = gtserror.Newf("error invalidating status from timelines: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)

View file

@ -247,6 +247,11 @@ func (p *clientAPI) CreateStatus(ctx context.Context, cMsg messages.FromClientAP
return gtserror.Newf("%T not parseable as *gtsmodel.Status", cMsg.GTSModel) return gtserror.Newf("%T not parseable as *gtsmodel.Status", cMsg.GTSModel)
} }
// Update stats for the actor account.
if err := p.utilF.incrementStatusesCount(ctx, cMsg.OriginAccount, status); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil {
log.Errorf(ctx, "error timelining and notifying status: %v", err) log.Errorf(ctx, "error timelining and notifying status: %v", err)
} }
@ -311,6 +316,11 @@ func (p *clientAPI) CreateFollowReq(ctx context.Context, cMsg messages.FromClien
return gtserror.Newf("%T not parseable as *gtsmodel.FollowRequest", cMsg.GTSModel) return gtserror.Newf("%T not parseable as *gtsmodel.FollowRequest", cMsg.GTSModel)
} }
// Update stats for the target account.
if err := p.utilF.incrementFollowRequestsCount(ctx, cMsg.TargetAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil {
log.Errorf(ctx, "error notifying follow request: %v", err) log.Errorf(ctx, "error notifying follow request: %v", err)
} }
@ -360,6 +370,11 @@ func (p *clientAPI) CreateAnnounce(ctx context.Context, cMsg messages.FromClient
return gtserror.Newf("%T not parseable as *gtsmodel.Status", cMsg.GTSModel) return gtserror.Newf("%T not parseable as *gtsmodel.Status", cMsg.GTSModel)
} }
// Update stats for the actor account.
if err := p.utilF.incrementStatusesCount(ctx, cMsg.OriginAccount, boost); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
// Timeline and notify the boost wrapper status. // Timeline and notify the boost wrapper status.
if err := p.surface.timelineAndNotifyStatus(ctx, boost); err != nil { if err := p.surface.timelineAndNotifyStatus(ctx, boost); err != nil {
log.Errorf(ctx, "error timelining and notifying status: %v", err) log.Errorf(ctx, "error timelining and notifying status: %v", err)
@ -485,6 +500,20 @@ func (p *clientAPI) AcceptFollow(ctx context.Context, cMsg messages.FromClientAP
return gtserror.Newf("%T not parseable as *gtsmodel.Follow", cMsg.GTSModel) return gtserror.Newf("%T not parseable as *gtsmodel.Follow", cMsg.GTSModel)
} }
// Update stats for the target account.
if err := p.utilF.decrementFollowRequestsCount(ctx, cMsg.TargetAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if err := p.utilF.incrementFollowersCount(ctx, cMsg.TargetAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
// Update stats for the origin account.
if err := p.utilF.incrementFollowingCount(ctx, cMsg.OriginAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if err := p.surface.notifyFollow(ctx, follow); err != nil { if err := p.surface.notifyFollow(ctx, follow); err != nil {
log.Errorf(ctx, "error notifying follow: %v", err) log.Errorf(ctx, "error notifying follow: %v", err)
} }
@ -502,6 +531,11 @@ func (p *clientAPI) RejectFollowRequest(ctx context.Context, cMsg messages.FromC
return gtserror.Newf("%T not parseable as *gtsmodel.FollowRequest", cMsg.GTSModel) return gtserror.Newf("%T not parseable as *gtsmodel.FollowRequest", cMsg.GTSModel)
} }
// Update stats for the target account.
if err := p.utilF.decrementFollowRequestsCount(ctx, cMsg.TargetAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if err := p.federate.RejectFollow( if err := p.federate.RejectFollow(
ctx, ctx,
p.converter.FollowRequestToFollow(ctx, followReq), p.converter.FollowRequestToFollow(ctx, followReq),
@ -518,6 +552,16 @@ func (p *clientAPI) UndoFollow(ctx context.Context, cMsg messages.FromClientAPI)
return gtserror.Newf("%T not parseable as *gtsmodel.Follow", cMsg.GTSModel) return gtserror.Newf("%T not parseable as *gtsmodel.Follow", cMsg.GTSModel)
} }
// Update stats for the origin account.
if err := p.utilF.decrementFollowingCount(ctx, cMsg.OriginAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
// Update stats for the target account.
if err := p.utilF.decrementFollowersCount(ctx, cMsg.TargetAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if err := p.federate.UndoFollow(ctx, follow); err != nil { if err := p.federate.UndoFollow(ctx, follow); err != nil {
log.Errorf(ctx, "error federating follow undo: %v", err) log.Errorf(ctx, "error federating follow undo: %v", err)
} }
@ -565,6 +609,11 @@ func (p *clientAPI) UndoAnnounce(ctx context.Context, cMsg messages.FromClientAP
return gtserror.Newf("db error deleting status: %w", err) return gtserror.Newf("db error deleting status: %w", err)
} }
// Update stats for the origin account.
if err := p.utilF.decrementStatusesCount(ctx, cMsg.OriginAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if err := p.surface.deleteStatusFromTimelines(ctx, status.ID); err != nil { if err := p.surface.deleteStatusFromTimelines(ctx, status.ID); err != nil {
log.Errorf(ctx, "error removing timelined status: %v", err) log.Errorf(ctx, "error removing timelined status: %v", err)
} }
@ -603,6 +652,11 @@ func (p *clientAPI) DeleteStatus(ctx context.Context, cMsg messages.FromClientAP
log.Errorf(ctx, "error wiping status: %v", err) log.Errorf(ctx, "error wiping status: %v", err)
} }
// Update stats for the origin account.
if err := p.utilF.decrementStatusesCount(ctx, cMsg.OriginAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if status.InReplyToID != "" { if status.InReplyToID != "" {
// Interaction counts changed on the replied status; // Interaction counts changed on the replied status;
// uncache the prepared version from all timelines. // uncache the prepared version from all timelines.

View file

@ -182,11 +182,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusWithNotification() {
nil, nil,
nil, nil,
) )
statusJSON = suite.statusJSON(
ctx,
status,
receivingAccount,
)
) )
// Update the follow from receiving account -> posting account so // Update the follow from receiving account -> posting account so
@ -212,6 +207,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusWithNotification() {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
statusJSON := suite.statusJSON(
ctx,
status,
receivingAccount,
)
// Check message in home stream. // Check message in home stream.
suite.checkStreamed( suite.checkStreamed(
homeStream, homeStream,
@ -285,11 +286,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusReply() {
suite.testStatuses["local_account_2_status_1"], suite.testStatuses["local_account_2_status_1"],
nil, nil,
) )
statusJSON = suite.statusJSON(
ctx,
status,
receivingAccount,
)
) )
// Process the new status. // Process the new status.
@ -305,6 +301,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusReply() {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
statusJSON := suite.statusJSON(
ctx,
status,
receivingAccount,
)
// Check message in home stream. // Check message in home stream.
suite.checkStreamed( suite.checkStreamed(
homeStream, homeStream,
@ -451,11 +453,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusListRepliesPolicyLis
suite.testStatuses["local_account_2_status_1"], suite.testStatuses["local_account_2_status_1"],
nil, nil,
) )
statusJSON = suite.statusJSON(
ctx,
status,
receivingAccount,
)
) )
// Modify replies policy of test list to show replies // Modify replies policy of test list to show replies
@ -480,6 +477,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusListRepliesPolicyLis
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
statusJSON := suite.statusJSON(
ctx,
status,
receivingAccount,
)
// Check message in home stream. // Check message in home stream.
suite.checkStreamed( suite.checkStreamed(
homeStream, homeStream,
@ -518,11 +521,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusListRepliesPolicyLis
suite.testStatuses["local_account_2_status_1"], suite.testStatuses["local_account_2_status_1"],
nil, nil,
) )
statusJSON = suite.statusJSON(
ctx,
status,
receivingAccount,
)
) )
// Modify replies policy of test list to show replies // Modify replies policy of test list to show replies
@ -552,6 +550,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusListRepliesPolicyLis
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
statusJSON := suite.statusJSON(
ctx,
status,
receivingAccount,
)
// Check message in home stream. // Check message in home stream.
suite.checkStreamed( suite.checkStreamed(
homeStream, homeStream,
@ -590,11 +594,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusReplyListRepliesPoli
suite.testStatuses["local_account_2_status_1"], suite.testStatuses["local_account_2_status_1"],
nil, nil,
) )
statusJSON = suite.statusJSON(
ctx,
status,
receivingAccount,
)
) )
// Modify replies policy of test list. // Modify replies policy of test list.
@ -619,6 +618,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusReplyListRepliesPoli
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
statusJSON := suite.statusJSON(
ctx,
status,
receivingAccount,
)
// Check message in home stream. // Check message in home stream.
suite.checkStreamed( suite.checkStreamed(
homeStream, homeStream,
@ -654,11 +659,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusBoost() {
nil, nil,
suite.testStatuses["local_account_2_status_1"], suite.testStatuses["local_account_2_status_1"],
) )
statusJSON = suite.statusJSON(
ctx,
status,
receivingAccount,
)
) )
// Process the new status. // Process the new status.
@ -674,6 +674,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusBoost() {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
statusJSON := suite.statusJSON(
ctx,
status,
receivingAccount,
)
// Check message in home stream. // Check message in home stream.
suite.checkStreamed( suite.checkStreamed(
homeStream, homeStream,

View file

@ -122,7 +122,7 @@ func (p *Processor) ProcessFromFediAPI(ctx context.Context, fMsg messages.FromFe
// UPDATE SOMETHING // UPDATE SOMETHING
case ap.ActivityUpdate: case ap.ActivityUpdate:
switch fMsg.APObjectType { //nolint:gocritic switch fMsg.APObjectType {
// UPDATE NOTE/STATUS // UPDATE NOTE/STATUS
case ap.ObjectNote: case ap.ObjectNote:
@ -133,6 +133,15 @@ func (p *Processor) ProcessFromFediAPI(ctx context.Context, fMsg messages.FromFe
return p.fediAPI.UpdateAccount(ctx, fMsg) return p.fediAPI.UpdateAccount(ctx, fMsg)
} }
// ACCEPT SOMETHING
case ap.ActivityAccept:
switch fMsg.APObjectType { //nolint:gocritic
// ACCEPT FOLLOW
case ap.ActivityFollow:
return p.fediAPI.AcceptFollow(ctx, fMsg)
}
// DELETE SOMETHING // DELETE SOMETHING
case ap.ActivityDelete: case ap.ActivityDelete:
switch fMsg.APObjectType { switch fMsg.APObjectType {
@ -220,6 +229,11 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e
return nil return nil
} }
// Update stats for the remote account.
if err := p.utilF.incrementStatusesCount(ctx, fMsg.RequestingAccount, status); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if status.InReplyToID != "" { if status.InReplyToID != "" {
// Interaction counts changed on the replied status; uncache the // Interaction counts changed on the replied status; uncache the
// prepared version from all timelines. The status dereferencer // prepared version from all timelines. The status dereferencer
@ -290,14 +304,20 @@ func (p *fediAPI) CreateFollowReq(ctx context.Context, fMsg messages.FromFediAPI
} }
if *followRequest.TargetAccount.Locked { if *followRequest.TargetAccount.Locked {
// Account on our instance is locked: just notify the follow request. // Local account is locked: just notify the follow request.
if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil {
log.Errorf(ctx, "error notifying follow request: %v", err) log.Errorf(ctx, "error notifying follow request: %v", err)
} }
// And update stats for the local account.
if err := p.utilF.incrementFollowRequestsCount(ctx, fMsg.ReceivingAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
return nil return nil
} }
// Account on our instance is not locked: // Local account is not locked:
// Automatically accept the follow request // Automatically accept the follow request
// and notify about the new follower. // and notify about the new follower.
follow, err := p.state.DB.AcceptFollowRequest( follow, err := p.state.DB.AcceptFollowRequest(
@ -309,6 +329,16 @@ func (p *fediAPI) CreateFollowReq(ctx context.Context, fMsg messages.FromFediAPI
return gtserror.Newf("error accepting follow request: %w", err) return gtserror.Newf("error accepting follow request: %w", err)
} }
// Update stats for the local account.
if err := p.utilF.incrementFollowersCount(ctx, fMsg.ReceivingAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
// Update stats for the remote account.
if err := p.utilF.incrementFollowingCount(ctx, fMsg.RequestingAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if err := p.federate.AcceptFollow(ctx, follow); err != nil { if err := p.federate.AcceptFollow(ctx, follow); err != nil {
log.Errorf(ctx, "error federating follow request accept: %v", err) log.Errorf(ctx, "error federating follow request accept: %v", err)
} }
@ -369,6 +399,11 @@ func (p *fediAPI) CreateAnnounce(ctx context.Context, fMsg messages.FromFediAPI)
return gtserror.Newf("error dereferencing announce: %w", err) return gtserror.Newf("error dereferencing announce: %w", err)
} }
// Update stats for the remote account.
if err := p.utilF.incrementStatusesCount(ctx, fMsg.RequestingAccount, boost); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
// Timeline and notify the announce. // Timeline and notify the announce.
if err := p.surface.timelineAndNotifyStatus(ctx, boost); err != nil { if err := p.surface.timelineAndNotifyStatus(ctx, boost); err != nil {
log.Errorf(ctx, "error timelining and notifying status: %v", err) log.Errorf(ctx, "error timelining and notifying status: %v", err)
@ -509,6 +544,24 @@ func (p *fediAPI) UpdateAccount(ctx context.Context, fMsg messages.FromFediAPI)
return nil return nil
} }
func (p *fediAPI) AcceptFollow(ctx context.Context, fMsg messages.FromFediAPI) error {
// Update stats for the remote account.
if err := p.utilF.decrementFollowRequestsCount(ctx, fMsg.RequestingAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if err := p.utilF.incrementFollowersCount(ctx, fMsg.RequestingAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
// Update stats for the local account.
if err := p.utilF.incrementFollowingCount(ctx, fMsg.ReceivingAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
return nil
}
func (p *fediAPI) UpdateStatus(ctx context.Context, fMsg messages.FromFediAPI) error { func (p *fediAPI) UpdateStatus(ctx context.Context, fMsg messages.FromFediAPI) error {
// Cast the existing Status model attached to msg. // Cast the existing Status model attached to msg.
existing, ok := fMsg.GTSModel.(*gtsmodel.Status) existing, ok := fMsg.GTSModel.(*gtsmodel.Status)
@ -567,6 +620,11 @@ func (p *fediAPI) DeleteStatus(ctx context.Context, fMsg messages.FromFediAPI) e
log.Errorf(ctx, "error wiping status: %v", err) log.Errorf(ctx, "error wiping status: %v", err)
} }
// Update stats for the remote account.
if err := p.utilF.decrementStatusesCount(ctx, fMsg.RequestingAccount); err != nil {
log.Errorf(ctx, "error updating account stats: %v", err)
}
if status.InReplyToID != "" { if status.InReplyToID != "" {
// Interaction counts changed on the replied status; // Interaction counts changed on the replied status;
// uncache the prepared version from all timelines. // uncache the prepared version from all timelines.

View file

@ -55,10 +55,11 @@ func (suite *FromFediAPITestSuite) TestProcessFederationAnnounce() {
announceStatus.Visibility = boostedStatus.Visibility announceStatus.Visibility = boostedStatus.Visibility
err := suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{ err := suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{
APObjectType: ap.ActivityAnnounce, APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: announceStatus, GTSModel: announceStatus,
ReceivingAccount: suite.testAccounts["local_account_1"], ReceivingAccount: suite.testAccounts["local_account_1"],
RequestingAccount: boostingAccount,
}) })
suite.NoError(err) suite.NoError(err)
@ -115,10 +116,11 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() {
// Send the replied status off to the fedi worker to be further processed. // Send the replied status off to the fedi worker to be further processed.
err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{ err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
APObjectModel: replyingStatusable, APObjectModel: replyingStatusable,
ReceivingAccount: suite.testAccounts["local_account_1"], ReceivingAccount: repliedAccount,
RequestingAccount: replyingAccount,
}) })
suite.NoError(err) suite.NoError(err)
@ -178,10 +180,11 @@ func (suite *FromFediAPITestSuite) TestProcessFave() {
suite.NoError(err) suite.NoError(err)
err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{ err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{
APObjectType: ap.ActivityLike, APObjectType: ap.ActivityLike,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: fave, GTSModel: fave,
ReceivingAccount: favedAccount, ReceivingAccount: favedAccount,
RequestingAccount: favingAccount,
}) })
suite.NoError(err) suite.NoError(err)
@ -247,10 +250,11 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount(
suite.NoError(err) suite.NoError(err)
err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{ err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{
APObjectType: ap.ActivityLike, APObjectType: ap.ActivityLike,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: fave, GTSModel: fave,
ReceivingAccount: receivingAccount, ReceivingAccount: receivingAccount,
RequestingAccount: favingAccount,
}) })
suite.NoError(err) suite.NoError(err)
@ -318,10 +322,11 @@ func (suite *FromFediAPITestSuite) TestProcessAccountDelete() {
// now they are mufos! // now they are mufos!
err = suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{ err = suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,
GTSModel: deletedAccount, GTSModel: deletedAccount,
ReceivingAccount: receivingAccount, ReceivingAccount: receivingAccount,
RequestingAccount: deletedAccount,
}) })
suite.NoError(err) suite.NoError(err)
@ -398,10 +403,11 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
suite.NoError(err) suite.NoError(err)
err = suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{ err = suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: satanFollowRequestTurtle, GTSModel: satanFollowRequestTurtle,
ReceivingAccount: targetAccount, ReceivingAccount: targetAccount,
RequestingAccount: originAccount,
}) })
suite.NoError(err) suite.NoError(err)
@ -451,10 +457,11 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
suite.NoError(err) suite.NoError(err)
err = suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{ err = suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: satanFollowRequestTurtle, GTSModel: satanFollowRequestTurtle,
ReceivingAccount: targetAccount, ReceivingAccount: targetAccount,
RequestingAccount: originAccount,
}) })
suite.NoError(err) suite.NoError(err)
@ -526,11 +533,12 @@ func (suite *FromFediAPITestSuite) TestCreateStatusFromIRI() {
statusCreator := suite.testAccounts["remote_account_2"] statusCreator := suite.testAccounts["remote_account_2"]
err := suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{ err := suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: nil, // gtsmodel is nil because this is a forwarded status -- we want to dereference it using the iri GTSModel: nil, // gtsmodel is nil because this is a forwarded status -- we want to dereference it using the iri
ReceivingAccount: receivingAccount, ReceivingAccount: receivingAccount,
APIri: testrig.URLMustParse("http://example.org/users/Some_User/statuses/afaba698-5740-4e32-a702-af61aa543bc1"), RequestingAccount: statusCreator,
APIri: testrig.URLMustParse("http://example.org/users/Some_User/statuses/afaba698-5740-4e32-a702-af61aa543bc1"),
}) })
suite.NoError(err) suite.NoError(err)

View file

@ -238,3 +238,258 @@ func (u *utilF) redirectFollowers(
return true return true
} }
func (u *utilF) incrementStatusesCount(
ctx context.Context,
account *gtsmodel.Account,
status *gtsmodel.Status,
) error {
// Lock on this account since we're changing stats.
unlock := u.state.AccountLocks.Lock(account.URI)
defer unlock()
// Populate stats.
if account.Stats == nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
// Update stats by incrementing status
// count by one and setting last posted.
*account.Stats.StatusesCount++
account.Stats.LastStatusAt = status.CreatedAt
if err := u.state.DB.UpdateAccountStats(
ctx,
account.Stats,
"statuses_count",
"last_status_at",
); err != nil {
return gtserror.Newf("db error updating account stats: %w", err)
}
return nil
}
func (u *utilF) decrementStatusesCount(
ctx context.Context,
account *gtsmodel.Account,
) error {
// Lock on this account since we're changing stats.
unlock := u.state.AccountLocks.Lock(account.URI)
defer unlock()
// Populate stats.
if account.Stats == nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
// Update stats by decrementing
// status count by one.
//
// Clamp to 0 to avoid funny business.
*account.Stats.StatusesCount--
if *account.Stats.StatusesCount < 0 {
*account.Stats.StatusesCount = 0
}
if err := u.state.DB.UpdateAccountStats(
ctx,
account.Stats,
"statuses_count",
); err != nil {
return gtserror.Newf("db error updating account stats: %w", err)
}
return nil
}
func (u *utilF) incrementFollowersCount(
ctx context.Context,
account *gtsmodel.Account,
) error {
// Lock on this account since we're changing stats.
unlock := u.state.AccountLocks.Lock(account.URI)
defer unlock()
// Populate stats.
if account.Stats == nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
// Update stats by incrementing followers
// count by one and setting last posted.
*account.Stats.FollowersCount++
if err := u.state.DB.UpdateAccountStats(
ctx,
account.Stats,
"followers_count",
); err != nil {
return gtserror.Newf("db error updating account stats: %w", err)
}
return nil
}
func (u *utilF) decrementFollowersCount(
ctx context.Context,
account *gtsmodel.Account,
) error {
// Lock on this account since we're changing stats.
unlock := u.state.AccountLocks.Lock(account.URI)
defer unlock()
// Populate stats.
if account.Stats == nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
// Update stats by decrementing
// followers count by one.
//
// Clamp to 0 to avoid funny business.
*account.Stats.FollowersCount--
if *account.Stats.FollowersCount < 0 {
*account.Stats.FollowersCount = 0
}
if err := u.state.DB.UpdateAccountStats(
ctx,
account.Stats,
"followers_count",
); err != nil {
return gtserror.Newf("db error updating account stats: %w", err)
}
return nil
}
func (u *utilF) incrementFollowingCount(
ctx context.Context,
account *gtsmodel.Account,
) error {
// Lock on this account since we're changing stats.
unlock := u.state.AccountLocks.Lock(account.URI)
defer unlock()
// Populate stats.
if account.Stats == nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
// Update stats by incrementing
// followers count by one.
*account.Stats.FollowingCount++
if err := u.state.DB.UpdateAccountStats(
ctx,
account.Stats,
"following_count",
); err != nil {
return gtserror.Newf("db error updating account stats: %w", err)
}
return nil
}
func (u *utilF) decrementFollowingCount(
ctx context.Context,
account *gtsmodel.Account,
) error {
// Lock on this account since we're changing stats.
unlock := u.state.AccountLocks.Lock(account.URI)
defer unlock()
// Populate stats.
if account.Stats == nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
// Update stats by decrementing
// following count by one.
//
// Clamp to 0 to avoid funny business.
*account.Stats.FollowingCount--
if *account.Stats.FollowingCount < 0 {
*account.Stats.FollowingCount = 0
}
if err := u.state.DB.UpdateAccountStats(
ctx,
account.Stats,
"following_count",
); err != nil {
return gtserror.Newf("db error updating account stats: %w", err)
}
return nil
}
func (u *utilF) incrementFollowRequestsCount(
ctx context.Context,
account *gtsmodel.Account,
) error {
// Lock on this account since we're changing stats.
unlock := u.state.AccountLocks.Lock(account.URI)
defer unlock()
// Populate stats.
if account.Stats == nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
// Update stats by incrementing
// follow requests count by one.
*account.Stats.FollowRequestsCount++
if err := u.state.DB.UpdateAccountStats(
ctx,
account.Stats,
"follow_requests_count",
); err != nil {
return gtserror.Newf("db error updating account stats: %w", err)
}
return nil
}
func (u *utilF) decrementFollowRequestsCount(
ctx context.Context,
account *gtsmodel.Account,
) error {
// Lock on this account since we're changing stats.
unlock := u.state.AccountLocks.Lock(account.URI)
defer unlock()
// Populate stats.
if account.Stats == nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
// Update stats by decrementing
// follow requests count by one.
//
// Clamp to 0 to avoid funny business.
*account.Stats.FollowRequestsCount--
if *account.Stats.FollowRequestsCount < 0 {
*account.Stats.FollowRequestsCount = 0
}
if err := u.state.DB.UpdateAccountStats(
ctx,
account.Stats,
"follow_requests_count",
); err != nil {
return gtserror.Newf("db error updating account stats: %w", err)
}
return nil
}

View file

@ -50,11 +50,12 @@ type State struct {
// functions, and by the go-fed/activity library. // functions, and by the go-fed/activity library.
FedLocks mutexes.MutexMap FedLocks mutexes.MutexMap
// ClientLocks provides access to this state's // AccountLocks provides access to this state's
// mutex map of per URI client locks. // mutex map of per URI locks, intended for use
// // when updating accounts, migrating, approving
// Used during account migration actions. // or rejecting an account, changing stats,
ClientLocks mutexes.MutexMap // pinned statuses, etc.
AccountLocks mutexes.MutexMap
// Storage provides access to the storage driver. // Storage provides access to the storage driver.
Storage *storage.Driver Storage *storage.Driver

View file

@ -63,18 +63,22 @@ func toMastodonVersion(in string) string {
// if something goes wrong. The returned application should be ready to serialize on an API level, and may have sensitive fields // if something goes wrong. The returned application should be ready to serialize on an API level, and may have sensitive fields
// (such as client id and client secret), so serve it only to an authorized user who should have permission to see it. // (such as client id and client secret), so serve it only to an authorized user who should have permission to see it.
func (c *Converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmodel.Account) (*apimodel.Account, error) { func (c *Converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmodel.Account) (*apimodel.Account, error) {
// we can build this sensitive account easily by first getting the public account.... // We can build this sensitive account model
// by first getting the public account, and
// then adding the Source object to it.
apiAccount, err := c.AccountToAPIAccountPublic(ctx, a) apiAccount, err := c.AccountToAPIAccountPublic(ctx, a)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// then adding the Source object to it... // Ensure account stats populated.
if a.Stats == nil {
// check pending follow requests aimed at this account if err := c.state.DB.PopulateAccountStats(ctx, a); err != nil {
frc, err := c.state.DB.CountAccountFollowRequests(ctx, a.ID) return nil, gtserror.Newf(
if err != nil { "error getting stats for account %s: %w",
return nil, fmt.Errorf("error counting follow requests: %s", err) a.ID, err,
)
}
} }
statusContentType := string(apimodel.StatusContentTypeDefault) statusContentType := string(apimodel.StatusContentTypeDefault)
@ -89,7 +93,7 @@ func (c *Converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmode
StatusContentType: statusContentType, StatusContentType: statusContentType,
Note: a.NoteRaw, Note: a.NoteRaw,
Fields: c.fieldsToAPIFields(a.FieldsRaw), Fields: c.fieldsToAPIFields(a.FieldsRaw),
FollowRequestsCount: frc, FollowRequestsCount: *a.Stats.FollowRequestsCount,
AlsoKnownAsURIs: a.AlsoKnownAsURIs, AlsoKnownAsURIs: a.AlsoKnownAsURIs,
} }
@ -100,8 +104,22 @@ func (c *Converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmode
// if something goes wrong. The returned account should be ready to serialize on an API level, and may NOT have sensitive fields. // if something goes wrong. The returned account should be ready to serialize on an API level, and may NOT have sensitive fields.
// In other words, this is the public record that the server has of an account. // In other words, this is the public record that the server has of an account.
func (c *Converter) AccountToAPIAccountPublic(ctx context.Context, a *gtsmodel.Account) (*apimodel.Account, error) { func (c *Converter) AccountToAPIAccountPublic(ctx context.Context, a *gtsmodel.Account) (*apimodel.Account, error) {
if err := c.state.DB.PopulateAccount(ctx, a); err != nil { // Populate account struct fields.
err := c.state.DB.PopulateAccount(ctx, a)
switch {
case err == nil:
// No problem.
case err != nil && a.Stats != nil:
// We have stats so that's
// *maybe* OK, try to continue.
log.Errorf(ctx, "error(s) populating account, will continue: %s", err) log.Errorf(ctx, "error(s) populating account, will continue: %s", err)
default:
// There was an error and we don't
// have stats, we can't continue.
return nil, gtserror.Newf("account stats not populated, could not continue: %w", err)
} }
// Basic account stats: // Basic account stats:
@ -110,30 +128,17 @@ func (c *Converter) AccountToAPIAccountPublic(ctx context.Context, a *gtsmodel.A
// - Statuses count // - Statuses count
// - Last status time // - Last status time
followersCount, err := c.state.DB.CountAccountFollowers(ctx, a.ID) var (
if err != nil && !errors.Is(err, db.ErrNoEntries) { followersCount = *a.Stats.FollowersCount
return nil, gtserror.Newf("error counting followers: %w", err) followingCount = *a.Stats.FollowingCount
} statusesCount = *a.Stats.StatusesCount
lastStatusAt = func() *string {
followingCount, err := c.state.DB.CountAccountFollows(ctx, a.ID) if a.Stats.LastStatusAt.IsZero() {
if err != nil && !errors.Is(err, db.ErrNoEntries) { return nil
return nil, gtserror.Newf("error counting following: %w", err) }
} return util.Ptr(util.FormatISO8601(a.Stats.LastStatusAt))
}()
statusesCount, err := c.state.DB.CountAccountStatuses(ctx, a.ID) )
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.Newf("error counting statuses: %w", err)
}
var lastStatusAt *string
lastPosted, err := c.state.DB.GetAccountLastPosted(ctx, a.ID, false)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.Newf("error getting last posted: %w", err)
}
if !lastPosted.IsZero() {
lastStatusAt = util.Ptr(util.FormatISO8601(lastPosted))
}
// Profile media + nice extras: // Profile media + nice extras:
// - Avatar // - Avatar

View file

@ -26,6 +26,7 @@ EXPECT=$(cat << "EOF"
"account-mem-ratio": 5, "account-mem-ratio": 5,
"account-note-mem-ratio": 1, "account-note-mem-ratio": 1,
"account-settings-mem-ratio": 0.1, "account-settings-mem-ratio": 0.1,
"account-stats-mem-ratio": 2,
"application-mem-ratio": 0.1, "application-mem-ratio": 0.1,
"block-mem-ratio": 3, "block-mem-ratio": 3,
"boost-of-ids-mem-ratio": 3, "boost-of-ids-mem-ratio": 3,