[feature] persist worker queues to db (#3042)

* persist queued worker tasks to database on shutdown, fill worker queues from database on startup

* ensure the tasks are sorted by creation time before pushing them

* add migration to insert WorkerTask{} into database, add test for worker task persistence

* add test for recovering worker queues from database

* quick tweak

* whoops we ended up with double cleaner job scheduling

* insert each task separately, because bun is throwing some reflection error??

* add specific checking of cancelled worker contexts

* add http request signing to deliveries recovered from database

* add test for outgoing public key ID being correctly set on delivery

* replace select with Queue.PopCtx()

* get rid of loop now we don't use it

* remove field now we don't use it

* ensure that signing func is set

* header values weren't being copied over 🤦

* use ptr for httpclient.Request in delivery

* move worker queue filling to later in server init process

* fix rebase issues

* make logging less shouty

* use slices.Delete() instead of copying / reslicing

* have database return tasks in ascending order instead of sorting them

* add a 1 minute timeout to persisting worker queues
This commit is contained in:
kim 2024-07-30 11:58:31 +00:00 committed by GitHub
parent 42932f9820
commit 87cff71af9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1191 additions and 93 deletions

View file

@ -87,9 +87,9 @@
// defer function for safe shutdown
// depending on what services were
// managed to be started.
state = new(state.State)
route *router.Router
state = new(state.State)
route *router.Router
process *processing.Processor
)
defer func() {
@ -125,6 +125,23 @@
}
}
if process != nil {
const timeout = time.Minute
// Use a new timeout context to ensure
// persisting queued tasks does not fail!
// The main ctx is very likely canceled.
ctx := context.WithoutCancel(ctx)
ctx, cncl := context.WithTimeout(ctx, timeout)
defer cncl()
// Now that all the "moving" components have been stopped,
// persist any remaining queued worker tasks to the database.
if err := process.Admin().PersistWorkerQueues(ctx); err != nil {
log.Errorf(ctx, "error persisting worker queues: %v", err)
}
}
if state.DB != nil {
// Lastly, if database service was started,
// ensure it gets closed now all else stopped.
@ -270,7 +287,7 @@ func(context.Context, time.Time) {
// Create the processor using all the
// other services we've created so far.
processor := processing.NewProcessor(
process = processing.NewProcessor(
cleaner,
typeConverter,
federator,
@ -286,14 +303,14 @@ func(context.Context, time.Time) {
state.Workers.Client.Init(messages.ClientMsgIndices())
state.Workers.Federator.Init(messages.FederatorMsgIndices())
state.Workers.Delivery.Init(client)
state.Workers.Client.Process = processor.Workers().ProcessFromClientAPI
state.Workers.Federator.Process = processor.Workers().ProcessFromFediAPI
state.Workers.Client.Process = process.Workers().ProcessFromClientAPI
state.Workers.Federator.Process = process.Workers().ProcessFromFediAPI
// Now start workers!
state.Workers.Start()
// Schedule notif tasks for all existing poll expiries.
if err := processor.Polls().ScheduleAll(ctx); err != nil {
if err := process.Polls().ScheduleAll(ctx); err != nil {
return fmt.Errorf("error scheduling poll expiries: %w", err)
}
@ -303,7 +320,7 @@ func(context.Context, time.Time) {
}
// Run advanced migrations.
if err := processor.AdvancedMigrations().Migrate(ctx); err != nil {
if err := process.AdvancedMigrations().Migrate(ctx); err != nil {
return err
}
@ -370,7 +387,7 @@ func(context.Context, time.Time) {
// attach global no route / 404 handler to the router
route.AttachNoRouteHandler(func(c *gin.Context) {
apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), processor.InstanceGetV1)
apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), process.InstanceGetV1)
})
// build router modules
@ -393,15 +410,15 @@ func(context.Context, time.Time) {
}
var (
authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths
clientModule = api.NewClient(state, processor) // api client endpoints
metricsModule = api.NewMetrics() // Metrics endpoints
healthModule = api.NewHealth(dbService.Ready) // Health check endpoints
fileserverModule = api.NewFileserver(processor) // fileserver endpoints
wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints
nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint
activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints
webModule = web.New(dbService, processor) // web pages + user profiles + settings panels etc
authModule = api.NewAuth(dbService, process, idp, routerSession, sessionName) // auth/oauth paths
clientModule = api.NewClient(state, process) // api client endpoints
metricsModule = api.NewMetrics() // Metrics endpoints
healthModule = api.NewHealth(dbService.Ready) // Health check endpoints
fileserverModule = api.NewFileserver(process) // fileserver endpoints
wellKnownModule = api.NewWellKnown(process) // .well-known endpoints
nodeInfoModule = api.NewNodeInfo(process) // nodeinfo endpoint
activityPubModule = api.NewActivityPub(dbService, process) // ActivityPub endpoints
webModule = web.New(dbService, process) // web pages + user profiles + settings panels etc
)
// create required middleware
@ -416,10 +433,11 @@ func(context.Context, time.Time) {
// throttling
cpuMultiplier := config.GetAdvancedThrottlingMultiplier()
retryAfter := config.GetAdvancedThrottlingRetryAfter()
clThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // client api
s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // server-to-server (AP)
fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // fileserver / web templates / emojis
pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // throttle public key endpoint separately
clThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // client api
s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter)
// server-to-server (AP)
fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // fileserver / web templates / emojis
pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // throttle public key endpoint separately
gzip := middleware.Gzip() // applied to all except fileserver
@ -442,6 +460,11 @@ func(context.Context, time.Time) {
return fmt.Errorf("error starting router: %w", err)
}
// Fill worker queues from persisted task data in database.
if err := process.Admin().FillWorkerQueues(ctx); err != nil {
return fmt.Errorf("error filling worker queues: %w", err)
}
// catch shutdown signals from the operating system
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

View file

@ -84,6 +84,7 @@ type DBService struct {
db.Timeline
db.User
db.Tombstone
db.WorkerTask
db *bun.DB
}
@ -302,6 +303,9 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
WorkerTask: &workerTaskDB{
db: db,
},
db: db,
}

View file

@ -0,0 +1,51 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// WorkerTask table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.WorkerTask{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View file

@ -0,0 +1,58 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type workerTaskDB struct{ db *bun.DB }
func (w *workerTaskDB) GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error) {
var tasks []*gtsmodel.WorkerTask
if err := w.db.NewSelect().
Model(&tasks).
OrderExpr("? ASC", bun.Ident("created_at")).
Scan(ctx); err != nil {
return nil, err
}
return tasks, nil
}
func (w *workerTaskDB) PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error {
var errs []error
for _, task := range tasks {
_, err := w.db.NewInsert().Model(task).Exec(ctx)
if err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
func (w *workerTaskDB) DeleteWorkerTaskByID(ctx context.Context, id uint) error {
_, err := w.db.NewDelete().
Table("worker_tasks").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return err
}

View file

@ -56,4 +56,5 @@ type DB interface {
Timeline
User
Tombstone
WorkerTask
}

35
internal/db/workertask.go Normal file
View file

@ -0,0 +1,35 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type WorkerTask interface {
// GetWorkerTasks fetches all persisted worker tasks from the database.
GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error)
// PutWorkerTasks persists the given worker tasks to the database.
PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error
// DeleteWorkerTask deletes worker task with given ID from database.
DeleteWorkerTaskByID(ctx context.Context, id uint) error
}

View file

@ -34,8 +34,8 @@
// queued tasks from being lost. It is simply a
// means to store a blob of serialized task data.
type WorkerTask struct {
ID uint `bun:""`
WorkerType uint8 `bun:""`
TaskData []byte `bun:""`
CreatedAt time.Time `bun:""`
ID uint `bun:",pk,autoincrement"`
WorkerType WorkerType `bun:",notnull"`
TaskData []byte `bun:",nullzero,notnull"`
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"`
}

View file

@ -197,7 +197,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) {
// If the fast-fail flag was set, just
// attempt a single iteration instead of
// following the below retry-backoff loop.
rsp, _, err = c.DoOnce(&req)
rsp, _, err = c.DoOnce(req)
if err != nil {
return nil, fmt.Errorf("%w (fast fail)", err)
}
@ -208,7 +208,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) {
var retry bool
// Perform the http request.
rsp, retry, err = c.DoOnce(&req)
rsp, retry, err = c.DoOnce(req)
if err == nil {
return rsp, nil
}

View file

@ -47,8 +47,8 @@ type Request struct {
// WrapRequest wraps an existing http.Request within
// our own httpclient.Request with retry / backoff tracking.
func WrapRequest(r *http.Request) Request {
var rr Request
func WrapRequest(r *http.Request) *Request {
rr := new(Request)
rr.Request = r
entry := log.WithContext(r.Context())
entry = entry.WithField("method", r.Method)

View file

@ -352,7 +352,7 @@ func resolveAPObject(data map[string]interface{}) (interface{}, error) {
// we then need to wrangle back into the original type. So we also store the type name
// and use this to determine the appropriate Go structure type to unmarshal into to.
func resolveGTSModel(typ string, data []byte) (interface{}, error) {
if typ == "" && data == nil {
if typ == "" {
// No data given.
return nil, nil
}

View file

@ -0,0 +1,426 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin
import (
"context"
"fmt"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
)
// NOTE:
// Having these functions in the processor, which is
// usually the intermediary that performs *processing*
// between the HTTP route handlers and the underlying
// database / storage layers is a little odd, so this
// may be subject to change!
//
// For now at least, this is a useful place that has
// access to the underlying database, workers and
// causes no dependency cycles with this use case!
// FillWorkerQueues recovers all serialized worker tasks from the database
// (if any!), and pushes them to each of their relevant worker queues.
func (p *Processor) FillWorkerQueues(ctx context.Context) error {
log.Info(ctx, "rehydrate!")
// Get all persisted worker tasks from db.
//
// (database returns these as ASCENDING, i.e.
// returned in the order they were inserted).
tasks, err := p.state.DB.GetWorkerTasks(ctx)
if err != nil {
return gtserror.Newf("error fetching worker tasks from db: %w", err)
}
var (
// Counts of each task type
// successfully recovered.
delivery int
federator int
client int
// Failed recoveries.
errors int
)
loop:
// Handle each persisted task, removing
// all those we can't handle. Leaving us
// with a slice of tasks we can safely
// delete from being persisted in the DB.
for i := 0; i < len(tasks); {
var err error
// Task at index.
task := tasks[i]
// Appropriate task count
// pointer to increment.
var counter *int
// Attempt to recovery persisted
// task depending on worker type.
switch task.WorkerType {
case gtsmodel.DeliveryWorker:
err = p.pushDelivery(ctx, task)
counter = &delivery
case gtsmodel.FederatorWorker:
err = p.pushFederator(ctx, task)
counter = &federator
case gtsmodel.ClientWorker:
err = p.pushClient(ctx, task)
counter = &client
default:
err = fmt.Errorf("invalid worker type %d", task.WorkerType)
}
if err != nil {
log.Errorf(ctx, "error pushing task %d: %v", task.ID, err)
// Drop error'd task from slice.
tasks = slices.Delete(tasks, i, i+1)
// Incr errors.
errors++
continue loop
}
// Increment slice
// index & counter.
(*counter)++
i++
}
// Tasks that worker successfully pushed
// to their appropriate workers, we can
// safely now remove from the database.
for _, task := range tasks {
if err := p.state.DB.DeleteWorkerTaskByID(ctx, task.ID); err != nil {
log.Errorf(ctx, "error deleting task from db: %v", err)
}
}
// Log recovered tasks.
log.WithContext(ctx).
WithField("delivery", delivery).
WithField("federator", federator).
WithField("client", client).
WithField("errors", errors).
Info("recovered queued tasks")
return nil
}
// PersistWorkerQueues pops all queued worker tasks (that are themselves persistable, i.e. not
// dereference tasks which are just function ptrs), serializes and persists them to the database.
func (p *Processor) PersistWorkerQueues(ctx context.Context) error {
log.Info(ctx, "dehydrate!")
var (
// Counts of each task type
// successfully persisted.
delivery int
federator int
client int
// Failed persists.
errors int
// Serialized tasks to persist.
tasks []*gtsmodel.WorkerTask
)
for {
// Pop all queued deliveries.
task, err := p.popDelivery()
if err != nil {
log.Errorf(ctx, "error popping delivery: %v", err)
errors++ // incr error count.
continue
}
if task == nil {
// No more queue
// tasks to pop!
break
}
// Append serialized task.
tasks = append(tasks, task)
delivery++ // incr count
}
for {
// Pop queued federator msgs.
task, err := p.popFederator()
if err != nil {
log.Errorf(ctx, "error popping federator message: %v", err)
errors++ // incr count
continue
}
if task == nil {
// No more queue
// tasks to pop!
break
}
// Append serialized task.
tasks = append(tasks, task)
federator++ // incr count
}
for {
// Pop queued client msgs.
task, err := p.popClient()
if err != nil {
log.Errorf(ctx, "error popping client message: %v", err)
continue
}
if task == nil {
// No more queue
// tasks to pop!
break
}
// Append serialized task.
tasks = append(tasks, task)
client++ // incr count
}
// Persist all serialized queued worker tasks to database.
if err := p.state.DB.PutWorkerTasks(ctx, tasks); err != nil {
return gtserror.Newf("error putting tasks in db: %w", err)
}
// Log recovered tasks.
log.WithContext(ctx).
WithField("delivery", delivery).
WithField("federator", federator).
WithField("client", client).
WithField("errors", errors).
Info("persisted queued tasks")
return nil
}
// pushDelivery parses a valid delivery.Delivery{} from serialized task data and pushes to queue.
func (p *Processor) pushDelivery(ctx context.Context, task *gtsmodel.WorkerTask) error {
dlv := new(delivery.Delivery)
// Deserialize the raw worker task data into delivery.
if err := dlv.Deserialize(task.TaskData); err != nil {
return gtserror.Newf("error deserializing delivery: %w", err)
}
var tsport transport.Transport
if uri := dlv.ActorID; uri != "" {
// Fetch the actor account by provided URI from db.
account, err := p.state.DB.GetAccountByURI(ctx, uri)
if err != nil {
return gtserror.Newf("error getting actor account %s from db: %w", uri, err)
}
// Fetch a transport for request signing for actor's account username.
tsport, err = p.transport.NewTransportForUsername(ctx, account.Username)
if err != nil {
return gtserror.Newf("error getting transport for actor %s: %w", uri, err)
}
} else {
var err error
// No actor was given, will be signed by instance account.
tsport, err = p.transport.NewTransportForUsername(ctx, "")
if err != nil {
return gtserror.Newf("error getting instance account transport: %w", err)
}
}
// Using transport, add actor signature to delivery.
if err := tsport.SignDelivery(dlv); err != nil {
return gtserror.Newf("error signing delivery: %w", err)
}
// Push deserialized task to delivery queue.
p.state.Workers.Delivery.Queue.Push(dlv)
return nil
}
// popDelivery pops delivery.Delivery{} from queue and serializes as valid task data.
func (p *Processor) popDelivery() (*gtsmodel.WorkerTask, error) {
// Pop waiting delivery from the delivery worker.
delivery, ok := p.state.Workers.Delivery.Queue.Pop()
if !ok {
return nil, nil
}
// Serialize the delivery task data.
data, err := delivery.Serialize()
if err != nil {
return nil, gtserror.Newf("error serializing delivery: %w", err)
}
return &gtsmodel.WorkerTask{
// ID is autoincrement
WorkerType: gtsmodel.DeliveryWorker,
TaskData: data,
CreatedAt: time.Now(),
}, nil
}
// pushClient parses a valid messages.FromFediAPI{} from serialized task data and pushes to queue.
func (p *Processor) pushFederator(ctx context.Context, task *gtsmodel.WorkerTask) error {
var msg messages.FromFediAPI
// Deserialize the raw worker task data into message.
if err := msg.Deserialize(task.TaskData); err != nil {
return gtserror.Newf("error deserializing federator message: %w", err)
}
if rcv := msg.Receiving; rcv != nil {
// Only a placeholder receiving account will be populated,
// fetch the actual model from database by persisted ID.
account, err := p.state.DB.GetAccountByID(ctx, rcv.ID)
if err != nil {
return gtserror.Newf("error fetching receiving account %s from db: %w", rcv.ID, err)
}
// Set the now populated
// receiving account model.
msg.Receiving = account
}
if req := msg.Requesting; req != nil {
// Only a placeholder requesting account will be populated,
// fetch the actual model from database by persisted ID.
account, err := p.state.DB.GetAccountByID(ctx, req.ID)
if err != nil {
return gtserror.Newf("error fetching requesting account %s from db: %w", req.ID, err)
}
// Set the now populated
// requesting account model.
msg.Requesting = account
}
// Push populated task to the federator queue.
p.state.Workers.Federator.Queue.Push(&msg)
return nil
}
// popFederator pops messages.FromFediAPI{} from queue and serializes as valid task data.
func (p *Processor) popFederator() (*gtsmodel.WorkerTask, error) {
// Pop waiting message from the federator worker.
msg, ok := p.state.Workers.Federator.Queue.Pop()
if !ok {
return nil, nil
}
// Serialize message task data.
data, err := msg.Serialize()
if err != nil {
return nil, gtserror.Newf("error serializing federator message: %w", err)
}
return &gtsmodel.WorkerTask{
// ID is autoincrement
WorkerType: gtsmodel.FederatorWorker,
TaskData: data,
CreatedAt: time.Now(),
}, nil
}
// pushClient parses a valid messages.FromClientAPI{} from serialized task data and pushes to queue.
func (p *Processor) pushClient(ctx context.Context, task *gtsmodel.WorkerTask) error {
var msg messages.FromClientAPI
// Deserialize the raw worker task data into message.
if err := msg.Deserialize(task.TaskData); err != nil {
return gtserror.Newf("error deserializing client message: %w", err)
}
if org := msg.Origin; org != nil {
// Only a placeholder origin account will be populated,
// fetch the actual model from database by persisted ID.
account, err := p.state.DB.GetAccountByID(ctx, org.ID)
if err != nil {
return gtserror.Newf("error fetching origin account %s from db: %w", org.ID, err)
}
// Set the now populated
// origin account model.
msg.Origin = account
}
if trg := msg.Target; trg != nil {
// Only a placeholder target account will be populated,
// fetch the actual model from database by persisted ID.
account, err := p.state.DB.GetAccountByID(ctx, trg.ID)
if err != nil {
return gtserror.Newf("error fetching target account %s from db: %w", trg.ID, err)
}
// Set the now populated
// target account model.
msg.Target = account
}
// Push populated task to the federator queue.
p.state.Workers.Client.Queue.Push(&msg)
return nil
}
// popClient pops messages.FromClientAPI{} from queue and serializes as valid task data.
func (p *Processor) popClient() (*gtsmodel.WorkerTask, error) {
// Pop waiting message from the client worker.
msg, ok := p.state.Workers.Client.Queue.Pop()
if !ok {
return nil, nil
}
// Serialize message task data.
data, err := msg.Serialize()
if err != nil {
return nil, gtserror.Newf("error serializing client message: %w", err)
}
return &gtsmodel.WorkerTask{
// ID is autoincrement
WorkerType: gtsmodel.ClientWorker,
TaskData: data,
CreatedAt: time.Now(),
}, nil
}

View file

@ -0,0 +1,421 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
"github.com/superseriousbusiness/gotosocial/testrig"
)
var (
// TODO: move these test values into
// the testrig test models area. They'll
// need to be as both WorkerTask and as
// the raw types themselves.
testDeliveries = []*delivery.Delivery{
{
ObjectID: "https://google.com/users/bigboy/follow/1",
TargetID: "https://askjeeves.com/users/smallboy",
Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Host": {"https://askjeeves.com"}}),
},
{
Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), http.Header{"Host": {"https://google.com"}}),
},
}
testFederatorMsgs = []*messages.FromFediAPI{
{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate,
TargetURI: "https://gotosocial.org",
Requesting: &gtsmodel.Account{ID: "654321"},
Receiving: &gtsmodel.Account{ID: "123456"},
},
{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate,
TargetURI: "https://uk-queen-is-dead.org",
Requesting: &gtsmodel.Account{ID: "123456"},
Receiving: &gtsmodel.Account{ID: "654321"},
},
}
testClientMsgs = []*messages.FromClientAPI{
{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate,
TargetURI: "https://gotosocial.org",
Origin: &gtsmodel.Account{ID: "654321"},
Target: &gtsmodel.Account{ID: "123456"},
},
{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate,
TargetURI: "https://uk-queen-is-dead.org",
Origin: &gtsmodel.Account{ID: "123456"},
Target: &gtsmodel.Account{ID: "654321"},
},
}
)
type WorkerTaskTestSuite struct {
AdminStandardTestSuite
}
func (suite *WorkerTaskTestSuite) TestFillWorkerQueues() {
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
var tasks []*gtsmodel.WorkerTask
for _, dlv := range testDeliveries {
// Serialize all test deliveries.
data, err := dlv.Serialize()
if err != nil {
panic(err)
}
// Append each serialized delivery to tasks.
tasks = append(tasks, &gtsmodel.WorkerTask{
WorkerType: gtsmodel.DeliveryWorker,
TaskData: data,
})
}
for _, msg := range testFederatorMsgs {
// Serialize all test messages.
data, err := msg.Serialize()
if err != nil {
panic(err)
}
if msg.Receiving != nil {
// Quick hack to bypass database errors for non-existing
// accounts, instead we just insert this into cache ;).
suite.state.Caches.DB.Account.Put(msg.Receiving)
suite.state.Caches.DB.AccountSettings.Put(&gtsmodel.AccountSettings{
AccountID: msg.Receiving.ID,
})
}
if msg.Requesting != nil {
// Quick hack to bypass database errors for non-existing
// accounts, instead we just insert this into cache ;).
suite.state.Caches.DB.Account.Put(msg.Requesting)
suite.state.Caches.DB.AccountSettings.Put(&gtsmodel.AccountSettings{
AccountID: msg.Requesting.ID,
})
}
// Append each serialized message to tasks.
tasks = append(tasks, &gtsmodel.WorkerTask{
WorkerType: gtsmodel.FederatorWorker,
TaskData: data,
})
}
for _, msg := range testClientMsgs {
// Serialize all test messages.
data, err := msg.Serialize()
if err != nil {
panic(err)
}
if msg.Origin != nil {
// Quick hack to bypass database errors for non-existing
// accounts, instead we just insert this into cache ;).
suite.state.Caches.DB.Account.Put(msg.Origin)
suite.state.Caches.DB.AccountSettings.Put(&gtsmodel.AccountSettings{
AccountID: msg.Origin.ID,
})
}
if msg.Target != nil {
// Quick hack to bypass database errors for non-existing
// accounts, instead we just insert this into cache ;).
suite.state.Caches.DB.Account.Put(msg.Target)
suite.state.Caches.DB.AccountSettings.Put(&gtsmodel.AccountSettings{
AccountID: msg.Target.ID,
})
}
// Append each serialized message to tasks.
tasks = append(tasks, &gtsmodel.WorkerTask{
WorkerType: gtsmodel.ClientWorker,
TaskData: data,
})
}
// Persist all test worker tasks to the database.
err := suite.state.DB.PutWorkerTasks(ctx, tasks)
suite.NoError(err)
// Fill the worker queues from persisted task data.
err = suite.adminProcessor.FillWorkerQueues(ctx)
suite.NoError(err)
var (
// Recovered
// task counts.
ndelivery int
nfederator int
nclient int
)
// Fetch current gotosocial instance account, for later checks.
instanceAcc, err := suite.state.DB.GetInstanceAccount(ctx, "")
suite.NoError(err)
for {
// Pop all queued delivery tasks from worker queue.
dlv, ok := suite.state.Workers.Delivery.Queue.Pop()
if !ok {
break
}
// Incr count.
ndelivery++
// Check that we have this message in slice.
err = containsSerializable(testDeliveries, dlv)
suite.NoError(err)
// Check that delivery request context has instance account pubkey.
pubKeyID := gtscontext.OutgoingPublicKeyID(dlv.Request.Context())
suite.Equal(instanceAcc.PublicKeyURI, pubKeyID)
signfn := gtscontext.HTTPClientSignFunc(dlv.Request.Context())
suite.NotNil(signfn)
}
for {
// Pop all queued federator messages from worker queue.
msg, ok := suite.state.Workers.Federator.Queue.Pop()
if !ok {
break
}
// Incr count.
nfederator++
// Check that we have this message in slice.
err = containsSerializable(testFederatorMsgs, msg)
suite.NoError(err)
}
for {
// Pop all queued client messages from worker queue.
msg, ok := suite.state.Workers.Client.Queue.Pop()
if !ok {
break
}
// Incr count.
nclient++
// Check that we have this message in slice.
err = containsSerializable(testClientMsgs, msg)
suite.NoError(err)
}
// Ensure recovered task counts as expected.
suite.Equal(len(testDeliveries), ndelivery)
suite.Equal(len(testFederatorMsgs), nfederator)
suite.Equal(len(testClientMsgs), nclient)
}
func (suite *WorkerTaskTestSuite) TestPersistWorkerQueues() {
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Push all test worker tasks to their respective queues.
suite.state.Workers.Delivery.Queue.Push(testDeliveries...)
suite.state.Workers.Federator.Queue.Push(testFederatorMsgs...)
suite.state.Workers.Client.Queue.Push(testClientMsgs...)
// Persist the worker queued tasks to database.
err := suite.adminProcessor.PersistWorkerQueues(ctx)
suite.NoError(err)
// Fetch all the persisted tasks from database.
tasks, err := suite.state.DB.GetWorkerTasks(ctx)
suite.NoError(err)
var (
// Persisted
// task counts.
ndelivery int
nfederator int
nclient int
)
// Check persisted task data.
for _, task := range tasks {
switch task.WorkerType {
case gtsmodel.DeliveryWorker:
var dlv delivery.Delivery
// Incr count.
ndelivery++
// Deserialize the persisted task data.
err := dlv.Deserialize(task.TaskData)
suite.NoError(err)
// Check that we have this delivery in slice.
err = containsSerializable(testDeliveries, &dlv)
suite.NoError(err)
case gtsmodel.FederatorWorker:
var msg messages.FromFediAPI
// Incr count.
nfederator++
// Deserialize the persisted task data.
err := msg.Deserialize(task.TaskData)
suite.NoError(err)
// Check that we have this message in slice.
err = containsSerializable(testFederatorMsgs, &msg)
suite.NoError(err)
case gtsmodel.ClientWorker:
var msg messages.FromClientAPI
// Incr count.
nclient++
// Deserialize the persisted task data.
err := msg.Deserialize(task.TaskData)
suite.NoError(err)
// Check that we have this message in slice.
err = containsSerializable(testClientMsgs, &msg)
suite.NoError(err)
default:
suite.T().Errorf("unexpected worker type: %d", task.WorkerType)
}
}
// Ensure persisted task counts as expected.
suite.Equal(len(testDeliveries), ndelivery)
suite.Equal(len(testFederatorMsgs), nfederator)
suite.Equal(len(testClientMsgs), nclient)
}
func (suite *WorkerTaskTestSuite) SetupTest() {
suite.AdminStandardTestSuite.SetupTest()
// we don't want workers running
testrig.StopWorkers(&suite.state)
}
func TestWorkerTaskTestSuite(t *testing.T) {
suite.Run(t, new(WorkerTaskTestSuite))
}
// containsSerializeable returns whether slice of serializables contains given serializable entry.
func containsSerializable[T interface{ Serialize() ([]byte, error) }](expect []T, have T) error {
// Serialize wanted value.
bh, err := have.Serialize()
if err != nil {
panic(err)
}
var strings []string
for _, t := range expect {
// Serialize expected value.
be, err := t.Serialize()
if err != nil {
panic(err)
}
// Alloc as string.
se := string(be)
if se == string(bh) {
// We have this entry!
return nil
}
// Add to serialized strings.
strings = append(strings, se)
}
return fmt.Errorf("could not find %s in %s", string(bh), strings)
}
// urlStr simply returns u.String() or "" if nil.
func urlStr(u *url.URL) string {
if u == nil {
return ""
}
return u.String()
}
// accountID simply returns account.ID or "" if nil.
func accountID(account *gtsmodel.Account) string {
if account == nil {
return ""
}
return account.ID
}
// toRequest creates httpclient.Request from HTTP method, URL and body data.
func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request {
var rbody io.Reader
if body != nil {
rbody = bytes.NewReader(body)
}
req, err := http.NewRequest(method, url, rbody)
if err != nil {
panic(err)
}
for key, values := range hdr {
for _, value := range values {
req.Header.Add(key, value)
}
}
return httpclient.WrapRequest(req)
}
// toJSON marshals input type as JSON data.
func toJSON(a any) []byte {
b, err := json.Marshal(a)
if err != nil {
panic(err)
}
return b
}

View file

@ -21,6 +21,7 @@
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"
@ -169,6 +170,38 @@ func (t *transport) prepare(
}, nil
}
func (t *transport) SignDelivery(dlv *delivery.Delivery) error {
if dlv.Request.GetBody == nil {
return gtserror.New("delivery request body not rewindable")
}
// Get a new copy of the request body.
body, err := dlv.Request.GetBody()
if err != nil {
return gtserror.Newf("error getting request body: %w", err)
}
// Read body data into memory.
data, err := io.ReadAll(body)
if err != nil {
return gtserror.Newf("error reading request body: %w", err)
}
// Get signing function for POST data.
// (note that delivery is ALWAYS POST).
sign := t.signPOST(data)
// Extract delivery context.
ctx := dlv.Request.Context()
// Update delivery request context with signing details.
ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID)
ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign)
dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
return nil
}
// getObjectID extracts an object ID from 'serialized' ActivityPub object map.
func getObjectID(obj map[string]interface{}) string {
switch t := obj["object"].(type) {

View file

@ -33,10 +33,6 @@
// be indexed (and so, dropped from queue)
// by any of these possible ID IRIs.
type Delivery struct {
// PubKeyID is the signing public key
// ID of the actor performing request.
PubKeyID string
// ActorID contains the ActivityPub
// actor ID IRI (if any) of the activity
// being sent out by this request.
@ -55,7 +51,7 @@ type Delivery struct {
// Request is the prepared (+ wrapped)
// httpclient.Client{} request that
// constitutes this ActivtyPub delivery.
Request httpclient.Request
Request *httpclient.Request
// internal fields.
next time.Time
@ -66,7 +62,6 @@ type Delivery struct {
// a json serialize / deserialize
// able shape that minimizes data.
type delivery struct {
PubKeyID string `json:"pub_key_id,omitempty"`
ActorID string `json:"actor_id,omitempty"`
ObjectID string `json:"object_id,omitempty"`
TargetID string `json:"target_id,omitempty"`
@ -101,7 +96,6 @@ func (dlv *Delivery) Serialize() ([]byte, error) {
// Marshal as internal JSON type.
return json.Marshal(delivery{
PubKeyID: dlv.PubKeyID,
ActorID: dlv.ActorID,
ObjectID: dlv.ObjectID,
TargetID: dlv.TargetID,
@ -125,7 +119,6 @@ func (dlv *Delivery) Deserialize(data []byte) error {
}
// Copy over simplest fields.
dlv.PubKeyID = idlv.PubKeyID
dlv.ActorID = idlv.ActorID
dlv.ObjectID = idlv.ObjectID
dlv.TargetID = idlv.TargetID
@ -143,6 +136,13 @@ func (dlv *Delivery) Deserialize(data []byte) error {
return err
}
// Copy over any stored header values.
for key, values := range idlv.Header {
for _, value := range values {
r.Header.Add(key, value)
}
}
// Wrap request in httpclient type.
dlv.Request = httpclient.WrapRequest(r)

View file

@ -35,32 +35,30 @@
}{
{
msg: delivery.Delivery{
PubKeyID: "https://google.com/users/bigboy#pubkey",
ActorID: "https://google.com/users/bigboy",
ObjectID: "https://google.com/users/bigboy/follow/1",
TargetID: "https://askjeeves.com/users/smallboy",
Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")),
Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Hello": {"world1", "world2"}}),
},
data: toJSON(map[string]any{
"pub_key_id": "https://google.com/users/bigboy#pubkey",
"actor_id": "https://google.com/users/bigboy",
"object_id": "https://google.com/users/bigboy/follow/1",
"target_id": "https://askjeeves.com/users/smallboy",
"method": "POST",
"url": "https://askjeeves.com/users/smallboy/inbox",
"body": []byte("data!"),
// "header": map[string][]string{},
"actor_id": "https://google.com/users/bigboy",
"object_id": "https://google.com/users/bigboy/follow/1",
"target_id": "https://askjeeves.com/users/smallboy",
"method": "POST",
"url": "https://askjeeves.com/users/smallboy/inbox",
"body": []byte("data!"),
"header": map[string][]string{"Hello": {"world1", "world2"}},
}),
},
{
msg: delivery.Delivery{
Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")),
Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), nil),
},
data: toJSON(map[string]any{
"method": "GET",
"url": "https://google.com",
"body": []byte("uwu im just a wittle seawch engwin"),
// "header": map[string][]string{},
// "header": map[string][]string{},
}),
},
}
@ -89,18 +87,18 @@ func TestDeserializeDelivery(t *testing.T) {
}
// Check that delivery fields are as expected.
assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID)
assert.Equal(t, test.msg.ActorID, msg.ActorID)
assert.Equal(t, test.msg.ObjectID, msg.ObjectID)
assert.Equal(t, test.msg.TargetID, msg.TargetID)
assert.Equal(t, test.msg.Request.Method, msg.Request.Method)
assert.Equal(t, test.msg.Request.URL, msg.Request.URL)
assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body))
assert.Equal(t, test.msg.Request.Header, msg.Request.Header)
}
}
// toRequest creates httpclient.Request from HTTP method, URL and body data.
func toRequest(method string, url string, body []byte) httpclient.Request {
func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request {
var rbody io.Reader
if body != nil {
rbody = bytes.NewReader(body)
@ -109,6 +107,11 @@ func toRequest(method string, url string, body []byte) httpclient.Request {
if err != nil {
panic(err)
}
for key, values := range hdr {
for _, value := range values {
req.Header.Add(key, value)
}
}
return httpclient.WrapRequest(req)
}

View file

@ -19,6 +19,7 @@
import (
"context"
"errors"
"slices"
"time"
@ -160,6 +161,13 @@ func (w *Worker) process(ctx context.Context) bool {
loop:
for {
// Before trying to get
// next delivery, check
// context still valid.
if ctx.Err() != nil {
return true
}
// Get next delivery.
dlv, ok := w.next(ctx)
if !ok {
@ -195,16 +203,30 @@ func (w *Worker) process(ctx context.Context) bool {
// Attempt delivery of AP request.
rsp, retry, err := w.Client.DoOnce(
&dlv.Request,
dlv.Request,
)
if err == nil {
switch {
case err == nil:
// Ensure body closed.
_ = rsp.Body.Close()
continue loop
}
if !retry {
case errors.Is(err, context.Canceled) &&
ctx.Err() != nil:
// In the case of our own context
// being cancelled, push delivery
// back onto queue for persisting.
//
// Note we specifically check against
// context.Canceled here as it will
// be faster than the mutex lock of
// ctx.Err(), so gives an initial
// faster check in the if-clause.
w.Queue.Push(dlv)
continue loop
case !retry:
// Drop deliveries when no
// retry requested, or they
// reached max (either).
@ -222,42 +244,36 @@ func (w *Worker) process(ctx context.Context) bool {
// next gets the next available delivery, blocking until available if necessary.
func (w *Worker) next(ctx context.Context) (*Delivery, bool) {
loop:
for {
// Try pop next queued.
dlv, ok := w.Queue.Pop()
// Try a fast-pop of queued
// delivery before anything.
dlv, ok := w.Queue.Pop()
if !ok {
// Check the backlog.
if len(w.backlog) > 0 {
if !ok {
// Check the backlog.
if len(w.backlog) > 0 {
// Sort by 'next' time.
sortDeliveries(w.backlog)
// Sort by 'next' time.
sortDeliveries(w.backlog)
// Pop next delivery.
dlv := w.popBacklog()
// Pop next delivery.
dlv := w.popBacklog()
return dlv, true
}
select {
// Backlog is empty, we MUST
// block until next enqueued.
case <-w.Queue.Wait():
continue loop
// Worker was stopped.
case <-ctx.Done():
return nil, false
}
return dlv, true
}
// Replace request context for worker state canceling.
ctx := gtscontext.WithValues(ctx, dlv.Request.Context())
dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
return dlv, true
// Block on next delivery push
// OR worker context canceled.
dlv, ok = w.Queue.PopCtx(ctx)
if !ok {
return nil, false
}
}
// Replace request context for worker state canceling.
ctx = gtscontext.WithValues(ctx, dlv.Request.Context())
dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
return dlv, true
}
// popBacklog pops next available from the backlog.

View file

@ -30,6 +30,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
"github.com/superseriousbusiness/httpsig"
)
@ -50,6 +51,10 @@ type Transport interface {
// transport client, retrying on certain preset errors.
POST(*http.Request, []byte) (*http.Response, error)
// SignDelivery adds HTTP request signing client "middleware"
// to the request context within given delivery.Delivery{}.
SignDelivery(*delivery.Delivery) error
// Deliver sends an ActivityStreams object.
Deliver(ctx context.Context, obj map[string]interface{}, to *url.URL) error

View file

@ -19,6 +19,7 @@
import (
"context"
"errors"
"codeberg.org/gruf/go-runners"
"codeberg.org/gruf/go-structr"
@ -147,9 +148,25 @@ func (w *MsgWorker[T]) process(ctx context.Context) {
return
}
// Attempt to process popped message type.
if err := w.Process(ctx, msg); err != nil {
// Attempt to process message.
err := w.Process(ctx, msg)
if err != nil {
log.Errorf(ctx, "%p: error processing: %v", w, err)
if errors.Is(err, context.Canceled) &&
ctx.Err() != nil {
// In the case of our own context
// being cancelled, push message
// back onto queue for persisting.
//
// Note we specifically check against
// context.Canceled here as it will
// be faster than the mutex lock of
// ctx.Err(), so gives an initial
// faster check in the if-clause.
w.Queue.Push(msg)
break
}
}
}
}

View file

@ -55,7 +55,8 @@ type Workers struct {
// StartScheduler starts the job scheduler.
func (w *Workers) StartScheduler() {
_ = w.Scheduler.Start() // false = already running
_ = w.Scheduler.Start()
// false = already running
log.Info(nil, "started scheduler")
}
@ -82,9 +83,12 @@ func (w *Workers) Start() {
log.Infof(nil, "started %d dereference workers", n)
}
// Stop will stop all of the contained worker pools (and global scheduler).
// Stop will stop all of the contained
// worker pools (and global scheduler).
func (w *Workers) Stop() {
_ = w.Scheduler.Stop() // false = not running
_ = w.Scheduler.Stop()
// false = not running
log.Info(nil, "stopped scheduler")
w.Delivery.Stop()
log.Info(nil, "stopped delivery workers")

View file

@ -29,6 +29,8 @@
var testModels = []interface{}{
&gtsmodel.Account{},
&gtsmodel.AccountNote{},
&gtsmodel.AccountSettings{},
&gtsmodel.AccountToEmoji{},
&gtsmodel.Application{},
&gtsmodel.Block{},
@ -67,8 +69,7 @@
&gtsmodel.Tombstone{},
&gtsmodel.Report{},
&gtsmodel.Rule{},
&gtsmodel.AccountNote{},
&gtsmodel.AccountSettings{},
&gtsmodel.WorkerTask{},
}
// NewTestDB returns a new initialized, empty database for testing.