diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go
index bc56e21f0..5f45ecd3f 100644
--- a/cmd/gotosocial/action/server/server.go
+++ b/cmd/gotosocial/action/server/server.go
@@ -110,7 +110,7 @@
oauthServer := oauth.New(ctx, dbService)
typeConverter := typeutils.NewConverter(dbService)
federatingDB := federatingdb.New(&state, typeConverter)
- transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, client)
+ transportController := transport.NewController(&state, federatingDB, &federation.Clock{}, client)
federator := federation.NewFederator(dbService, federatingDB, transportController, typeConverter, mediaManager)
// decide whether to create a noop email sender (won't send emails) or a real one
diff --git a/example/config.yaml b/example/config.yaml
index bdd3c4cc2..e662159f6 100644
--- a/example/config.yaml
+++ b/example/config.yaml
@@ -287,6 +287,10 @@ cache:
user-ttl: "5m"
user-sweep-freq: "30s"
+ webfinger-max-size": 250
+ webfinger-ttl: "24h"
+ webfinger-sweep-freq": "15m"
+
######################
##### WEB CONFIG #####
######################
diff --git a/internal/api/model/well-known.go b/internal/api/model/well-known.go
index bf61a6085..f3481ad72 100644
--- a/internal/api/model/well-known.go
+++ b/internal/api/model/well-known.go
@@ -18,6 +18,8 @@
package model
+import "encoding/xml"
+
// WellKnownResponse represents the response to either a webfinger request for an 'acct' resource, or a request to nodeinfo.
// For example, it would be returned from https://example.org/.well-known/webfinger?resource=acct:some_username@example.org
//
@@ -32,12 +34,12 @@ type WellKnownResponse struct {
// Link represents one 'link' in a slice of links returned from a lookup request.
//
-// See https://webfinger.net/
+// See https://webfinger.net/ and https://www.rfc-editor.org/rfc/rfc6415.html#section-3.1
type Link struct {
- Rel string `json:"rel"`
- Type string `json:"type,omitempty"`
- Href string `json:"href,omitempty"`
- Template string `json:"template,omitempty"`
+ Rel string `json:"rel" xml:"rel,attr"`
+ Type string `json:"type,omitempty" xml:"type,attr,omitempty"`
+ Href string `json:"href,omitempty" xml:"href,attr,omitempty"`
+ Template string `json:"template,omitempty" xml:"template,attr,omitempty"`
}
// Nodeinfo represents a version 2.1 or version 2.0 nodeinfo schema.
@@ -87,3 +89,13 @@ type NodeInfoUsage struct {
type NodeInfoUsers struct {
Total int `json:"total"`
}
+
+// HostMeta represents a hostmeta document.
+// See: https://www.rfc-editor.org/rfc/rfc6415.html#section-3
+//
+// swagger:model hostmeta
+type HostMeta struct {
+ XMLName xml.Name `xml:"XRD"`
+ XMLNS string `xml:"xmlns,attr"`
+ Link []Link `xml:"Link"`
+}
diff --git a/internal/cache/gts.go b/internal/cache/gts.go
index 253dc47b2..568ffb478 100644
--- a/internal/cache/gts.go
+++ b/internal/cache/gts.go
@@ -20,6 +20,7 @@
import (
"codeberg.org/gruf/go-cache/v3/result"
+ "codeberg.org/gruf/go-cache/v3/ttl"
"github.com/superseriousbusiness/gotosocial/internal/cache/domain"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -71,6 +72,9 @@ type GTSCaches interface {
// User provides access to the gtsmodel User database cache.
User() *result.Cache[*gtsmodel.User]
+
+ // Webfinger
+ Webfinger() *ttl.Cache[string, string]
}
// NewGTS returns a new default implementation of GTSCaches.
@@ -91,6 +95,7 @@ type gtsCaches struct {
status *result.Cache[*gtsmodel.Status]
tombstone *result.Cache[*gtsmodel.Tombstone]
user *result.Cache[*gtsmodel.User]
+ webfinger *ttl.Cache[string, string]
}
func (c *gtsCaches) Init() {
@@ -106,6 +111,7 @@ func (c *gtsCaches) Init() {
c.initStatus()
c.initTombstone()
c.initUser()
+ c.initWebfinger()
}
func (c *gtsCaches) Start() {
@@ -145,6 +151,9 @@ func (c *gtsCaches) Start() {
tryUntil("starting gtsmodel.User cache", 5, func() bool {
return c.user.Start(config.GetCacheGTSUserSweepFreq())
})
+ tryUntil("starting gtsmodel.Webfinger cache", 5, func() bool {
+ return c.webfinger.Start(config.GetCacheGTSWebfingerSweepFreq())
+ })
}
func (c *gtsCaches) Stop() {
@@ -160,6 +169,7 @@ func (c *gtsCaches) Stop() {
tryUntil("stopping gtsmodel.Status cache", 5, c.status.Stop)
tryUntil("stopping gtsmodel.Tombstone cache", 5, c.tombstone.Stop)
tryUntil("stopping gtsmodel.User cache", 5, c.user.Stop)
+ tryUntil("stopping gtsmodel.Webfinger cache", 5, c.webfinger.Stop)
}
func (c *gtsCaches) Account() *result.Cache[*gtsmodel.Account] {
@@ -210,6 +220,10 @@ func (c *gtsCaches) User() *result.Cache[*gtsmodel.User] {
return c.user
}
+func (c *gtsCaches) Webfinger() *ttl.Cache[string, string] {
+ return c.webfinger
+}
+
func (c *gtsCaches) initAccount() {
c.account = result.New([]result.Lookup{
{Name: "ID"},
@@ -355,3 +369,10 @@ func (c *gtsCaches) initUser() {
}, config.GetCacheGTSUserMaxSize())
c.user.SetTTL(config.GetCacheGTSUserTTL(), true)
}
+
+func (c *gtsCaches) initWebfinger() {
+ c.webfinger = ttl.New[string, string](
+ 0,
+ config.GetCacheGTSWebfingerMaxSize(),
+ config.GetCacheGTSWebfingerTTL())
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index a7a36eebf..f7a59d760 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -207,6 +207,10 @@ type GTSCacheConfiguration struct {
UserMaxSize int `name:"user-max-size"`
UserTTL time.Duration `name:"user-ttl"`
UserSweepFreq time.Duration `name:"user-sweep-freq"`
+
+ WebfingerMaxSize int `name:"webfinger-max-size"`
+ WebfingerTTL time.Duration `name:"webfinger-ttl"`
+ WebfingerSweepFreq time.Duration `name:"webfinger-sweep-freq"`
}
// MarshalMap will marshal current Configuration into a map structure (useful for JSON/TOML/YAML).
diff --git a/internal/config/defaults.go b/internal/config/defaults.go
index 7d2427ee7..418858827 100644
--- a/internal/config/defaults.go
+++ b/internal/config/defaults.go
@@ -166,6 +166,10 @@
UserMaxSize: 100,
UserTTL: time.Minute * 5,
UserSweepFreq: time.Second * 30,
+
+ WebfingerMaxSize: 250,
+ WebfingerTTL: time.Hour * 24,
+ WebfingerSweepFreq: time.Minute * 15,
},
},
diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go
index b021ed617..14fa72b24 100644
--- a/internal/config/helpers.gen.go
+++ b/internal/config/helpers.gen.go
@@ -3003,6 +3003,81 @@ func GetCacheGTSUserSweepFreq() time.Duration { return global.GetCacheGTSUserSwe
// SetCacheGTSUserSweepFreq safely sets the value for global configuration 'Cache.GTS.UserSweepFreq' field
func SetCacheGTSUserSweepFreq(v time.Duration) { global.SetCacheGTSUserSweepFreq(v) }
+// GetCacheGTSWebfingerMaxSize safely fetches the Configuration value for state's 'Cache.GTS.WebfingerMaxSize' field
+func (st *ConfigState) GetCacheGTSWebfingerMaxSize() (v int) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.WebfingerMaxSize
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSWebfingerMaxSize safely sets the Configuration value for state's 'Cache.GTS.WebfingerMaxSize' field
+func (st *ConfigState) SetCacheGTSWebfingerMaxSize(v int) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.WebfingerMaxSize = v
+ st.reloadToViper()
+}
+
+// CacheGTSWebfingerMaxSizeFlag returns the flag name for the 'Cache.GTS.WebfingerMaxSize' field
+func CacheGTSWebfingerMaxSizeFlag() string { return "cache-gts-webfinger-max-size" }
+
+// GetCacheGTSWebfingerMaxSize safely fetches the value for global configuration 'Cache.GTS.WebfingerMaxSize' field
+func GetCacheGTSWebfingerMaxSize() int { return global.GetCacheGTSWebfingerMaxSize() }
+
+// SetCacheGTSWebfingerMaxSize safely sets the value for global configuration 'Cache.GTS.WebfingerMaxSize' field
+func SetCacheGTSWebfingerMaxSize(v int) { global.SetCacheGTSWebfingerMaxSize(v) }
+
+// GetCacheGTSWebfingerTTL safely fetches the Configuration value for state's 'Cache.GTS.WebfingerTTL' field
+func (st *ConfigState) GetCacheGTSWebfingerTTL() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.WebfingerTTL
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSWebfingerTTL safely sets the Configuration value for state's 'Cache.GTS.WebfingerTTL' field
+func (st *ConfigState) SetCacheGTSWebfingerTTL(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.WebfingerTTL = v
+ st.reloadToViper()
+}
+
+// CacheGTSWebfingerTTLFlag returns the flag name for the 'Cache.GTS.WebfingerTTL' field
+func CacheGTSWebfingerTTLFlag() string { return "cache-gts-webfinger-ttl" }
+
+// GetCacheGTSWebfingerTTL safely fetches the value for global configuration 'Cache.GTS.WebfingerTTL' field
+func GetCacheGTSWebfingerTTL() time.Duration { return global.GetCacheGTSWebfingerTTL() }
+
+// SetCacheGTSWebfingerTTL safely sets the value for global configuration 'Cache.GTS.WebfingerTTL' field
+func SetCacheGTSWebfingerTTL(v time.Duration) { global.SetCacheGTSWebfingerTTL(v) }
+
+// GetCacheGTSWebfingerSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.WebfingerSweepFreq' field
+func (st *ConfigState) GetCacheGTSWebfingerSweepFreq() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.WebfingerSweepFreq
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSWebfingerSweepFreq safely sets the Configuration value for state's 'Cache.GTS.WebfingerSweepFreq' field
+func (st *ConfigState) SetCacheGTSWebfingerSweepFreq(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.WebfingerSweepFreq = v
+ st.reloadToViper()
+}
+
+// CacheGTSWebfingerSweepFreqFlag returns the flag name for the 'Cache.GTS.WebfingerSweepFreq' field
+func CacheGTSWebfingerSweepFreqFlag() string { return "cache-gts-webfinger-sweep-freq" }
+
+// GetCacheGTSWebfingerSweepFreq safely fetches the value for global configuration 'Cache.GTS.WebfingerSweepFreq' field
+func GetCacheGTSWebfingerSweepFreq() time.Duration { return global.GetCacheGTSWebfingerSweepFreq() }
+
+// SetCacheGTSWebfingerSweepFreq safely sets the value for global configuration 'Cache.GTS.WebfingerSweepFreq' field
+func SetCacheGTSWebfingerSweepFreq(v time.Duration) { global.SetCacheGTSWebfingerSweepFreq(v) }
+
// GetAdminAccountUsername safely fetches the Configuration value for state's 'AdminAccountUsername' field
func (st *ConfigState) GetAdminAccountUsername() (v string) {
st.mutex.Lock()
diff --git a/internal/transport/controller.go b/internal/transport/controller.go
index abcccfe1e..d23ae0b68 100644
--- a/internal/transport/controller.go
+++ b/internal/transport/controller.go
@@ -32,9 +32,9 @@
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
)
// Controller generates transports for use in making federation requests to other servers.
@@ -47,7 +47,7 @@ type Controller interface {
}
type controller struct {
- db db.DB
+ state *state.State
fedDB federatingdb.DB
clock pub.Clock
client pub.HttpClient
@@ -57,14 +57,14 @@ type controller struct {
}
// NewController returns an implementation of the Controller interface for creating new transports
-func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
+func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
applicationName := config.GetApplicationName()
host := config.GetHost()
proto := config.GetProtocol()
version := config.GetSoftwareVersion()
c := &controller{
- db: db,
+ state: state,
fedDB: federatingDB,
clock: clock,
client: client,
@@ -138,7 +138,7 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin
u = username
}
- ourAccount, err := c.db.GetAccountByUsernameDomain(ctx, u, "")
+ ourAccount, err := c.state.DB.GetAccountByUsernameDomain(ctx, u, "")
if err != nil {
return nil, fmt.Errorf("error getting account %s from db: %s", username, err)
}
diff --git a/internal/transport/finger.go b/internal/transport/finger.go
index 4e6594df4..6631ff8f1 100644
--- a/internal/transport/finger.go
+++ b/internal/transport/finger.go
@@ -20,29 +20,61 @@
import (
"context"
+ "encoding/xml"
"fmt"
"io"
"net/http"
+ "net/url"
+ apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
)
-func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
- // Prepare URL string
- urlStr := "https://" +
- targetDomain +
- "/.well-known/webfinger?resource=acct:" +
- targetUsername + "@" + targetDomain
+// webfingerURLFor returns the URL to try a webfinger request against, as
+// well as if the URL was retrieved from cache. When the URL is retrieved
+// from cache we don't have to try and do host-meta discovery
+func (t *transport) webfingerURLFor(targetDomain string) (string, bool) {
+ url := "https://" + targetDomain + "/.well-known/webfinger"
- // Generate new GET request from URL string
- req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
+ wc := t.controller.state.Caches.GTS.Webfinger()
+ // We're doing the manual locking/unlocking here to be able to
+ // safely call Cache.Get instead of Get, as the latter updates the
+ // item expiry which we don't want to do here
+ wc.Lock()
+ item, ok := wc.Cache.Get(targetDomain)
+ wc.Unlock()
+
+ if ok {
+ url = item.Value
+ }
+
+ return url, ok
+}
+
+func prepWebfingerReq(ctx context.Context, loc, domain, username string) (*http.Request, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, loc, nil)
if err != nil {
return nil, err
}
+
+ value := url.QueryEscape("acct:" + username + "@" + domain)
+ req.URL.RawQuery = "resource=" + value
+
req.Header.Add("Accept", string(apiutil.AppJSON))
req.Header.Add("Accept", "application/jrd+json")
req.Header.Set("Host", req.URL.Host)
+ return req, nil
+}
+
+func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
+ // Generate new GET request
+ url, cached := t.webfingerURLFor(targetDomain)
+ req, err := prepWebfingerReq(ctx, url, targetDomain, targetUsername)
+ if err != nil {
+ return nil, err
+ }
+
// Perform the HTTP request
rsp, err := t.GET(req)
if err != nil {
@@ -50,10 +82,117 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom
}
defer rsp.Body.Close()
- // Check for an expected status code
- if rsp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("GET request to %s failed: %s", urlStr, rsp.Status)
+ // Check if the request succeeded so we can bail out early
+ if rsp.StatusCode == http.StatusOK {
+ if cached {
+ // If we got a success on a cached URL, i.e one set by us later on when
+ // a host-meta based webfinger request succeeded, set it again here to
+ // renew the TTL
+ t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, url)
+ }
+ return io.ReadAll(rsp.Body)
}
+ // From here on out, we're handling different failure scenarios and
+ // deciding whether we should do a host-meta based fallback or not
+
+ if (rsp.StatusCode >= 500 && rsp.StatusCode < 600) || cached {
+ // In case we got a 5xx, bail out irrespective of if the value
+ // was cached or not. The target may be broken or be signalling
+ // us to back-off.
+ //
+ // If it's any error but the URL was cached, bail out too
+ return nil, fmt.Errorf("GET request to %s failed: %s", req.URL.String(), rsp.Status)
+ }
+
+ // So far we've failed to get a successful response from the expected
+ // webfinger endpoint. Lets try and discover the webfinger endpoint
+ // through /.well-known/host-meta
+ host, err := t.webfingerFromHostMeta(ctx, targetDomain)
+ if err != nil {
+ return nil, fmt.Errorf("failed to discover webfinger URL fallback for: %s through host-meta: %w", targetDomain, err)
+ }
+
+ // Check if the original and host-meta URL are the same. If they
+ // are there's no sense in us trying the request again as it just
+ // failed
+ if host == url {
+ return nil, fmt.Errorf("webfinger discovery on %s returned endpoint we already tried: %s", targetDomain, host)
+ }
+
+ // Now that we have a different URL for the webfinger
+ // endpoint, try the request against that endpoint instead
+ req, err = prepWebfingerReq(ctx, host, targetDomain, targetUsername)
+ if err != nil {
+ return nil, err
+ }
+
+ // Perform the HTTP request
+ rsp, err = t.GET(req)
+ if err != nil {
+ return nil, err
+ }
+ defer rsp.Body.Close()
+
+ if rsp.StatusCode != http.StatusOK {
+ // We've reached the end of the line here, both the original request
+ // and our attempt to resolve it through the fallback have failed
+ return nil, fmt.Errorf("GET request to %s failed: %s", req.URL.String(), rsp.Status)
+ }
+
+ // Set the URL in cache here, since host-meta told us this should be the
+ // valid one, it's different from the default and our request to it did
+ // not fail in any manner
+ t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, host)
+
return io.ReadAll(rsp.Body)
}
+
+func (t *transport) webfingerFromHostMeta(ctx context.Context, targetDomain string) (string, error) {
+ // Build the request for the host-meta endpoint
+ hmurl := "https://" + targetDomain + "/.well-known/host-meta"
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, hmurl, nil)
+ if err != nil {
+ return "", err
+ }
+
+ // We're doing XML
+ req.Header.Add("Accept", string(apiutil.AppXML))
+ req.Header.Add("Accept", "application/xrd+xml")
+ req.Header.Set("Host", req.URL.Host)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
+ if err != nil {
+ return "", err
+ }
+ defer rsp.Body.Close()
+
+ // Doesn't look like host-meta is working for this instance
+ if rsp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("GET request for %s failed: %s", req.URL.String(), rsp.Status)
+ }
+
+ e := xml.NewDecoder(rsp.Body)
+ var hm apimodel.HostMeta
+ if err := e.Decode(&hm); err != nil {
+ // We got something, but it's not a host-meta document we understand
+ return "", fmt.Errorf("failed to decode host-meta response for %s at %s: %w", targetDomain, req.URL.String(), err)
+ }
+
+ for _, link := range hm.Link {
+ // Based on what we currently understand, there should not be more than one
+ // of these with Rel="lrdd" in a host-meta document
+ if link.Rel == "lrdd" {
+ u, err := url.Parse(link.Template)
+ if err != nil {
+ return "", fmt.Errorf("lrdd link is not a valid url: %w", err)
+ }
+ // Get rid of the query template, we only want the scheme://host/path part
+ u.RawQuery = ""
+ urlStr := u.String()
+ return urlStr, nil
+ }
+ }
+ return "", fmt.Errorf("no webfinger URL found")
+}
diff --git a/internal/transport/finger_test.go b/internal/transport/finger_test.go
new file mode 100644
index 000000000..d207785dc
--- /dev/null
+++ b/internal/transport/finger_test.go
@@ -0,0 +1,118 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package transport_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+)
+
+type FingerTestSuite struct {
+ TransportTestSuite
+}
+
+func (suite *FingerTestSuite) TestFinger() {
+ wc := suite.state.Caches.GTS.Webfinger()
+ suite.Equal(0, wc.Len(), "expect webfinger cache to be empty")
+
+ _, err := suite.transport.Finger(context.TODO(), "brand_new_person", "unknown-instance.com")
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Equal(0, wc.Len(), "expect webfinger cache to be empty for normal webfinger request")
+}
+
+func (suite *FingerTestSuite) TestFingerWithHostMeta() {
+ wc := suite.state.Caches.GTS.Webfinger()
+ suite.Equal(0, wc.Len(), "expect webfinger cache to be empty")
+
+ _, err := suite.transport.Finger(context.TODO(), "someone", "misconfigured-instance.com")
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Equal(1, wc.Len(), "expect webfinger cache to hold one entry")
+ suite.True(wc.Has("misconfigured-instance.com"), "expect webfinger cache to have entry for misconfigured-instance.com")
+}
+
+func (suite *FingerTestSuite) TestFingerWithHostMetaCacheStrategy() {
+ wc := suite.state.Caches.GTS.Webfinger()
+ suite.Equal(0, wc.Len(), "expect webfinger cache to be empty")
+
+ _, err := suite.transport.Finger(context.TODO(), "someone", "misconfigured-instance.com")
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Equal(1, wc.Len(), "expect webfinger cache to hold one entry")
+ wc.Lock()
+ suite.True(wc.Cache.Has("misconfigured-instance.com"), "expect webfinger cache to have entry for misconfigured-instance.com")
+ ent, _ := wc.Cache.Get("misconfigured-instance.com")
+ wc.Unlock()
+
+ initialTime := ent.Expiry
+
+ // finger them again
+ _, err = suite.transport.Finger(context.TODO(), "someone", "misconfigured-instance.com")
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // there should still only be 1 cache entry
+ suite.Equal(1, wc.Len(), "expect webfinger cache to hold one entry")
+ wc.Lock()
+ suite.True(wc.Cache.Has("misconfigured-instance.com"), "expect webfinger cache to have entry for misconfigured-instance.com")
+ rep, _ := wc.Cache.Get("misconfigured-instance.com")
+ wc.Unlock()
+
+ repeatTime := rep.Expiry
+
+ // the TTL of the entry should have extended because we did a second
+ // successful finger
+ suite.NotEqual(initialTime, repeatTime, "expected webfinger cache entry to have different expiry times")
+ if repeatTime.Before(initialTime) {
+ suite.FailNow("expected webfinger cache entry to not be a time traveller")
+ }
+
+ // finger a non-existing user on that same instance which will return an error
+ _, err = suite.transport.Finger(context.TODO(), "invalid", "misconfigured-instance.com")
+ if err == nil {
+ suite.FailNow("expected request for invalid user to fail")
+ }
+
+ // there should still only be 1 cache entry, because we don't evict from cache on failure
+ suite.Equal(1, wc.Len(), "expect webfinger cache to hold one entry")
+ wc.Lock()
+ suite.True(wc.Cache.Has("misconfigured-instance.com"), "expect webfinger cache to have entry for misconfigured-instance.com")
+ last, _ := wc.Cache.Get("misconfigured-instance.com")
+ wc.Unlock()
+
+ lastTime := last.Expiry
+
+ // The TTL of the previous and new entry should be the same since
+ // a failed request must not extend the entry TTL
+ suite.Equal(repeatTime, lastTime)
+}
+
+func TestFingerTestSuite(t *testing.T) {
+ suite.Run(t, &FingerTestSuite{})
+}
diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go
new file mode 100644
index 000000000..5ee597e45
--- /dev/null
+++ b/internal/transport/transport_test.go
@@ -0,0 +1,101 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package transport_test
+
+import (
+ "context"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/email"
+ "github.com/superseriousbusiness/gotosocial/internal/federation"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/media"
+ "github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/storage"
+ "github.com/superseriousbusiness/gotosocial/internal/transport"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type TransportTestSuite struct {
+ // standard suite interfaces
+ suite.Suite
+ db db.DB
+ storage *storage.Driver
+ mediaManager media.Manager
+ federator federation.Federator
+ processor *processing.Processor
+ emailSender email.Sender
+ sentEmails map[string]string
+ state state.State
+
+ // standard suite models
+ testTokens map[string]*gtsmodel.Token
+ testClients map[string]*gtsmodel.Client
+ testApplications map[string]*gtsmodel.Application
+ testUsers map[string]*gtsmodel.User
+ testAccounts map[string]*gtsmodel.Account
+
+ transport transport.Transport
+}
+
+func (suite *TransportTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+}
+
+func (suite *TransportTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
+ testrig.InitTestConfig()
+ testrig.InitTestLog()
+
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
+ suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../testrig/media")), suite.mediaManager)
+ suite.sentEmails = make(map[string]string)
+ suite.emailSender = testrig.NewEmailSender("../../web/template/", suite.sentEmails)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
+
+ testrig.StandardDBSetup(suite.db, nil)
+ testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
+
+ ts, err := suite.federator.TransportController().NewTransportForUsername(context.TODO(), "")
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.transport = ts
+
+ suite.NoError(suite.processor.Start())
+}
+
+func (suite *TransportTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+ testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
+}
diff --git a/test/envparsing.sh b/test/envparsing.sh
index 8d795b6b3..9f16e026c 100755
--- a/test/envparsing.sh
+++ b/test/envparsing.sh
@@ -52,7 +52,10 @@ EXPECT=$(cat <<"EOF"
"tombstone-ttl": 300000000000,
"user-max-size": 100,
"user-sweep-freq": 30000000000,
- "user-ttl": 300000000000
+ "user-ttl": 300000000000,
+ "webfinger-max-size": 250,
+ "webfinger-sweep-freq": 900000000000,
+ "webfinger-ttl": 86400000000000
}
},
"config-path": "internal/config/testdata/test.yaml",
diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go
index 9657205f6..aeb9d4dfa 100644
--- a/testrig/transportcontroller.go
+++ b/testrig/transportcontroller.go
@@ -21,6 +21,7 @@
import (
"bytes"
"encoding/json"
+ "encoding/xml"
"io"
"net/http"
"strings"
@@ -52,7 +53,7 @@
// PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular)
// basis.
func NewTestTransportController(state *state.State, client pub.HttpClient) transport.Controller {
- return transport.NewController(state.DB, NewTestFederatingDB(state), &federation.Clock{}, client)
+ return transport.NewController(state, NewTestFederatingDB(state), &federation.Clock{}, client)
}
type MockHTTPClient struct {
@@ -121,6 +122,10 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat
responseContentLength = len(responseBytes)
} else if strings.Contains(req.URL.String(), ".well-known/webfinger") {
responseCode, responseBytes, responseContentType, responseContentLength = WebfingerResponse(req)
+ } else if strings.Contains(req.URL.String(), ".weird-webfinger-location/webfinger") {
+ responseCode, responseBytes, responseContentType, responseContentLength = WebfingerResponse(req)
+ } else if strings.Contains(req.URL.String(), ".well-known/host-meta") {
+ responseCode, responseBytes, responseContentType, responseContentLength = HostMetaResponse(req)
} else if note, ok := mockHTTPClient.TestRemoteStatuses[req.URL.String()]; ok {
// the request is for a note that we have stored
noteI, err := streams.Serialize(note)
@@ -221,11 +226,47 @@ func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
return m.do(req)
}
+func HostMetaResponse(req *http.Request) (responseCode int, responseBytes []byte, responseContentType string, responseContentLength int) {
+ var hm *apimodel.HostMeta
+
+ if req.URL.String() == "https://misconfigured-instance.com/.well-known/host-meta" {
+ hm = &apimodel.HostMeta{
+ XMLNS: "http://docs.oasis-open.org/ns/xri/xrd-1.0",
+ Link: []apimodel.Link{
+ {
+ Rel: "lrdd",
+ Type: "application/xrd+xml",
+ Template: "https://misconfigured-instance.com/.weird-webfinger-location/webfinger?resource={uri}",
+ },
+ },
+ }
+ }
+
+ if hm == nil {
+ log.Debugf(nil, "hostmeta response not available for %s", req.URL)
+ responseCode = http.StatusNotFound
+ responseBytes = []byte(``)
+ responseContentType = "application/xml"
+ responseContentLength = len(responseBytes)
+ return
+ }
+
+ hmXML, err := xml.Marshal(hm)
+ if err != nil {
+ panic(err)
+ }
+ responseCode = http.StatusOK
+ responseBytes = hmXML
+ responseContentType = "application/xml"
+ responseContentLength = len(hmXML)
+ return
+}
+
func WebfingerResponse(req *http.Request) (responseCode int, responseBytes []byte, responseContentType string, responseContentLength int) {
var wfr *apimodel.WellKnownResponse
switch req.URL.String() {
- case "https://unknown-instance.com/.well-known/webfinger?resource=acct:some_group@unknown-instance.com":
+ case "https://unknown-instance.com/.well-known/webfinger?resource=acct%3Asome_group%40unknown-instance.com":
wfr = &apimodel.WellKnownResponse{
Subject: "acct:some_group@unknown-instance.com",
Links: []apimodel.Link{
@@ -236,7 +277,7 @@ func WebfingerResponse(req *http.Request) (responseCode int, responseBytes []byt
},
},
}
- case "https://owncast.example.org/.well-known/webfinger?resource=acct:rgh@owncast.example.org":
+ case "https://owncast.example.org/.well-known/webfinger?resource=acct%3Argh%40owncast.example.org":
wfr = &apimodel.WellKnownResponse{
Subject: "acct:rgh@example.org",
Links: []apimodel.Link{
@@ -247,7 +288,7 @@ func WebfingerResponse(req *http.Request) (responseCode int, responseBytes []byt
},
},
}
- case "https://unknown-instance.com/.well-known/webfinger?resource=acct:brand_new_person@unknown-instance.com":
+ case "https://unknown-instance.com/.well-known/webfinger?resource=acct%3Abrand_new_person%40unknown-instance.com":
wfr = &apimodel.WellKnownResponse{
Subject: "acct:brand_new_person@unknown-instance.com",
Links: []apimodel.Link{
@@ -258,7 +299,7 @@ func WebfingerResponse(req *http.Request) (responseCode int, responseBytes []byt
},
},
}
- case "https://turnip.farm/.well-known/webfinger?resource=acct:turniplover6969@turnip.farm":
+ case "https://turnip.farm/.well-known/webfinger?resource=acct%3Aturniplover6969%40turnip.farm":
wfr = &apimodel.WellKnownResponse{
Subject: "acct:turniplover6969@turnip.farm",
Links: []apimodel.Link{
@@ -269,7 +310,7 @@ func WebfingerResponse(req *http.Request) (responseCode int, responseBytes []byt
},
},
}
- case "https://fossbros-anonymous.io/.well-known/webfinger?resource=acct:foss_satan@fossbros-anonymous.io":
+ case "https://fossbros-anonymous.io/.well-known/webfinger?resource=acct%3Afoss_satan%40fossbros-anonymous.io":
wfr = &apimodel.WellKnownResponse{
Subject: "acct:foss_satan@fossbros-anonymous.io",
Links: []apimodel.Link{
@@ -280,7 +321,7 @@ func WebfingerResponse(req *http.Request) (responseCode int, responseBytes []byt
},
},
}
- case "https://example.org/.well-known/webfinger?resource=acct:Some_User@example.org":
+ case "https://example.org/.well-known/webfinger?resource=acct%3ASome_User%40example.org":
wfr = &apimodel.WellKnownResponse{
Subject: "acct:Some_User@example.org",
Links: []apimodel.Link{
@@ -291,6 +332,17 @@ func WebfingerResponse(req *http.Request) (responseCode int, responseBytes []byt
},
},
}
+ case "https://misconfigured-instance.com/.weird-webfinger-location/webfinger?resource=acct%3Asomeone%40misconfigured-instance.com":
+ wfr = &apimodel.WellKnownResponse{
+ Subject: "acct:someone@misconfigured-instance.com",
+ Links: []apimodel.Link{
+ {
+ Rel: "self",
+ Type: applicationActivityJSON,
+ Href: "https://misconfigured-instance.com/users/someone",
+ },
+ },
+ }
}
if wfr == nil {