mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-10-31 22:40:01 +00:00
[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code
* use new context for the stream context
* ensure stream gets closed on return of writeTo / readFrom WSConn()
* ensure stream write timeout gets cancelled
* remove embedded context type from Stream{}, reformat log messages for consistency
* use c.Request.Context() for context passed into Stream().Open()
* only return 1 boolean, fix tests to expect multiple stream types in messages
* changes to ping logic
* further improved ping logic
* don't export unused function types, update message sending to only include relevant stream type
* ensure stream gets closed 🤦
* update to error log on failed json marshal (instead of panic)
* inverse websocket read error checking to _ignore_ expected close errors
This commit is contained in:
parent
8cafa6b74b
commit
291e180990
14 changed files with 535 additions and 451 deletions
|
@ -22,10 +22,10 @@
|
|||
"slices"
|
||||
"time"
|
||||
|
||||
"codeberg.org/gruf/go-kv"
|
||||
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
streampkg "github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
|
@ -202,7 +202,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
|||
// functions pass messages into a channel, which we can
|
||||
// then read from and put into a websockets connection.
|
||||
stream, errWithCode := m.processor.Stream().Open(
|
||||
c.Request.Context(),
|
||||
c.Request.Context(), // this ctx is only used for logging
|
||||
account,
|
||||
streamType,
|
||||
)
|
||||
|
@ -213,10 +213,8 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
|||
|
||||
l := log.
|
||||
WithContext(c.Request.Context()).
|
||||
WithFields(kv.Fields{
|
||||
{"username", account.Username},
|
||||
{"streamID", stream.ID},
|
||||
}...)
|
||||
WithField("streamID", id.NewULID()).
|
||||
WithField("username", account.Username)
|
||||
|
||||
// Upgrade the incoming HTTP request. This hijacks the
|
||||
// underlying connection and reuses it for the websocket
|
||||
|
@ -227,18 +225,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
|||
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
l.Errorf("error upgrading websocket connection: %v", err)
|
||||
close(stream.Hangup)
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
l.Info("opened websocket connection")
|
||||
|
||||
// We perform the main websocket rw loops in a separate
|
||||
// goroutine in order to let the upgrade handler return.
|
||||
// This prevents the upgrade handler from holding open any
|
||||
// throttle / rate-limit request tokens which could become
|
||||
// problematic on instances with multiple users.
|
||||
go m.handleWSConn(account.Username, wsConn, stream)
|
||||
go m.handleWSConn(&l, wsConn, stream)
|
||||
}
|
||||
|
||||
// handleWSConn handles a two-way websocket streaming connection.
|
||||
|
@ -246,48 +242,39 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
|||
// into the connection. If any errors are encountered while reading
|
||||
// or writing (including expected errors like clients leaving), the
|
||||
// connection will be closed.
|
||||
func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) {
|
||||
// Create new context for the lifetime of this connection.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func (m *Module) handleWSConn(l *log.Entry, wsConn *websocket.Conn, stream *streampkg.Stream) {
|
||||
l.Info("opened websocket connection")
|
||||
|
||||
l := log.
|
||||
WithContext(ctx).
|
||||
WithFields(kv.Fields{
|
||||
{"username", username},
|
||||
{"streamID", stream.ID},
|
||||
}...)
|
||||
// Create new async context with cancel.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
|
||||
// Create ticker to send keepalive pings
|
||||
pinger := time.NewTicker(m.dTicker)
|
||||
|
||||
// Read messages coming from the Websocket client connection into the server.
|
||||
go func() {
|
||||
defer cancel()
|
||||
m.readFromWSConn(ctx, username, wsConn, stream)
|
||||
defer cncl()
|
||||
|
||||
// Read messages from websocket to server.
|
||||
m.readFromWSConn(ctx, wsConn, stream, l)
|
||||
}()
|
||||
|
||||
// Write messages coming from the processor into the Websocket client connection.
|
||||
go func() {
|
||||
defer cancel()
|
||||
m.writeToWSConn(ctx, username, wsConn, stream, pinger)
|
||||
defer cncl()
|
||||
|
||||
// Write messages from processor in websocket conn.
|
||||
m.writeToWSConn(ctx, wsConn, stream, m.dTicker, l)
|
||||
}()
|
||||
|
||||
// Wait for either the read or write functions to close, to indicate
|
||||
// that the client has left, or something else has gone wrong.
|
||||
// Wait for ctx
|
||||
// to be closed.
|
||||
<-ctx.Done()
|
||||
|
||||
// Close stream
|
||||
// straightaway.
|
||||
stream.Close()
|
||||
|
||||
// Tidy up underlying websocket connection.
|
||||
if err := wsConn.Close(); err != nil {
|
||||
l.Errorf("error closing websocket connection: %v", err)
|
||||
}
|
||||
|
||||
// Close processor channel so the processor knows
|
||||
// not to send any more messages to this stream.
|
||||
close(stream.Hangup)
|
||||
|
||||
// Stop ping ticker (tiny resource saving).
|
||||
pinger.Stop()
|
||||
|
||||
l.Info("closed websocket connection")
|
||||
}
|
||||
|
||||
|
@ -299,89 +286,64 @@ func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *s
|
|||
// if the given context is canceled.
|
||||
func (m *Module) readFromWSConn(
|
||||
ctx context.Context,
|
||||
username string,
|
||||
wsConn *websocket.Conn,
|
||||
stream *streampkg.Stream,
|
||||
l *log.Entry,
|
||||
) {
|
||||
l := log.
|
||||
WithContext(ctx).
|
||||
WithFields(kv.Fields{
|
||||
{"username", username},
|
||||
{"streamID", stream.ID},
|
||||
}...)
|
||||
|
||||
readLoop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Connection closed.
|
||||
break readLoop
|
||||
var msg struct {
|
||||
Type string `json:"type"`
|
||||
Stream string `json:"stream"`
|
||||
List string `json:"list,omitempty"`
|
||||
}
|
||||
|
||||
// Read JSON objects from the client and act on them.
|
||||
if err := wsConn.ReadJSON(&msg); err != nil {
|
||||
// Only log an error if something weird happened.
|
||||
// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
|
||||
if !websocket.IsCloseError(err, []int{
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNoStatusReceived,
|
||||
}...) {
|
||||
l.Errorf("error during websocket read: %v", err)
|
||||
}
|
||||
|
||||
// The connection is gone; no
|
||||
// further streaming possible.
|
||||
break
|
||||
}
|
||||
|
||||
// Messages *from* the WS connection are infrequent
|
||||
// and usually interesting, so log this at info.
|
||||
l.Infof("received websocket message: %+v", msg)
|
||||
|
||||
// Ignore if the updateStreamType is unknown (or missing),
|
||||
// so a bad client can't cause extra memory allocations
|
||||
if !slices.Contains(streampkg.AllStatusTimelines, msg.Stream) {
|
||||
l.Warnf("unknown 'stream' field: %v", msg)
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.List != "" {
|
||||
// If a list is given, add this to
|
||||
// the stream name as this is how we
|
||||
// we track stream types internally.
|
||||
msg.Stream += ":" + msg.List
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case "subscribe":
|
||||
stream.Subscribe(msg.Stream)
|
||||
case "unsubscribe":
|
||||
stream.Unsubscribe(msg.Stream)
|
||||
default:
|
||||
// Read JSON objects from the client and act on them.
|
||||
var msg map[string]string
|
||||
if err := wsConn.ReadJSON(&msg); err != nil {
|
||||
// Only log an error if something weird happened.
|
||||
// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
|
||||
if websocket.IsUnexpectedCloseError(err, []int{
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNoStatusReceived,
|
||||
}...) {
|
||||
l.Errorf("error reading from websocket: %v", err)
|
||||
}
|
||||
|
||||
// The connection is gone; no
|
||||
// further streaming possible.
|
||||
break readLoop
|
||||
}
|
||||
|
||||
// Messages *from* the WS connection are infrequent
|
||||
// and usually interesting, so log this at info.
|
||||
l.Infof("received message from websocket: %v", msg)
|
||||
|
||||
// If the message contains 'stream' and 'type' fields, we can
|
||||
// update the set of timelines that are subscribed for events.
|
||||
updateType, ok := msg["type"]
|
||||
if !ok {
|
||||
l.Warn("'type' field not provided")
|
||||
continue
|
||||
}
|
||||
|
||||
updateStream, ok := msg["stream"]
|
||||
if !ok {
|
||||
l.Warn("'stream' field not provided")
|
||||
continue
|
||||
}
|
||||
|
||||
// Ignore if the updateStreamType is unknown (or missing),
|
||||
// so a bad client can't cause extra memory allocations
|
||||
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
|
||||
l.Warnf("unknown 'stream' field: %v", msg)
|
||||
continue
|
||||
}
|
||||
|
||||
updateList, ok := msg["list"]
|
||||
if ok {
|
||||
updateStream += ":" + updateList
|
||||
}
|
||||
|
||||
switch updateType {
|
||||
case "subscribe":
|
||||
stream.Lock()
|
||||
stream.StreamTypes[updateStream] = true
|
||||
stream.Unlock()
|
||||
case "unsubscribe":
|
||||
stream.Lock()
|
||||
delete(stream.StreamTypes, updateStream)
|
||||
stream.Unlock()
|
||||
default:
|
||||
l.Warnf("invalid 'type' field: %v", msg)
|
||||
}
|
||||
l.Warnf("invalid 'type' field: %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
l.Debug("finished reading from websocket connection")
|
||||
l.Debug("finished websocket read")
|
||||
}
|
||||
|
||||
// writeToWSConn receives messages coming from the processor via the
|
||||
|
@ -393,46 +355,47 @@ func (m *Module) readFromWSConn(
|
|||
// if the given context is canceled.
|
||||
func (m *Module) writeToWSConn(
|
||||
ctx context.Context,
|
||||
username string,
|
||||
wsConn *websocket.Conn,
|
||||
stream *streampkg.Stream,
|
||||
pinger *time.Ticker,
|
||||
ping time.Duration,
|
||||
l *log.Entry,
|
||||
) {
|
||||
l := log.
|
||||
WithContext(ctx).
|
||||
WithFields(kv.Fields{
|
||||
{"username", username},
|
||||
{"streamID", stream.ID},
|
||||
}...)
|
||||
|
||||
writeLoop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Connection closed.
|
||||
break writeLoop
|
||||
// Wrap context with timeout to send a ping.
|
||||
pingctx, cncl := context.WithTimeout(ctx, ping)
|
||||
|
||||
case msg := <-stream.Messages:
|
||||
// Received a new message from the processor.
|
||||
l.Tracef("writing message to websocket: %+v", msg)
|
||||
if err := wsConn.WriteJSON(msg); err != nil {
|
||||
l.Debugf("error writing json to websocket: %v", err)
|
||||
break writeLoop
|
||||
}
|
||||
// Block on receipt of msg.
|
||||
msg, ok := stream.Recv(pingctx)
|
||||
|
||||
// Reset pinger on successful send, since
|
||||
// we know the connection is still there.
|
||||
pinger.Reset(m.dTicker)
|
||||
// Check if cancel because ping.
|
||||
pinged := (pingctx.Err() != nil)
|
||||
cncl()
|
||||
|
||||
case <-pinger.C:
|
||||
// Time to send a keep-alive "ping".
|
||||
l.Trace("writing ping control message to websocket")
|
||||
switch {
|
||||
case !ok && pinged:
|
||||
// The ping context timed out!
|
||||
l.Trace("writing websocket ping")
|
||||
|
||||
// Wrapped context time-out, send a keep-alive "ping".
|
||||
if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
|
||||
l.Debugf("error writing ping to websocket: %v", err)
|
||||
break writeLoop
|
||||
l.Debugf("error writing websocket ping: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
case !ok:
|
||||
// Stream was
|
||||
// closed.
|
||||
return
|
||||
}
|
||||
|
||||
l.Trace("writing websocket message: %+v", msg)
|
||||
|
||||
// Received a new message from the processor.
|
||||
if err := wsConn.WriteJSON(msg); err != nil {
|
||||
l.Debugf("error writing websocket message: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
l.Debug("finished writing to websocket connection")
|
||||
l.Debug("finished websocket write")
|
||||
}
|
||||
|
|
|
@ -18,38 +18,16 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"context"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
||||
// Delete streams the delete of the given statusID to *ALL* open streams.
|
||||
func (p *Processor) Delete(statusID string) error {
|
||||
errs := []string{}
|
||||
|
||||
// get all account IDs with open streams
|
||||
accountIDs := []string{}
|
||||
p.streamMap.Range(func(k interface{}, _ interface{}) bool {
|
||||
key, ok := k.(string)
|
||||
if !ok {
|
||||
panic("streamMap key was not a string (account id)")
|
||||
}
|
||||
|
||||
accountIDs = append(accountIDs, key)
|
||||
return true
|
||||
func (p *Processor) Delete(ctx context.Context, statusID string) {
|
||||
p.streams.PostAll(ctx, stream.Message{
|
||||
Payload: statusID,
|
||||
Event: stream.EventTypeDelete,
|
||||
Stream: stream.AllStatusTimelines,
|
||||
})
|
||||
|
||||
// stream the delete to every account
|
||||
for _, accountID := range accountIDs {
|
||||
if err := p.toAccount(statusID, stream.EventTypeDelete, stream.AllStatusTimelines, accountID); err != nil {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) != 0 {
|
||||
return fmt.Errorf("one or more errors streaming status delete: %s", strings.Join(errs, ";"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -18,20 +18,29 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/gruf/go-byteutil"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
||||
// Notify streams the given notification to any open, appropriate streams belonging to the given account.
|
||||
func (p *Processor) Notify(n *apimodel.Notification, account *gtsmodel.Account) error {
|
||||
bytes, err := json.Marshal(n)
|
||||
func (p *Processor) Notify(ctx context.Context, account *gtsmodel.Account, notif *apimodel.Notification) {
|
||||
b, err := json.Marshal(notif)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling notification to json: %s", err)
|
||||
log.Errorf(ctx, "error marshaling json: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
return p.toAccount(string(bytes), stream.EventTypeNotification, []string{stream.TimelineNotifications, stream.TimelineHome}, account.ID)
|
||||
p.streams.Post(ctx, account.ID, stream.Message{
|
||||
Payload: byteutil.B2S(b),
|
||||
Event: stream.EventTypeNotification,
|
||||
Stream: []string{
|
||||
stream.TimelineNotifications,
|
||||
stream.TimelineHome,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
|
|
@ -49,10 +49,11 @@ func (suite *NotificationTestSuite) TestStreamNotification() {
|
|||
Account: followAccountAPIModel,
|
||||
}
|
||||
|
||||
err = suite.streamProcessor.Notify(notification, account)
|
||||
suite.NoError(err)
|
||||
suite.streamProcessor.Notify(context.Background(), account, notification)
|
||||
|
||||
msg, ok := openStream.Recv(context.Background())
|
||||
suite.True(ok)
|
||||
|
||||
msg := <-openStream.Messages
|
||||
dst := new(bytes.Buffer)
|
||||
err = json.Indent(dst, []byte(msg.Payload), "", " ")
|
||||
suite.NoError(err)
|
||||
|
|
|
@ -19,13 +19,10 @@
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/gruf/go-kv"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
@ -37,97 +34,5 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT
|
|||
{"streamType", streamType},
|
||||
}...)
|
||||
l.Debug("received open stream request")
|
||||
|
||||
var (
|
||||
streamID string
|
||||
err error
|
||||
)
|
||||
|
||||
// Each stream needs a unique ID so we know to close it.
|
||||
streamID, err = id.NewRandomULID()
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err))
|
||||
}
|
||||
|
||||
// Each stream can be subscibed to multiple types.
|
||||
// Record them in a set, and include the initial one
|
||||
// if it was given to us.
|
||||
streamTypes := map[string]any{}
|
||||
if streamType != "" {
|
||||
streamTypes[streamType] = true
|
||||
}
|
||||
|
||||
newStream := &stream.Stream{
|
||||
ID: streamID,
|
||||
StreamTypes: streamTypes,
|
||||
Messages: make(chan *stream.Message, 100),
|
||||
Hangup: make(chan interface{}, 1),
|
||||
Connected: true,
|
||||
}
|
||||
go p.waitToCloseStream(account, newStream)
|
||||
|
||||
v, ok := p.streamMap.Load(account.ID)
|
||||
if ok {
|
||||
// There is an entry in the streamMap
|
||||
// for this account. Parse it out.
|
||||
streamsForAccount, ok := v.(*stream.StreamsForAccount)
|
||||
if !ok {
|
||||
return nil, gtserror.NewErrorInternalError(errors.New("stream map error"))
|
||||
}
|
||||
|
||||
// Append new stream to existing entry.
|
||||
streamsForAccount.Lock()
|
||||
streamsForAccount.Streams = append(streamsForAccount.Streams, newStream)
|
||||
streamsForAccount.Unlock()
|
||||
} else {
|
||||
// There is no entry in the streamMap for
|
||||
// this account yet. Create one and store it.
|
||||
p.streamMap.Store(account.ID, &stream.StreamsForAccount{
|
||||
Streams: []*stream.Stream{
|
||||
newStream,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return newStream, nil
|
||||
}
|
||||
|
||||
// waitToCloseStream waits until the hangup channel is closed for the given stream.
|
||||
// It then iterates through the map of streams stored by the processor, removes the stream from it,
|
||||
// and then closes the messages channel of the stream to indicate that the channel should no longer be read from.
|
||||
func (p *Processor) waitToCloseStream(account *gtsmodel.Account, thisStream *stream.Stream) {
|
||||
<-thisStream.Hangup // wait for a hangup message
|
||||
|
||||
// lock the stream to prevent more messages being put in it while we work
|
||||
thisStream.Lock()
|
||||
defer thisStream.Unlock()
|
||||
|
||||
// indicate the stream is no longer connected
|
||||
thisStream.Connected = false
|
||||
|
||||
// load and parse the entry for this account from the stream map
|
||||
v, ok := p.streamMap.Load(account.ID)
|
||||
if !ok || v == nil {
|
||||
return
|
||||
}
|
||||
streamsForAccount, ok := v.(*stream.StreamsForAccount)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// lock the streams for account while we remove this stream from its slice
|
||||
streamsForAccount.Lock()
|
||||
defer streamsForAccount.Unlock()
|
||||
|
||||
// put everything into modified streams *except* the stream we're removing
|
||||
modifiedStreams := []*stream.Stream{}
|
||||
for _, s := range streamsForAccount.Streams {
|
||||
if s.ID != thisStream.ID {
|
||||
modifiedStreams = append(modifiedStreams, s)
|
||||
}
|
||||
}
|
||||
streamsForAccount.Streams = modifiedStreams
|
||||
|
||||
// finally close the messages channel so no more messages can be read from it
|
||||
close(thisStream.Messages)
|
||||
return p.streams.Open(account.ID, streamType), nil
|
||||
}
|
||||
|
|
|
@ -18,21 +18,26 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/gruf/go-byteutil"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
||||
// StatusUpdate streams the given edited status to any open, appropriate
|
||||
// streams belonging to the given account.
|
||||
func (p *Processor) StatusUpdate(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
|
||||
bytes, err := json.Marshal(s)
|
||||
// StatusUpdate streams the given edited status to any open, appropriate streams belonging to the given account.
|
||||
func (p *Processor) StatusUpdate(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
|
||||
b, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling status to json: %s", err)
|
||||
log.Errorf(ctx, "error marshaling json: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
return p.toAccount(string(bytes), stream.EventTypeStatusUpdate, streamTypes, account.ID)
|
||||
p.streams.Post(ctx, account.ID, stream.Message{
|
||||
Payload: byteutil.B2S(b),
|
||||
Event: stream.EventTypeStatusUpdate,
|
||||
Stream: []string{streamType},
|
||||
})
|
||||
}
|
||||
|
|
|
@ -42,10 +42,11 @@ func (suite *StatusUpdateTestSuite) TestStreamNotification() {
|
|||
apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account)
|
||||
suite.NoError(err)
|
||||
|
||||
err = suite.streamProcessor.StatusUpdate(apiStatus, account, []string{stream.TimelineHome})
|
||||
suite.NoError(err)
|
||||
suite.streamProcessor.StatusUpdate(context.Background(), account, apiStatus, stream.TimelineHome)
|
||||
|
||||
msg, ok := openStream.Recv(context.Background())
|
||||
suite.True(ok)
|
||||
|
||||
msg := <-openStream.Messages
|
||||
dst := new(bytes.Buffer)
|
||||
err = json.Indent(dst, []byte(msg.Payload), "", " ")
|
||||
suite.NoError(err)
|
||||
|
|
|
@ -18,8 +18,6 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
|
@ -28,53 +26,13 @@
|
|||
type Processor struct {
|
||||
state *state.State
|
||||
oauthServer oauth.Server
|
||||
streamMap *sync.Map
|
||||
streams stream.Streams
|
||||
}
|
||||
|
||||
func New(state *state.State, oauthServer oauth.Server) Processor {
|
||||
return Processor{
|
||||
state: state,
|
||||
oauthServer: oauthServer,
|
||||
streamMap: &sync.Map{},
|
||||
streams: stream.Streams{},
|
||||
}
|
||||
}
|
||||
|
||||
// toAccount streams the given payload with the given event type to any streams currently open for the given account ID.
|
||||
func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error {
|
||||
// Load all streams open for this account.
|
||||
v, ok := p.streamMap.Load(accountID)
|
||||
if !ok {
|
||||
return nil // No entry = nothing to stream.
|
||||
}
|
||||
streamsForAccount := v.(*stream.StreamsForAccount)
|
||||
|
||||
streamsForAccount.Lock()
|
||||
defer streamsForAccount.Unlock()
|
||||
|
||||
for _, s := range streamsForAccount.Streams {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if !s.Connected {
|
||||
continue
|
||||
}
|
||||
|
||||
typeLoop:
|
||||
for _, streamType := range streamTypes {
|
||||
if _, found := s.StreamTypes[streamType]; found {
|
||||
s.Messages <- &stream.Message{
|
||||
Stream: []string{streamType},
|
||||
Event: string(event),
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
// Break out to the outer loop,
|
||||
// to avoid sending duplicates of
|
||||
// the same event to the same stream.
|
||||
break typeLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -18,20 +18,26 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/gruf/go-byteutil"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
||||
// Update streams the given update to any open, appropriate streams belonging to the given account.
|
||||
func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
|
||||
bytes, err := json.Marshal(s)
|
||||
func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
|
||||
b, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling status to json: %s", err)
|
||||
log.Errorf(ctx, "error marshaling json: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID)
|
||||
p.streams.Post(ctx, account.ID, stream.Message{
|
||||
Payload: byteutil.B2S(b),
|
||||
Event: stream.EventTypeUpdate,
|
||||
Stream: []string{streamType},
|
||||
})
|
||||
}
|
||||
|
|
|
@ -116,23 +116,20 @@ func (suite *FromClientAPITestSuite) checkStreamed(
|
|||
expectPayload string,
|
||||
expectEventType string,
|
||||
) {
|
||||
var msg *stream.Message
|
||||
streamLoop:
|
||||
for {
|
||||
select {
|
||||
case msg = <-str.Messages:
|
||||
break streamLoop // Got it.
|
||||
case <-time.After(5 * time.Second):
|
||||
break streamLoop // Didn't get it.
|
||||
}
|
||||
|
||||
// Set a 5s timeout on context.
|
||||
ctx := context.Background()
|
||||
ctx, cncl := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cncl()
|
||||
|
||||
msg, ok := str.Recv(ctx)
|
||||
|
||||
if expectMessage && !ok {
|
||||
suite.FailNow("expected a message but message was not received")
|
||||
}
|
||||
|
||||
if expectMessage && msg == nil {
|
||||
suite.FailNow("expected a message but message was nil")
|
||||
}
|
||||
|
||||
if !expectMessage && msg != nil {
|
||||
suite.FailNow("expected no message but message was not nil")
|
||||
if !expectMessage && ok {
|
||||
suite.FailNow("expected no message but message was received")
|
||||
}
|
||||
|
||||
if expectPayload != "" && msg.Payload != expectPayload {
|
||||
|
|
|
@ -130,14 +130,9 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() {
|
|||
suite.Equal(replyingStatus.ID, notif.StatusID)
|
||||
suite.False(*notif.Read)
|
||||
|
||||
// the notification should be streamed
|
||||
var msg *stream.Message
|
||||
select {
|
||||
case msg = <-wssStream.Messages:
|
||||
// fine
|
||||
case <-time.After(5 * time.Second):
|
||||
suite.FailNow("no message from wssStream")
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
|
||||
msg, ok := wssStream.Recv(ctx)
|
||||
suite.True(ok)
|
||||
|
||||
suite.Equal(stream.EventTypeNotification, msg.Event)
|
||||
suite.NotEmpty(msg.Payload)
|
||||
|
@ -203,14 +198,10 @@ func (suite *FromFediAPITestSuite) TestProcessFave() {
|
|||
suite.Equal(fave.StatusID, notif.StatusID)
|
||||
suite.False(*notif.Read)
|
||||
|
||||
// 2. a notification should be streamed
|
||||
var msg *stream.Message
|
||||
select {
|
||||
case msg = <-wssStream.Messages:
|
||||
// fine
|
||||
case <-time.After(5 * time.Second):
|
||||
suite.FailNow("no message from wssStream")
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
|
||||
msg, ok := wssStream.Recv(ctx)
|
||||
suite.True(ok)
|
||||
|
||||
suite.Equal(stream.EventTypeNotification, msg.Event)
|
||||
suite.NotEmpty(msg.Payload)
|
||||
suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream)
|
||||
|
@ -277,7 +268,9 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount(
|
|||
suite.False(*notif.Read)
|
||||
|
||||
// 2. no notification should be streamed to the account that received the fave message, because they weren't the target
|
||||
suite.Empty(wssStream.Messages)
|
||||
ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
|
||||
_, ok := wssStream.Recv(ctx)
|
||||
suite.False(ok)
|
||||
}
|
||||
|
||||
func (suite *FromFediAPITestSuite) TestProcessAccountDelete() {
|
||||
|
@ -405,14 +398,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
|
|||
})
|
||||
suite.NoError(err)
|
||||
|
||||
// a notification should be streamed
|
||||
var msg *stream.Message
|
||||
select {
|
||||
case msg = <-wssStream.Messages:
|
||||
// fine
|
||||
case <-time.After(5 * time.Second):
|
||||
suite.FailNow("no message from wssStream")
|
||||
}
|
||||
ctx, _ = context.WithTimeout(ctx, time.Second*5)
|
||||
msg, ok := wssStream.Recv(context.Background())
|
||||
suite.True(ok)
|
||||
|
||||
suite.Equal(stream.EventTypeNotification, msg.Event)
|
||||
suite.NotEmpty(msg.Payload)
|
||||
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
|
||||
|
@ -423,7 +412,7 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
|
|||
suite.Equal(originAccount.ID, notif.Account.ID)
|
||||
|
||||
// no messages should have been sent out, since we didn't need to federate an accept
|
||||
suite.Empty(suite.httpClient.SentMessages)
|
||||
suite.Empty(&suite.httpClient.SentMessages)
|
||||
}
|
||||
|
||||
func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
|
||||
|
@ -503,14 +492,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
|
|||
suite.Equal(originAccount.URI, accept.To)
|
||||
suite.Equal("Accept", accept.Type)
|
||||
|
||||
// a notification should be streamed
|
||||
var msg *stream.Message
|
||||
select {
|
||||
case msg = <-wssStream.Messages:
|
||||
// fine
|
||||
case <-time.After(5 * time.Second):
|
||||
suite.FailNow("no message from wssStream")
|
||||
}
|
||||
ctx, _ = context.WithTimeout(ctx, time.Second*5)
|
||||
msg, ok := wssStream.Recv(context.Background())
|
||||
suite.True(ok)
|
||||
|
||||
suite.Equal(stream.EventTypeNotification, msg.Event)
|
||||
suite.NotEmpty(msg.Payload)
|
||||
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
|
||||
|
|
|
@ -394,10 +394,7 @@ func (s *surface) notify(
|
|||
if err != nil {
|
||||
return gtserror.Newf("error converting notification to api representation: %w", err)
|
||||
}
|
||||
|
||||
if err := s.stream.Notify(apiNotif, targetAccount); err != nil {
|
||||
return gtserror.Newf("error streaming notification to account: %w", err)
|
||||
}
|
||||
s.stream.Notify(ctx, targetAccount, apiNotif)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -348,11 +348,7 @@ func (s *surface) timelineStatus(
|
|||
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
|
||||
return true, err
|
||||
}
|
||||
|
||||
if err := s.stream.Update(apiStatus, account, []string{streamType}); err != nil {
|
||||
err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
|
||||
return true, err
|
||||
}
|
||||
s.stream.Update(ctx, account, apiStatus, streamType)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
@ -363,12 +359,11 @@ func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string
|
|||
if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.stream.Delete(statusID)
|
||||
s.stream.Delete(ctx, statusID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// invalidateStatusFromTimelines does cache invalidation on the given status by
|
||||
|
@ -555,11 +550,6 @@ func (s *surface) timelineStreamStatusUpdate(
|
|||
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.stream.StatusUpdate(apiStatus, account, []string{streamType}); err != nil {
|
||||
err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.stream.StatusUpdate(ctx, account, apiStatus, streamType)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -17,36 +17,65 @@
|
|||
|
||||
package stream
|
||||
|
||||
import "sync"
|
||||
|
||||
const (
|
||||
// EventTypeNotification -- a user should be shown a notification
|
||||
EventTypeNotification string = "notification"
|
||||
// EventTypeUpdate -- a user should be shown an update in their timeline
|
||||
EventTypeUpdate string = "update"
|
||||
// EventTypeDelete -- something should be deleted from a user
|
||||
EventTypeDelete string = "delete"
|
||||
// EventTypeStatusUpdate -- something in the user's timeline has been edited
|
||||
// (yes this is a confusing name, blame Mastodon)
|
||||
EventTypeStatusUpdate string = "status.update"
|
||||
import (
|
||||
"context"
|
||||
"maps"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
// TimelineLocal -- public statuses from the LOCAL timeline.
|
||||
TimelineLocal string = "public:local"
|
||||
// TimelinePublic -- public statuses, including federated ones.
|
||||
TimelinePublic string = "public"
|
||||
// TimelineHome -- statuses for a user's Home timeline.
|
||||
TimelineHome string = "user"
|
||||
// TimelineNotifications -- notification events.
|
||||
TimelineNotifications string = "user:notification"
|
||||
// TimelineDirect -- statuses sent to a user directly.
|
||||
TimelineDirect string = "direct"
|
||||
// TimelineList -- statuses for a user's list timeline.
|
||||
TimelineList string = "list"
|
||||
// EventTypeNotification -- a user
|
||||
// should be shown a notification.
|
||||
EventTypeNotification = "notification"
|
||||
|
||||
// EventTypeUpdate -- a user should
|
||||
// be shown an update in their timeline.
|
||||
EventTypeUpdate = "update"
|
||||
|
||||
// EventTypeDelete -- something
|
||||
// should be deleted from a user.
|
||||
EventTypeDelete = "delete"
|
||||
|
||||
// EventTypeStatusUpdate -- something in the
|
||||
// user's timeline has been edited (yes this
|
||||
// is a confusing name, blame Mastodon ...).
|
||||
EventTypeStatusUpdate = "status.update"
|
||||
)
|
||||
|
||||
// AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes.
|
||||
const (
|
||||
// TimelineLocal:
|
||||
// All public posts originating from this
|
||||
// server. Analogous to the local timeline.
|
||||
TimelineLocal = "public:local"
|
||||
|
||||
// TimelinePublic:
|
||||
// All public posts known to the server.
|
||||
// Analogous to the federated timeline.
|
||||
TimelinePublic = "public"
|
||||
|
||||
// TimelineHome:
|
||||
// Events related to the current user, such
|
||||
// as home feed updates and notifications.
|
||||
TimelineHome = "user"
|
||||
|
||||
// TimelineNotifications:
|
||||
// Notifications for the current user.
|
||||
TimelineNotifications = "user:notification"
|
||||
|
||||
// TimelineDirect:
|
||||
// Updates to direct conversations.
|
||||
TimelineDirect = "direct"
|
||||
|
||||
// TimelineList:
|
||||
// Updates to a specific list.
|
||||
TimelineList = "list"
|
||||
)
|
||||
|
||||
// AllStatusTimelines contains all Timelines
|
||||
// that a status could conceivably be delivered
|
||||
// to, useful for sending out status deletes.
|
||||
var AllStatusTimelines = []string{
|
||||
TimelineLocal,
|
||||
TimelinePublic,
|
||||
|
@ -55,38 +84,298 @@
|
|||
TimelineList,
|
||||
}
|
||||
|
||||
// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time.
|
||||
// TODO: put a limit on this
|
||||
type StreamsForAccount struct {
|
||||
// The currently held streams for this account
|
||||
Streams []*Stream
|
||||
// Mutex to lock/unlock when modifying the slice of streams.
|
||||
sync.Mutex
|
||||
type Streams struct {
|
||||
streams map[string][]*Stream
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Stream represents one open stream for a client.
|
||||
// Open will open open a new Stream for given account ID and stream types, the given context will be passed to Stream.
|
||||
func (s *Streams) Open(accountID string, streamTypes ...string) *Stream {
|
||||
if len(streamTypes) == 0 {
|
||||
panic("no stream types given")
|
||||
}
|
||||
|
||||
// Prep new Stream.
|
||||
str := new(Stream)
|
||||
str.done = make(chan struct{})
|
||||
str.msgCh = make(chan Message, 50) // TODO: make configurable
|
||||
for _, streamType := range streamTypes {
|
||||
str.Subscribe(streamType)
|
||||
}
|
||||
|
||||
// TODO: add configurable
|
||||
// max streams per account.
|
||||
|
||||
// Acquire lock.
|
||||
s.mutex.Lock()
|
||||
|
||||
if s.streams == nil {
|
||||
// Main stream-map needs allocating.
|
||||
s.streams = make(map[string][]*Stream)
|
||||
}
|
||||
|
||||
// Add new stream for account.
|
||||
strs := s.streams[accountID]
|
||||
strs = append(strs, str)
|
||||
s.streams[accountID] = strs
|
||||
|
||||
// Register close callback
|
||||
// to remove stream from our
|
||||
// internal map for this account.
|
||||
str.close = func() {
|
||||
s.mutex.Lock()
|
||||
strs := s.streams[accountID]
|
||||
strs = slices.DeleteFunc(strs, func(s *Stream) bool {
|
||||
return s == str // remove 'str' ptr
|
||||
})
|
||||
s.streams[accountID] = strs
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Done with lock.
|
||||
s.mutex.Unlock()
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// Post will post the given message to all streams of given account ID matching type.
|
||||
func (s *Streams) Post(ctx context.Context, accountID string, msg Message) bool {
|
||||
var deferred []func() bool
|
||||
|
||||
// Acquire lock.
|
||||
s.mutex.Lock()
|
||||
|
||||
// Iterate all streams stored for account.
|
||||
for _, str := range s.streams[accountID] {
|
||||
|
||||
// Check whether stream supports any of our message targets.
|
||||
if stype := str.getStreamType(msg.Stream...); stype != "" {
|
||||
|
||||
// Rescope var
|
||||
// to prevent
|
||||
// ptr reuse.
|
||||
stream := str
|
||||
|
||||
// Use a message copy to *only*
|
||||
// include the supported stream.
|
||||
msgCopy := Message{
|
||||
Stream: []string{stype},
|
||||
Event: msg.Event,
|
||||
Payload: msg.Payload,
|
||||
}
|
||||
|
||||
// Send message to supported stream
|
||||
// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
|
||||
// This prevents deadlocks between each
|
||||
// msg channel and main Streams{} mutex.
|
||||
deferred = append(deferred, func() bool {
|
||||
return stream.send(ctx, msgCopy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Done with lock.
|
||||
s.mutex.Unlock()
|
||||
|
||||
var ok bool
|
||||
|
||||
// Execute deferred outside lock.
|
||||
for _, deferfn := range deferred {
|
||||
v := deferfn()
|
||||
ok = ok && v
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// PostAll will post the given message to all streams with matching types.
|
||||
func (s *Streams) PostAll(ctx context.Context, msg Message) bool {
|
||||
var deferred []func() bool
|
||||
|
||||
// Acquire lock.
|
||||
s.mutex.Lock()
|
||||
|
||||
// Iterate ALL stored streams.
|
||||
for _, strs := range s.streams {
|
||||
for _, str := range strs {
|
||||
|
||||
// Check whether stream supports any of our message targets.
|
||||
if stype := str.getStreamType(msg.Stream...); stype != "" {
|
||||
|
||||
// Rescope var
|
||||
// to prevent
|
||||
// ptr reuse.
|
||||
stream := str
|
||||
|
||||
// Use a message copy to *only*
|
||||
// include the supported stream.
|
||||
msgCopy := Message{
|
||||
Stream: []string{stype},
|
||||
Event: msg.Event,
|
||||
Payload: msg.Payload,
|
||||
}
|
||||
|
||||
// Send message to supported stream
|
||||
// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
|
||||
// This prevents deadlocks between each
|
||||
// msg channel and main Streams{} mutex.
|
||||
deferred = append(deferred, func() bool {
|
||||
return stream.send(ctx, msgCopy)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Done with lock.
|
||||
s.mutex.Unlock()
|
||||
|
||||
var ok bool
|
||||
|
||||
// Execute deferred outside lock.
|
||||
for _, deferfn := range deferred {
|
||||
v := deferfn()
|
||||
ok = ok && v
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// Stream represents one
|
||||
// open stream for a client.
|
||||
type Stream struct {
|
||||
// ID of this stream, generated during creation.
|
||||
ID string
|
||||
// A set of types subscribed to by this stream: user/public/etc.
|
||||
// It's a map to ensure no duplicates; the value is ignored.
|
||||
StreamTypes map[string]any
|
||||
// Channel of messages for the client to read from
|
||||
Messages chan *Message
|
||||
// Channel to close when the client drops away
|
||||
Hangup chan interface{}
|
||||
// Only put messages in the stream when Connected
|
||||
Connected bool
|
||||
// Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc.
|
||||
sync.Mutex
|
||||
|
||||
// atomically updated ptr to a read-only copy
|
||||
// of supported stream types in a hashmap. this
|
||||
// gets updated via CAS operations in .cas().
|
||||
types atomic.Pointer[map[string]struct{}]
|
||||
|
||||
// protects stream close.
|
||||
done chan struct{}
|
||||
|
||||
// inbound msg ch.
|
||||
msgCh chan Message
|
||||
|
||||
// close hook to remove
|
||||
// stream from Streams{}.
|
||||
close func()
|
||||
}
|
||||
|
||||
// Message represents one streamed message.
|
||||
// Subscribe will add given type to given types this stream supports.
|
||||
func (s *Stream) Subscribe(streamType string) {
|
||||
s.cas(func(m map[string]struct{}) bool {
|
||||
if _, ok := m[streamType]; ok {
|
||||
return false
|
||||
}
|
||||
m[streamType] = struct{}{}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Unsubscribe will remove given type (if found) from types this stream supports.
|
||||
func (s *Stream) Unsubscribe(streamType string) {
|
||||
s.cas(func(m map[string]struct{}) bool {
|
||||
if _, ok := m[streamType]; !ok {
|
||||
return false
|
||||
}
|
||||
delete(m, streamType)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// getStreamType returns the first stream type in given list that stream supports.
|
||||
func (s *Stream) getStreamType(streamTypes ...string) string {
|
||||
if ptr := s.types.Load(); ptr != nil {
|
||||
for _, streamType := range streamTypes {
|
||||
if _, ok := (*ptr)[streamType]; ok {
|
||||
return streamType
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// send will block on posting a new Message{}, returning early with
|
||||
// a false value if provided context is canceled, or stream closed.
|
||||
func (s *Stream) send(ctx context.Context, msg Message) bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
return false
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case s.msgCh <- msg:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Recv will block on receiving Message{}, returning early with a
|
||||
// false value if provided context is canceled, or stream closed.
|
||||
func (s *Stream) Recv(ctx context.Context) (Message, bool) {
|
||||
select {
|
||||
case <-s.done:
|
||||
return Message{}, false
|
||||
case <-ctx.Done():
|
||||
return Message{}, false
|
||||
case msg := <-s.msgCh:
|
||||
return msg, true
|
||||
}
|
||||
}
|
||||
|
||||
// Close will close the underlying context, finally
|
||||
// removing it from the parent Streams per-account-map.
|
||||
func (s *Stream) Close() {
|
||||
select {
|
||||
case <-s.done:
|
||||
default:
|
||||
close(s.done)
|
||||
s.close()
|
||||
}
|
||||
}
|
||||
|
||||
// cas will perform a Compare And Swap operation on s.types using modifier func.
|
||||
func (s *Stream) cas(fn func(map[string]struct{}) bool) {
|
||||
if fn == nil {
|
||||
panic("nil function")
|
||||
}
|
||||
for {
|
||||
var m map[string]struct{}
|
||||
|
||||
// Get current value.
|
||||
ptr := s.types.Load()
|
||||
|
||||
if ptr == nil {
|
||||
// Allocate new types map.
|
||||
m = make(map[string]struct{})
|
||||
} else {
|
||||
// Clone r-only map.
|
||||
m = maps.Clone(*ptr)
|
||||
}
|
||||
|
||||
// Apply
|
||||
// changes.
|
||||
if !fn(m) {
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt to Compare And Swap ptr.
|
||||
if s.types.CompareAndSwap(ptr, &m) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Message represents
|
||||
// one streamed message.
|
||||
type Message struct {
|
||||
// All the stream types this message should be delivered to.
|
||||
|
||||
// All the stream types this
|
||||
// message should be delivered to.
|
||||
Stream []string `json:"stream"`
|
||||
// The event type of the message (update/delete/notification etc)
|
||||
|
||||
// The event type of the message
|
||||
// (update/delete/notification etc)
|
||||
Event string `json:"event"`
|
||||
// The actual payload of the message. In case of an update or notification, this will be a JSON string.
|
||||
|
||||
// The actual payload of the message. In case of an
|
||||
// update or notification, this will be a JSON string.
|
||||
Payload string `json:"payload"`
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue