diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 444157c1b..067f87392 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -20,14 +20,16 @@ import ( "context" - "errors" - "fmt" "time" "codeberg.org/gruf/go-kv" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + streampkg "github.com/superseriousbusiness/gotosocial/internal/stream" + "golang.org/x/exp/slices" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -134,32 +136,37 @@ // '400': // description: bad request func (m *Module) StreamGETHandler(c *gin.Context) { - streamType := c.Query(StreamQueryKey) - if streamType == "" { - err := fmt.Errorf("no stream type provided under query key %s", StreamQueryKey) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - - var token string // First we check for a query param provided access token - if token = c.Query(AccessTokenQueryKey); token == "" { + token := c.Query(AccessTokenQueryKey) + if token == "" { // Else we check the HTTP header provided token - if token = c.GetHeader(AccessTokenHeader); token == "" { - const errStr = "no access token provided" - err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr) + token = c.GetHeader(AccessTokenHeader) + } + + var account *gtsmodel.Account + if token != "" { + // Check the explicit token + var errWithCode gtserror.WithCode + account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + } else { + // If no explicit token was provided, try regular oauth + auth, errStr := oauth.Authed(c, true, true, true, true) + if errStr != nil { + err := gtserror.NewErrorUnauthorized(errStr, errStr.Error()) apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1) return } + account = auth.Account } - account, errWithCode := m.processor.Stream().Authorize(c.Request.Context(), token) - if errWithCode != nil { - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) - return - } - + // Get the initial stream type, if there is one. + // streamType will be an empty string if one wasn't supplied. Open() will deal with this + streamType := c.Query(StreamQueryKey) stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) @@ -219,8 +226,9 @@ func (m *Module) StreamGETHandler(c *gin.Context) { // We have to listen for received websocket messages in // order to trigger the underlying wsConn.PingHandler(). // - // So we wait on received messages but only act on errors. - _, _, err := wsConn.ReadMessage() + // Read JSON objects from the client and act on them + var msg map[string]string + err := wsConn.ReadJSON(&msg) if err != nil { if ctx.Err() == nil { // Only log error if the connection was not closed @@ -229,6 +237,33 @@ func (m *Module) StreamGETHandler(c *gin.Context) { } return } + l.Tracef("received message from websocket: %v", msg) + + // If the message contains 'stream' and 'type' fields, we can + // update the set of timelines that are subscribed for events. + // everything else is ignored. + action := msg["type"] + streamType := msg["stream"] + + // Ignore if the streamType is unknown (or missing), so a bad + // client can't cause extra memory allocations + if !slices.Contains(streampkg.AllStatusTimelines, streamType) { + l.Warnf("Unknown 'stream' field: %v", msg) + continue + } + + switch action { + case "subscribe": + stream.Lock() + stream.Timelines[streamType] = true + stream.Unlock() + case "unsubscribe": + stream.Lock() + delete(stream.Timelines, streamType) + stream.Unlock() + default: + l.Warnf("Invalid 'type' field: %v", msg) + } } }() diff --git a/internal/processing/stream/open.go b/internal/processing/stream/open.go index 10d01a767..823efa182 100644 --- a/internal/processing/stream/open.go +++ b/internal/processing/stream/open.go @@ -45,9 +45,17 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) } + // Each stream can be subscibed to multiple timelines. + // Record them in a set, and include the initial one + // if it was given to us + timelines := map[string]bool{} + if streamTimeline != "" { + timelines[streamTimeline] = true + } + thisStream := &stream.Stream{ ID: streamID, - Timeline: streamTimeline, + Timelines: timelines, Messages: make(chan *stream.Message, 100), Hangup: make(chan interface{}, 1), Connected: true, diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go index a10ab2474..f3a9e92f3 100644 --- a/internal/processing/stream/stream.go +++ b/internal/processing/stream/stream.go @@ -63,12 +63,15 @@ func (p *Processor) toAccount(payload string, event string, timelines []string, } for _, t := range timelines { - if s.Timeline == string(t) { + if _, found := s.Timelines[t]; found { s.Messages <- &stream.Message{ Stream: []string{string(t)}, Event: string(event), Payload: payload, } + // break out to the outer loop, to avoid sending duplicates + // of the same event to the same stream + break } } } diff --git a/internal/stream/stream.go b/internal/stream/stream.go index ba27c213f..a23a5500a 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -63,8 +63,10 @@ type StreamsForAccount struct { type Stream struct { // ID of this stream, generated during creation. ID string - // Timeline of this stream: user/public/etc - Timeline string + // A set of timelines of this stream: user/public/etc + // a matching key means the timeline is subscribed. The value + // is ignored + Timelines map[string]bool // Channel of messages for the client to read from Messages chan *Message // Channel to close when the client drops away