diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 88c682a75..1f34e3447 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -149,60 +149,78 @@ // '400': // description: bad request func (m *Module) StreamGETHandler(c *gin.Context) { + var ( + account *gtsmodel.Account + errWithCode gtserror.WithCode + ) - // First we check for a query param provided access token + // Try query param access token. token := c.Query(AccessTokenQueryKey) if token == "" { - // Else we check the HTTP header provided token + // Try fallback HTTP header provided token. token = c.GetHeader(AccessTokenHeader) } - var account *gtsmodel.Account if token != "" { - // Check the explicit token - var errWithCode gtserror.WithCode + // Token was provided, use it to authorize stream. account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) - if errWithCode != nil { - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) - return - } } else { - // If no explicit token was provided, try regular oauth - auth, errStr := oauth.Authed(c, true, true, true, true) - if errStr != nil { - err := gtserror.NewErrorUnauthorized(errStr, errStr.Error()) - apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1) - return - } - account = auth.Account + // No explicit token was provided: + // try regular oauth as a last resort. + account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + return nil, gtserror.NewErrorUnauthorized(err, err.Error()) + } + + return authed.Account, nil + }() } - // Get the initial stream type, if there is one. - // By appending other query params to the streamType, - // we can allow for streaming for specific list IDs - // or hashtags. + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + // Get the initial requested stream type, if there is one. streamType := c.Query(StreamQueryKey) + + // By appending other query params to the streamType, we + // can allow streaming for specific list IDs or hashtags. + // The streamType in this case will end up looking like + // `hashtag:example` or `list:01H3YF48G8B7KTPQFS8D2QBVG8`. if list := c.Query(StreamListKey); list != "" { streamType += ":" + list } else if tag := c.Query(StreamTagKey); tag != "" { streamType += ":" + tag } - stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) + // Open a stream with the processor; this lets processor + // functions pass messages into a channel, which we can + // then read from and put into a websockets connection. + stream, errWithCode := m.processor.Stream().Open( + c.Request.Context(), + account, + streamType, + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - l := log.WithContext(c.Request.Context()). + l := log. + WithContext(c.Request.Context()). WithFields(kv.Fields{ - {"account", account.Username}, + {"username", account.Username}, {"streamID", stream.ID}, - {"streamType", streamType}, }...) - // Upgrade the incoming HTTP request, which hijacks the underlying - // connection and reuses it for the websocket (non-http) protocol. + // Upgrade the incoming HTTP request. This hijacks the + // underlying connection and reuses it for the websocket + // (non-http) protocol. + // + // If the upgrade fails, then Upgrade replies to the client + // with an HTTP error response. wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) if err != nil { l.Errorf("error upgrading websocket connection: %v", err) @@ -210,125 +228,208 @@ func (m *Module) StreamGETHandler(c *gin.Context) { return } + l.Info("opened websocket connection") + + // We perform the main websocket rw loops in a separate + // goroutine in order to let the upgrade handler return. + // This prevents the upgrade handler from holding open any + // throttle / rate-limit request tokens which could become + // problematic on instances with multiple users. + go m.handleWSConn(account.Username, wsConn, stream) +} + +// handleWSConn handles a two-way websocket streaming connection. +// It will both read messages from the connection, and push messages +// into the connection. If any errors are encountered while reading +// or writing (including expected errors like clients leaving), the +// connection will be closed. +func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) { + // Create new context for the lifetime of this connection. + ctx, cancel := context.WithCancel(context.Background()) + + l := log. + WithContext(ctx). + WithFields(kv.Fields{ + {"username", username}, + {"streamID", stream.ID}, + }...) + + // Create ticker to send keepalive pings + pinger := time.NewTicker(m.dTicker) + + // Read messages coming from the Websocket client connection into the server. go func() { - // We perform the main websocket send loop in a separate - // goroutine in order to let the upgrade handler return. - // This prevents the upgrade handler from holding open any - // throttle / rate-limit request tokens which could become - // problematic on instances with multiple users. - l.Info("opened websocket connection") - defer l.Info("closed websocket connection") + defer cancel() + m.readFromWSConn(ctx, username, wsConn, stream) + }() - // Create new context for lifetime of the connection - ctx, cncl := context.WithCancel(context.Background()) + // Write messages coming from the processor into the Websocket client connection. + go func() { + defer cancel() + m.writeToWSConn(ctx, username, wsConn, stream, pinger) + }() - // Create ticker to send alive pings - pinger := time.NewTicker(m.dTicker) + // Wait for either the read or write functions to close, to indicate + // that the client has left, or something else has gone wrong. + <-ctx.Done() - defer func() { - // Signal done - cncl() + // Tidy up underlying websocket connection. + if err := wsConn.Close(); err != nil { + l.Errorf("error closing websocket connection: %v", err) + } - // Close websocket conn - _ = wsConn.Close() + // Close processor channel so the processor knows + // not to send any more messages to this stream. + close(stream.Hangup) - // Close processor stream - close(stream.Hangup) + // Stop ping ticker (tiny resource saving). + pinger.Stop() - // Stop ping ticker - pinger.Stop() - }() + l.Info("closed websocket connection") +} - go func() { - // Signal done - defer cncl() +// readFromWSConn reads control messages coming in from the given +// websockets connection, and modifies the subscription StreamTypes +// of the given stream accordingly after acquiring a lock on it. +// +// This is a blocking function; will return only on read error or +// if the given context is canceled. +func (m *Module) readFromWSConn( + ctx context.Context, + username string, + wsConn *websocket.Conn, + stream *streampkg.Stream, +) { + l := log. + WithContext(ctx). + WithFields(kv.Fields{ + {"username", username}, + {"streamID", stream.ID}, + }...) - for { - // We have to listen for received websocket messages in - // order to trigger the underlying wsConn.PingHandler(). - // - // Read JSON objects from the client and act on them - var msg map[string]string - err := wsConn.ReadJSON(&msg) - if err != nil { - if ctx.Err() == nil { - // Only log error if the connection was not closed - // by us. Uncanceled context indicates this is the case. - l.Errorf("error reading from websocket: %v", err) - } - return - } - l.Tracef("received message from websocket: %v", msg) +readLoop: + for { + select { + case <-ctx.Done(): + // Connection closed. + break readLoop - // If the message contains 'stream' and 'type' fields, we can - // update the set of timelines that are subscribed for events. - updateType, ok := msg["type"] - if !ok { - l.Warn("'type' field not provided") - continue + default: + // Read JSON objects from the client and act on them. + var msg map[string]string + if err := wsConn.ReadJSON(&msg); err != nil { + // Only log an error if something weird happened. + // See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7 + if websocket.IsUnexpectedCloseError(err, []int{ + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived, + }...) { + l.Errorf("error reading from websocket: %v", err) } - updateStream, ok := msg["stream"] - if !ok { - l.Warn("'stream' field not provided") - continue - } - - // Ignore if the updateStreamType is unknown (or missing), - // so a bad client can't cause extra memory allocations - if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { - l.Warnf("unknown 'stream' field: %v", msg) - continue - } - - updateList, ok := msg["list"] - if ok { - updateStream += ":" + updateList - } - - switch updateType { - case "subscribe": - stream.Lock() - stream.StreamTypes[updateStream] = true - stream.Unlock() - case "unsubscribe": - stream.Lock() - delete(stream.StreamTypes, updateStream) - stream.Unlock() - default: - l.Warnf("invalid 'type' field: %v", msg) - } + // The connection is gone; no + // further streaming possible. + break readLoop } - }() - for { - select { - // Connection closed - case <-ctx.Done(): - return + // Messages *from* the WS connection are infrequent + // and usually interesting, so log this at info. + l.Infof("received message from websocket: %v", msg) - // Received next stream message - case msg := <-stream.Messages: - l.Tracef("sending message to websocket: %+v", msg) - if err := wsConn.WriteJSON(msg); err != nil { - l.Debugf("error writing json to websocket: %v", err) - return - } + // If the message contains 'stream' and 'type' fields, we can + // update the set of timelines that are subscribed for events. + updateType, ok := msg["type"] + if !ok { + l.Warn("'type' field not provided") + continue + } - // Reset on each successful send. - pinger.Reset(m.dTicker) + updateStream, ok := msg["stream"] + if !ok { + l.Warn("'stream' field not provided") + continue + } - // Send keep-alive "ping" - case <-pinger.C: - l.Trace("pinging websocket ...") - if err := wsConn.WriteMessage( - websocket.PingMessage, - []byte{}, - ); err != nil { - l.Debugf("error writing ping to websocket: %v", err) - return - } + // Ignore if the updateStreamType is unknown (or missing), + // so a bad client can't cause extra memory allocations + if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { + l.Warnf("unknown 'stream' field: %v", msg) + continue + } + + updateList, ok := msg["list"] + if ok { + updateStream += ":" + updateList + } + + switch updateType { + case "subscribe": + stream.Lock() + stream.StreamTypes[updateStream] = true + stream.Unlock() + case "unsubscribe": + stream.Lock() + delete(stream.StreamTypes, updateStream) + stream.Unlock() + default: + l.Warnf("invalid 'type' field: %v", msg) } } - }() + } + + l.Debug("finished reading from websocket connection") +} + +// writeToWSConn receives messages coming from the processor via the +// given stream, and writes them into the given websockets connection. +// This function also handles sending ping messages into the websockets +// connection to keep it alive when no other activity occurs. +// +// This is a blocking function; will return only on write error or +// if the given context is canceled. +func (m *Module) writeToWSConn( + ctx context.Context, + username string, + wsConn *websocket.Conn, + stream *streampkg.Stream, + pinger *time.Ticker, +) { + l := log. + WithContext(ctx). + WithFields(kv.Fields{ + {"username", username}, + {"streamID", stream.ID}, + }...) + +writeLoop: + for { + select { + case <-ctx.Done(): + // Connection closed. + break writeLoop + + case msg := <-stream.Messages: + // Received a new message from the processor. + l.Tracef("writing message to websocket: %+v", msg) + if err := wsConn.WriteJSON(msg); err != nil { + l.Debugf("error writing json to websocket: %v", err) + break writeLoop + } + + // Reset pinger on successful send, since + // we know the connection is still there. + pinger.Reset(m.dTicker) + + case <-pinger.C: + // Time to send a keep-alive "ping". + l.Trace("writing ping control message to websocket") + if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil { + l.Debugf("error writing ping to websocket: %v", err) + break writeLoop + } + } + } + + l.Debug("finished writing to websocket connection") } diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go index edddeab73..303e16cd3 100644 --- a/internal/api/client/streaming/streaming.go +++ b/internal/api/client/streaming/streaming.go @@ -42,15 +42,18 @@ type Module struct { } func New(processor *processing.Processor, dTicker time.Duration, wsBuf int) *Module { + // We expect CORS requests for websockets, + // (via eg., semaphore.social) so be lenient. + // TODO: make this customizable? + checkOrigin := func(r *http.Request) bool { return true } + return &Module{ processor: processor, dTicker: dTicker, wsUpgrade: websocket.Upgrader{ - ReadBufferSize: wsBuf, // we don't expect reads + ReadBufferSize: wsBuf, WriteBufferSize: wsBuf, - - // we expect cors requests (via eg., semaphore.social) so be lenient - CheckOrigin: func(r *http.Request) bool { return true }, + CheckOrigin: checkOrigin, }, } }