diff --git a/internal/api/util/negotiate.go b/internal/api/util/negotiate.go index 8e7f41134..6d68a0df3 100644 --- a/internal/api/util/negotiate.go +++ b/internal/api/util/negotiate.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" + "strings" "github.com/gin-gonic/gin" ) @@ -108,10 +109,63 @@ func NegotiateAccept(c *gin.Context, offers ...MIME) (string, error) { return strings[0], nil } - format := c.NegotiateFormat(strings...) + format := NegotiateFormat(c, strings...) if format == "" { return "", fmt.Errorf("no format can be offered for requested Accept header(s) %s; this endpoint offers %s", accepts, offers) } return format, nil } + +// This is the exact same thing as gin.Context.NegotiateFormat except it contains +// tsmethurst's fix to make it work properly with multiple accept headers. +// +// https://github.com/gin-gonic/gin/pull/3156 +func NegotiateFormat(c *gin.Context, offered ...string) string { + if len(offered) == 0 { + panic("you must provide at least one offer") + } + + if c.Accepted == nil { + for _, a := range c.Request.Header.Values("Accept") { + c.Accepted = append(c.Accepted, parseAccept(a)...) + } + } + if len(c.Accepted) == 0 { + return offered[0] + } + for _, accepted := range c.Accepted { + for _, offer := range offered { + // According to RFC 2616 and RFC 2396, non-ASCII characters are not allowed in headers, + // therefore we can just iterate over the string without casting it into []rune + i := 0 + for ; i < len(accepted); i++ { + if accepted[i] == '*' || offer[i] == '*' { + return offer + } + if accepted[i] != offer[i] { + break + } + } + if i == len(accepted) { + return offer + } + } + } + return "" +} + +// https://github.com/gin-gonic/gin/blob/4787b8203b79012877ac98d7806422da3a678ba2/utils.go#L103 +func parseAccept(acceptHeader string) []string { + parts := strings.Split(acceptHeader, ",") + out := make([]string, 0, len(parts)) + for _, part := range parts { + if i := strings.IndexByte(part, ';'); i > 0 { + part = part[:i] + } + if part = strings.TrimSpace(part); part != "" { + out = append(out, part) + } + } + return out +} diff --git a/internal/api/util/negotiate_test.go b/internal/api/util/negotiate_test.go new file mode 100644 index 000000000..a8b28b55f --- /dev/null +++ b/internal/api/util/negotiate_test.go @@ -0,0 +1,65 @@ +package util + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" +) + +type testMIMES []MIME + +func (tm testMIMES) String(t *testing.T) string { + t.Helper() + + res := tm.StringS(t) + return strings.Join(res, ",") +} + +func (tm testMIMES) StringS(t *testing.T) []string { + t.Helper() + + res := make([]string, 0, len(tm)) + for _, m := range tm { + res = append(res, string(m)) + } + return res +} + +func TestNegotiateFormat(t *testing.T) { + tests := []struct { + incoming []string + offered testMIMES + format string + }{ + {incoming: testMIMES{AppJSON}.StringS(t), offered: testMIMES{AppJRDJSON, AppJSON}, format: "application/json"}, + {incoming: testMIMES{AppJRDJSON}.StringS(t), offered: testMIMES{AppJRDJSON, AppJSON}, format: "application/jrd+json"}, + {incoming: testMIMES{AppJRDJSON, AppJSON}.StringS(t), offered: testMIMES{AppJRDJSON}, format: "application/jrd+json"}, + {incoming: testMIMES{AppJRDJSON, AppJSON}.StringS(t), offered: testMIMES{AppJSON}, format: "application/json"}, + {incoming: testMIMES{"text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8"}.StringS(t), offered: testMIMES{AppJSON, AppXML}, format: "application/xml"}, + {incoming: testMIMES{"text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8"}.StringS(t), offered: testMIMES{TextHTML, AppXML}, format: "text/html"}, + } + + for _, tt := range tests { + name := "incoming:" + strings.Join(tt.incoming, ",") + " offered:" + tt.offered.String(t) + t.Run(name, func(t *testing.T) { + tt := tt + t.Parallel() + + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = &http.Request{ + Header: make(http.Header), + } + for _, header := range tt.incoming { + c.Request.Header.Add("accept", header) + } + + format := NegotiateFormat(c, tt.offered.StringS(t)...) + if tt.format != format { + t.Fatalf("expected format: '%s', got format: '%s'", tt.format, format) + } + }) + } +} diff --git a/internal/web/profile.go b/internal/web/profile.go index a4fddbafe..56f8e0a56 100644 --- a/internal/web/profile.go +++ b/internal/web/profile.go @@ -73,7 +73,7 @@ func (m *Module) profileGETHandler(c *gin.Context) { // if we're getting an AP request on this endpoint we // should render the account's AP representation instead - accept := c.NegotiateFormat(string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) + accept := apiutil.NegotiateFormat(c, string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) if accept == string(apiutil.AppActivityJSON) || accept == string(apiutil.AppActivityLDJSON) { m.returnAPProfile(ctx, c, username, accept) return diff --git a/internal/web/thread.go b/internal/web/thread.go index fe57ddf1f..8d4e99bef 100644 --- a/internal/web/thread.go +++ b/internal/web/thread.go @@ -90,7 +90,7 @@ func (m *Module) threadGETHandler(c *gin.Context) { // if we're getting an AP request on this endpoint we // should render the status's AP representation instead - accept := c.NegotiateFormat(string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) + accept := apiutil.NegotiateFormat(c, string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) if accept == string(apiutil.AppActivityJSON) || accept == string(apiutil.AppActivityLDJSON) { m.returnAPStatus(ctx, c, username, statusID, accept) return