diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go index 50bbd43bf..864d85428 100644 --- a/internal/cache/domain/domain.go +++ b/internal/cache/domain/domain.go @@ -19,151 +19,206 @@ import ( "fmt" - "time" + "strings" + "sync/atomic" + "unsafe" - "codeberg.org/gruf/go-cache/v3/ttl" - "github.com/miekg/dns" + "golang.org/x/exp/slices" ) // BlockCache provides a means of caching domain blocks in memory to reduce load // on an underlying storage mechanism, e.g. a database. // -// It consists of a TTL primary cache that stores calculated domain string to block results, -// that on cache miss is filled by calculating block status by iterating over a list of all of -// the domain blocks stored in memory. This reduces CPU usage required by not need needing to -// iterate through a possible 100-1000s long block list, while saving memory by having a primary -// cache of limited size that evicts stale entries. The raw list of all domain blocks should in -// most cases be negligible when it comes to memory usage. -// // The in-memory block list is kept up-to-date by means of a passed loader function during every // call to .IsBlocked(). In the case of a nil internal block list, the loader function is called to -// hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to invalidate -// the cache, e.g. when a domain block is added / deleted from the database. It will drop the current -// list of domain blocks and clear all entries from the primary cache. +// hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to +// invalidate the cache, e.g. when a domain block is added / deleted from the database. type BlockCache struct { - pcache *ttl.Cache[string, bool] // primary cache of domains -> block results - blocks []block // raw list of all domain blocks, nil => not loaded. -} - -// New returns a new initialized BlockCache instance with given primary cache capacity and TTL. -func New(pcap int, pttl time.Duration) *BlockCache { - c := new(BlockCache) - c.pcache = new(ttl.Cache[string, bool]) - c.pcache.Init(0, pcap, pttl) - return c -} - -// Start will start the cache background eviction routine with given sweep frequency. If already running or a freq <= 0 provided, this is a no-op. This will block until the eviction routine has started. -func (b *BlockCache) Start(pfreq time.Duration) bool { - return b.pcache.Start(pfreq) -} - -// Stop will stop cache background eviction routine. If not running this is a no-op. This will block until the eviction routine has stopped. -func (b *BlockCache) Stop() bool { - return b.pcache.Stop() + // atomically updated ptr value to the + // current domain block cache radix trie. + rootptr unsafe.Pointer } // IsBlocked checks whether domain is blocked. If the cache is not currently loaded, then the provided load function is used to hydrate it. -// NOTE: be VERY careful using any kind of locking mechanism within the load function, as this itself is ran within the cache mutex lock. func (b *BlockCache) IsBlocked(domain string, load func() ([]string, error)) (bool, error) { - var blocked bool + // Load the current root pointer value. + ptr := atomic.LoadPointer(&b.rootptr) - // Acquire cache lock - b.pcache.Lock() - defer b.pcache.Unlock() - - // Check primary cache for result - entry, ok := b.pcache.Cache.Get(domain) - if ok { - return entry.Value, nil - } - - if b.blocks == nil { - // Cache is not hydrated + if ptr == nil { + // Cache is not hydrated. // - // Load domains from callback + // Load domains from callback. domains, err := load() if err != nil { return false, fmt.Errorf("error reloading cache: %w", err) } - // Drop all domain blocks and recreate - b.blocks = make([]block, len(domains)) + // Allocate new radix trie + // node to store matches. + root := new(root) - for i, domain := range domains { - // Store pre-split labels for each domain block - b.blocks[i].labels = dns.SplitDomainName(domain) + // Add each domain to the trie. + for _, domain := range domains { + root.Add(domain) } + + // Sort the trie. + root.Sort() + + // Store the new node ptr. + ptr = unsafe.Pointer(root) + atomic.StorePointer(&b.rootptr, ptr) } - // Split domain into it separate labels - labels := dns.SplitDomainName(domain) - - // Compare this to our stored blocks - for _, block := range b.blocks { - if block.Blocks(labels) { - blocked = true - break - } - } - - // Store block result in primary cache - b.pcache.Cache.Set(domain, &ttl.Entry[string, bool]{ - Key: domain, - Value: blocked, - Expiry: time.Now().Add(b.pcache.TTL), - }) - - return blocked, nil + // Look for a match in the trie node. + return (*root)(ptr).Match(domain), nil } -// Clear will drop the currently loaded domain list, and clear the primary cache. -// This will trigger a reload on next call to .IsBlocked(). +// Clear will drop the currently loaded domain list, +// triggering a reload on next call to .IsBlocked(). func (b *BlockCache) Clear() { - // Drop all blocks. - b.pcache.Lock() - b.blocks = nil - b.pcache.Unlock() - - // Clear needs to be done _outside_ of - // lock, as also acquires a mutex lock. - b.pcache.Clear() + atomic.StorePointer(&b.rootptr, nil) } -// block represents a domain block, and stores the -// deconstructed labels of a singular domain block. -// e.g. []string{"gts", "superseriousbusiness", "org"}. -type block struct { - labels []string +// root is the root node in the domain +// block cache radix trie. this is the +// singular access point to the trie. +type root struct{ root node } + +// Add will add the given domain to the radix trie. +func (r *root) Add(domain string) { + r.root.add(strings.Split(domain, ".")) } -// Blocks checks whether the separated domain labels of an -// incoming domain matches the stored (receiving struct) block. -func (b block) Blocks(labels []string) bool { - // Calculate length difference - d := len(labels) - len(b.labels) - if d < 0 { +// Match will return whether the given domain matches +// an existing stored domain block in this radix trie. +func (r *root) Match(domain string) bool { + return r.root.match(strings.Split(domain, ".")) +} + +// Sort will sort the entire radix trie ensuring that +// child nodes are stored in alphabetical order. This +// MUST be done to finalize the block cache in order +// to speed up the binary search of node child parts. +func (r *root) Sort() { + r.root.sort() +} + +type node struct { + part string + child []*node +} + +func (n *node) add(parts []string) { + if len(parts) == 0 { + panic("invalid domain") + } + + for { + // Pop next domain part. + i := len(parts) - 1 + part := parts[i] + parts = parts[:i] + + var nn *node + + // Look for existing child node + // that matches next domain part. + for _, child := range n.child { + if child.part == part { + nn = child + break + } + } + + if nn == nil { + // Alloc new child node. + nn = &node{part: part} + n.child = append(n.child, nn) + } + + if len(parts) == 0 { + // Drop all children here as + // this is a higher-level block + // than that we previously had. + nn.child = nil + return + } + + // Re-iter with + // child node. + n = nn + } +} + +func (n *node) match(parts []string) bool { + if len(parts) == 0 { + // Invalid domain. return false } - // Iterate backwards through domain block's - // labels, omparing against the incoming domain's. - // - // So for the following input: - // labels = []string{"mail", "google", "com"} - // b.labels = []string{"google", "com"} - // - // These would be matched in reverse order along - // the entirety of the block object's labels: - // "com" => match - // "google" => match - // - // And so would reach the end and return true. - for i := len(b.labels) - 1; i >= 0; i-- { - if b.labels[i] != labels[i+d] { + for { + // Pop next domain part. + i := len(parts) - 1 + part := parts[i] + parts = parts[:i] + + // Look for existing child + // that matches next part. + nn := n.getChild(part) + + if nn == nil { + // No match :( return false } + + if len(nn.child) == 0 { + // It's a match! + return true + } + + // Re-iter with + // child node. + n = nn + } +} + +// getChild fetches child node with given domain part string +// using a binary search. THIS ASSUMES CHILDREN ARE SORTED. +func (n *node) getChild(part string) *node { + i, j := 0, len(n.child) + + for i < j { + // avoid overflow when computing h + h := int(uint(i+j) >> 1) + // i ≤ h < j + + if n.child[h].part < part { + // preserves: + // n.child[i-1].part != part + i = h + 1 + } else { + // preserves: + // n.child[h].part == part + j = h + } } - return true + if i >= len(n.child) || n.child[i].part != part { + return nil // no match + } + + return n.child[i] +} + +func (n *node) sort() { + // Sort this node's slice of child nodes. + slices.SortFunc(n.child, func(i, j *node) bool { + return i.part < j.part + }) + + // Sort each child node's children. + for _, child := range n.child { + child.sort() + } } diff --git a/internal/cache/domain/domain_test.go b/internal/cache/domain/domain_test.go index d8c3205a6..b5937978c 100644 --- a/internal/cache/domain/domain_test.go +++ b/internal/cache/domain/domain_test.go @@ -20,13 +20,12 @@ import ( "errors" "testing" - "time" "github.com/superseriousbusiness/gotosocial/internal/cache/domain" ) func TestBlockCache(t *testing.T) { - c := domain.New(100, time.Second) + c := new(domain.BlockCache) blocks := []string{ "google.com", diff --git a/internal/cache/gts.go b/internal/cache/gts.go index a96bc3608..1032a5611 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -72,12 +72,6 @@ func (c *GTSCaches) Init() { func (c *GTSCaches) Start() { tryStart(c.account, config.GetCacheGTSAccountSweepFreq()) tryStart(c.block, config.GetCacheGTSBlockSweepFreq()) - tryUntil("starting domain block cache", 5, func() bool { - if sweep := config.GetCacheGTSDomainBlockSweepFreq(); sweep > 0 { - return c.domainBlock.Start(sweep) - } - return true - }) tryStart(c.emoji, config.GetCacheGTSEmojiSweepFreq()) tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStart(c.follow, config.GetCacheGTSFollowSweepFreq()) @@ -102,7 +96,6 @@ func (c *GTSCaches) Start() { func (c *GTSCaches) Stop() { tryStop(c.account, config.GetCacheGTSAccountSweepFreq()) tryStop(c.block, config.GetCacheGTSBlockSweepFreq()) - tryUntil("stopping domain block cache", 5, c.domainBlock.Stop) tryStop(c.emoji, config.GetCacheGTSEmojiSweepFreq()) tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStop(c.follow, config.GetCacheGTSFollowSweepFreq()) @@ -233,10 +226,7 @@ func (c *GTSCaches) initBlock() { } func (c *GTSCaches) initDomainBlock() { - c.domainBlock = domain.New( - config.GetCacheGTSDomainBlockMaxSize(), - config.GetCacheGTSDomainBlockTTL(), - ) + c.domainBlock = new(domain.BlockCache) } func (c *GTSCaches) initEmoji() {