[chore] better dns validation (#3644)

* add seperate PunifyValidate() function for properly validating domain names when converting to punycode

* rename function, strip port from domain validation
This commit is contained in:
kim 2025-01-14 14:23:18 +00:00 committed by GitHub
parent b95498b8c2
commit e77c7e16b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 203 additions and 173 deletions

View file

@ -28,6 +28,7 @@
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/instance" "github.com/superseriousbusiness/gotosocial/internal/api/client/instance"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -51,6 +52,7 @@ func (suite *InstancePatchTestSuite) instancePatch(fieldName string, fileName st
ctx := suite.newContext(recorder, http.MethodPatch, instance.InstanceInformationPathV1, requestBody.Bytes(), w.FormDataContentType(), true) ctx := suite.newContext(recorder, http.MethodPatch, instance.InstanceInformationPathV1, requestBody.Bytes(), w.FormDataContentType(), true)
suite.instanceModule.InstanceUpdatePATCHHandler(ctx) suite.instanceModule.InstanceUpdatePATCHHandler(ctx)
middleware.Logger(false)(ctx)
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()

View file

@ -137,8 +137,9 @@ func(account *gtsmodel.Account) error {
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) { func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) {
if domain != "" { if domain != "" {
// Normalize the domain as punycode
var err error var err error
// Normalize the domain as punycode
domain, err = util.Punify(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -36,12 +36,12 @@ type domainDB struct {
state *state.State state *state.State
} }
func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error { func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) (err error) {
// Normalize the domain as punycode // Normalize the domain as punycode, note the extra
var err error // validation step for domain name write operations.
allow.Domain, err = util.Punify(allow.Domain) allow.Domain, err = util.PunifySafely(allow.Domain)
if err != nil { if err != nil {
return err return gtserror.Newf("error punifying domain %s: %w", allow.Domain, err)
} }
// Attempt to store domain allow in DB // Attempt to store domain allow in DB
@ -58,10 +58,10 @@ func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.Domain
} }
func (d *domainDB) GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) { func (d *domainDB) GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) {
// Normalize the domain as punycode // Normalize domain as punycode for lookup.
domain, err := util.Punify(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return nil, err return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
// Check for easy case, domain referencing *us* // Check for easy case, domain referencing *us*
@ -111,12 +111,12 @@ func (d *domainDB) GetDomainAllowByID(ctx context.Context, id string) (*gtsmodel
return &allow, nil return &allow, nil
} }
func (d *domainDB) UpdateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow, columns ...string) error { func (d *domainDB) UpdateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow, columns ...string) (err error) {
// Normalize the domain as punycode // Normalize the domain as punycode, note the extra
var err error // validation step for domain name write operations.
allow.Domain, err = util.Punify(allow.Domain) allow.Domain, err = util.PunifySafely(allow.Domain)
if err != nil { if err != nil {
return err return gtserror.Newf("error punifying domain %s: %w", allow.Domain, err)
} }
// Ensure updated_at is set. // Ensure updated_at is set.
@ -142,10 +142,10 @@ func (d *domainDB) UpdateDomainAllow(ctx context.Context, allow *gtsmodel.Domain
} }
func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error { func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {
// Normalize the domain as punycode // Normalize domain as punycode for lookup.
domain, err := util.Punify(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return err return gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
// Attempt to delete domain allow // Attempt to delete domain allow
@ -163,11 +163,13 @@ func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {
} }
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error { func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error {
// Normalize the domain as punycode
var err error var err error
block.Domain, err = util.Punify(block.Domain)
// Normalize the domain as punycode, note the extra
// validation step for domain name write operations.
block.Domain, err = util.PunifySafely(block.Domain)
if err != nil { if err != nil {
return err return gtserror.Newf("error punifying domain %s: %w", block.Domain, err)
} }
// Attempt to store domain block in DB // Attempt to store domain block in DB
@ -184,10 +186,10 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
} }
func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error) { func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error) {
// Normalize the domain as punycode // Normalize domain as punycode for lookup.
domain, err := util.Punify(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return nil, err return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
// Check for easy case, domain referencing *us* // Check for easy case, domain referencing *us*
@ -238,11 +240,13 @@ func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel
} }
func (d *domainDB) UpdateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock, columns ...string) error { func (d *domainDB) UpdateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock, columns ...string) error {
// Normalize the domain as punycode
var err error var err error
block.Domain, err = util.Punify(block.Domain)
// Normalize the domain as punycode, note the extra
// validation step for domain name write operations.
block.Domain, err = util.PunifySafely(block.Domain)
if err != nil { if err != nil {
return err return gtserror.Newf("error punifying domain %s: %w", block.Domain, err)
} }
// Ensure updated_at is set. // Ensure updated_at is set.
@ -268,10 +272,10 @@ func (d *domainDB) UpdateDomainBlock(ctx context.Context, block *gtsmodel.Domain
} }
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error { func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {
// Normalize the domain as punycode // Normalize domain as punycode for lookup.
domain, err := util.Punify(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return err return gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
// Attempt to delete domain block // Attempt to delete domain block
@ -289,10 +293,10 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {
} }
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, error) { func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, error) {
// Normalize the domain as punycode // Normalize domain as punycode for lookup.
domain, err := util.Punify(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return false, err return false, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
// Domain referencing *us* cannot be blocked. // Domain referencing *us* cannot be blocked.

View file

@ -168,7 +168,7 @@ func (d *domainDB) GetDomainPermissionDrafts(
if domain != "" { if domain != "" {
var err error var err error
// Normalize domain as punycode. // Normalize domain as punycode for lookup.
domain, err = util.Punify(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
@ -234,22 +234,23 @@ func (d *domainDB) GetDomainPermissionDrafts(
func (d *domainDB) PutDomainPermissionDraft( func (d *domainDB) PutDomainPermissionDraft(
ctx context.Context, ctx context.Context,
permDraft *gtsmodel.DomainPermissionDraft, draft *gtsmodel.DomainPermissionDraft,
) error { ) error {
var err error var err error
// Normalize the domain as punycode // Normalize the domain as punycode, note the extra
permDraft.Domain, err = util.Punify(permDraft.Domain) // validation step for domain name write operations.
draft.Domain, err = util.PunifySafely(draft.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", permDraft.Domain, err) return gtserror.Newf("error punifying domain %s: %w", draft.Domain, err)
} }
return d.state.Caches.DB.DomainPermissionDraft.Store( return d.state.Caches.DB.DomainPermissionDraft.Store(
permDraft, draft,
func() error { func() error {
_, err := d.db. _, err := d.db.
NewInsert(). NewInsert().
Model(permDraft). Model(draft).
Exec(ctx) Exec(ctx)
return err return err
}, },

View file

@ -37,11 +37,13 @@ func (d *domainDB) PutDomainPermissionExclude(
ctx context.Context, ctx context.Context,
exclude *gtsmodel.DomainPermissionExclude, exclude *gtsmodel.DomainPermissionExclude,
) error { ) error {
// Normalize the domain as punycode
var err error var err error
exclude.Domain, err = util.Punify(exclude.Domain)
// Normalize the domain as punycode, note the extra
// validation step for domain name write operations.
exclude.Domain, err = util.PunifySafely(exclude.Domain)
if err != nil { if err != nil {
return err return gtserror.Newf("error punifying domain %s: %w", exclude.Domain, err)
} }
// Attempt to store domain perm exclude in DB // Attempt to store domain perm exclude in DB
@ -58,10 +60,10 @@ func (d *domainDB) PutDomainPermissionExclude(
} }
func (d *domainDB) IsDomainPermissionExcluded(ctx context.Context, domain string) (bool, error) { func (d *domainDB) IsDomainPermissionExcluded(ctx context.Context, domain string) (bool, error) {
// Normalize the domain as punycode // Normalize domain as punycode for lookup.
domain, err := util.Punify(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
return false, err return false, gtserror.Newf("error punifying domain %s: %w", domain, err)
} }
// Func to scan list of all // Func to scan list of all
@ -177,7 +179,7 @@ func (d *domainDB) GetDomainPermissionExcludes(
if domain != "" { if domain != "" {
var err error var err error
// Normalize domain as punycode. // Normalize domain as punycode for lookup.
domain, err = util.Punify(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)

View file

@ -158,8 +158,9 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
} }
func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, error) { func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, error) {
// Normalize the domain as punycode
var err error var err error
// Normalize the domain as punycode
domain, err = util.Punify(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
@ -265,8 +266,9 @@ func (i *instanceDB) PopulateInstance(ctx context.Context, instance *gtsmodel.In
func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error { func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error {
var err error var err error
// Normalize the domain as punycode // Normalize the domain as punycode, note the extra
instance.Domain, err = util.Punify(instance.Domain) // validation step for domain name write operations.
instance.Domain, err = util.PunifySafely(instance.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
} }
@ -279,9 +281,11 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc
} }
func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error { func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error {
// Normalize the domain as punycode
var err error var err error
instance.Domain, err = util.Punify(instance.Domain)
// Normalize the domain as punycode, note the extra
// validation step for domain name write operations.
instance.Domain, err = util.PunifySafely(instance.Domain)
if err != nil { if err != nil {
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
} }
@ -349,8 +353,9 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
limit = 0 limit = 0
} }
// Normalize the domain as punycode.
var err error var err error
// Normalize the domain as punycode
domain, err = util.Punify(domain) domain, err = util.Punify(domain)
if err != nil { if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)

View file

@ -23,7 +23,6 @@
"encoding/csv" "encoding/csv"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"slices" "slices"
"strconv" "strconv"
@ -32,7 +31,6 @@
"codeberg.org/gruf/go-kv" "codeberg.org/gruf/go-kv"
"github.com/miekg/dns"
"github.com/superseriousbusiness/gotosocial/internal/admin" "github.com/superseriousbusiness/gotosocial/internal/admin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
@ -629,7 +627,7 @@ func permsFromCSV(
// Normalize + validate domain. // Normalize + validate domain.
domainRaw := record[*domainI] domainRaw := record[*domainI]
domain, err := validateDomain(domainRaw) domain, err := util.PunifySafely(domainRaw)
if err != nil { if err != nil {
l.Warnf("skipping invalid domain %s: %+v", domainRaw, err) l.Warnf("skipping invalid domain %s: %+v", domainRaw, err)
continue continue
@ -702,7 +700,7 @@ func permsFromJSON(
// Normalize + validate domain. // Normalize + validate domain.
domainRaw := apiPerm.Domain.Domain domainRaw := apiPerm.Domain.Domain
domain, err := validateDomain(domainRaw) domain, err := util.PunifySafely(domainRaw)
if err != nil { if err != nil {
l.Warnf("skipping invalid domain %s: %+v", domainRaw, err) l.Warnf("skipping invalid domain %s: %+v", domainRaw, err)
continue continue
@ -757,8 +755,8 @@ func permsFromPlain(
perms := make([]gtsmodel.DomainPermission, 0, len(domains)) perms := make([]gtsmodel.DomainPermission, 0, len(domains))
for _, domainRaw := range domains { for _, domainRaw := range domains {
// Normalize + validate domain. // Normalize + validate domain as ASCII.
domain, err := validateDomain(domainRaw) domain, err := util.PunifySafely(domainRaw)
if err != nil { if err != nil {
l.Warnf("skipping invalid domain %s: %+v", domainRaw, err) l.Warnf("skipping invalid domain %s: %+v", domainRaw, err)
continue continue
@ -781,30 +779,6 @@ func permsFromPlain(
return perms, nil return perms, nil
} }
func validateDomain(domain string) (string, error) {
// Basic validation.
if _, ok := dns.IsDomainName(domain); !ok {
err := fmt.Errorf("invalid domain name")
return "", err
}
// Convert to punycode.
domain, err := util.Punify(domain)
if err != nil {
err := fmt.Errorf("could not punify domain: %w", err)
return "", err
}
// Check for invalid characters
// after the punification process.
if strings.ContainsAny(domain, "*, \n") {
err := fmt.Errorf("invalid char(s) in domain")
return "", err
}
return domain, nil
}
func (s *Subscriptions) existingCovered( func (s *Subscriptions) existingCovered(
ctx context.Context, ctx context.Context,
permType gtsmodel.DomainPermissionType, permType gtsmodel.DomainPermissionType,

138
internal/util/domain.go Normal file
View file

@ -0,0 +1,138 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package util
import (
"net/url"
"strings"
"golang.org/x/net/idna"
)
var (
// IDNA (Internationalized Domain Names for Applications)
// profiles for fast punycode conv and full verification.
punifyProfile = *idna.Punycode
verifyProfile = *idna.Lookup
)
// PunifySafely validates the provided domain name,
// and converts unicode chars to ASCII, i.e. punified form.
func PunifySafely(domain string) (string, error) {
if i := strings.LastIndexByte(domain, ':'); i >= 0 {
// If there is a port included in domain, we
// strip it as colon is invalid in a hostname.
domain, port := domain[:i], domain[i:]
domain, err := verifyProfile.ToASCII(domain)
if err != nil {
return "", err
}
// Then rebuild with port after.
domain = strings.ToLower(domain)
return domain + port, nil
} else { //nolint:revive
// Otherwise we just punify domain as-is.
domain, err := verifyProfile.ToASCII(domain)
return strings.ToLower(domain), err
}
}
// Punify is a faster form of PunifySafely() without validation.
func Punify(domain string) (string, error) {
domain, err := punifyProfile.ToASCII(domain)
return strings.ToLower(domain), err
}
// DePunify converts any punycode-encoded unicode characters
// in domain name back to their origin unicode. Please note
// that this performs minimal validation of domain name.
func DePunify(domain string) (string, error) {
domain = strings.ToLower(domain)
return punifyProfile.ToUnicode(domain)
}
// URIMatches returns true if the expected URI matches
// any of the given URIs, taking account of punycode.
func URIMatches(expect *url.URL, uris ...*url.URL) (ok bool, err error) {
// Create new URL to hold
// punified URI information.
punyURI := new(url.URL)
*punyURI = *expect
// Set punified expected URL host.
punyURI.Host, err = Punify(expect.Host)
if err != nil {
return false, err
}
// Calculate expected URI string.
expectStr := punyURI.String()
// Use punyURI to iteratively
// store each punified URI info
// and generate punified URI
// strings to check against.
for _, uri := range uris {
*punyURI = *uri
punyURI.Host, err = Punify(uri.Host)
if err != nil {
return false, err
}
// Check for a match against expect.
if expectStr == punyURI.String() {
return true, nil
}
}
// Didn't match.
return false, nil
}
// PunifyURI returns a new copy of URI with the 'host'
// part converted to punycode with PunifySafely().
// For simple comparisons prefer the faster URIMatches().
func PunifyURI(in *url.URL) (*url.URL, error) {
punyHost, err := PunifySafely(in.Host)
if err != nil {
return nil, err
}
out := new(url.URL)
*out = *in
out.Host = punyHost
return out, nil
}
// PunifyURIToStr returns given URI serialized with the
// 'host' part converted to punycode with PunifySafely().
// For simple comparisons prefer the faster URIMatches().
func PunifyURIToStr(in *url.URL) (string, error) {
punyHost, err := PunifySafely(in.Host)
if err != nil {
return "", err
}
oldHost := in.Host
in.Host = punyHost
str := in.String()
in.Host = oldHost
return str, nil
}

View file

@ -1,97 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package util
import (
"net/url"
"strings"
"golang.org/x/net/idna"
)
// Punify converts the given domain to lowercase
// then to punycode (for international domain names).
//
// Returns the resulting domain or an error if the
// punycode conversion fails.
func Punify(domain string) (string, error) {
domain = strings.ToLower(domain)
return idna.ToASCII(domain)
}
// DePunify converts the given punycode string
// to its original unicode representation (lowercased).
// Noop if the domain is (already) not puny.
//
// Returns an error if conversion fails.
func DePunify(domain string) (string, error) {
out, err := idna.ToUnicode(domain)
return strings.ToLower(out), err
}
// URIMatches returns true if the expected URI matches
// any of the given URIs, taking account of punycode.
func URIMatches(expect *url.URL, uris ...*url.URL) (bool, error) {
// Normalize expect to punycode.
expectStr, err := PunifyURIToStr(expect)
if err != nil {
return false, err
}
for _, uri := range uris {
uriStr, err := PunifyURIToStr(uri)
if err != nil {
return false, err
}
if uriStr == expectStr {
// Looks good.
return true, nil
}
}
// Didn't match.
return false, nil
}
// PunifyURI returns a copy of the given URI
// with the 'host' part converted to punycode.
func PunifyURI(in *url.URL) (*url.URL, error) {
punyHost, err := Punify(in.Host)
if err != nil {
return nil, err
}
out := new(url.URL)
*out = *in
out.Host = punyHost
return out, nil
}
// PunifyURIToStr returns given URI serialized
// with the 'host' part converted to punycode.
func PunifyURIToStr(in *url.URL) (string, error) {
punyHost, err := Punify(in.Host)
if err != nil {
return "", err
}
oldHost := in.Host
in.Host = punyHost
str := in.String()
in.Host = oldHost
return str, nil
}