mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-22 16:46:53 +01:00
admin: Enforce and refactor origin checking
Using URLs seems a little cleaner and more correct cf: https://caddy.community/t/protect-admin-endpoint/15114 (This used to work. Something must have changed recently.)
This commit is contained in:
parent
1d0425b26f
commit
40b54434f3
1 changed files with 57 additions and 26 deletions
83
admin.go
83
admin.go
|
@ -42,6 +42,7 @@ import (
|
||||||
"github.com/caddyserver/certmagic"
|
"github.com/caddyserver/certmagic"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AdminConfig configures Caddy's API endpoint, which is used
|
// AdminConfig configures Caddy's API endpoint, which is used
|
||||||
|
@ -192,6 +193,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin
|
||||||
} else {
|
} else {
|
||||||
muxWrap.enforceHost = !addr.isWildcardInterface()
|
muxWrap.enforceHost = !addr.isWildcardInterface()
|
||||||
muxWrap.allowedOrigins = admin.allowedOrigins(addr)
|
muxWrap.allowedOrigins = admin.allowedOrigins(addr)
|
||||||
|
muxWrap.enforceOrigin = admin.EnforceOrigin
|
||||||
}
|
}
|
||||||
|
|
||||||
addRouteWithMetrics := func(pattern string, handlerLabel string, h http.Handler) {
|
addRouteWithMetrics := func(pattern string, handlerLabel string, h http.Handler) {
|
||||||
|
@ -252,7 +254,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin
|
||||||
// will be used as the default origin. If admin.Origins is
|
// will be used as the default origin. If admin.Origins is
|
||||||
// empty, no origins will be allowed, effectively bricking the
|
// empty, no origins will be allowed, effectively bricking the
|
||||||
// endpoint for non-unix-socket endpoints, but whatever.
|
// endpoint for non-unix-socket endpoints, but whatever.
|
||||||
func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string {
|
func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []*url.URL {
|
||||||
uniqueOrigins := make(map[string]struct{})
|
uniqueOrigins := make(map[string]struct{})
|
||||||
for _, o := range admin.Origins {
|
for _, o := range admin.Origins {
|
||||||
uniqueOrigins[o] = struct{}{}
|
uniqueOrigins[o] = struct{}{}
|
||||||
|
@ -276,8 +278,23 @@ func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string {
|
||||||
uniqueOrigins[addr.JoinHostPort(0)] = struct{}{}
|
uniqueOrigins[addr.JoinHostPort(0)] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
allowed := make([]string, 0, len(uniqueOrigins))
|
allowed := make([]*url.URL, 0, len(uniqueOrigins))
|
||||||
for origin := range uniqueOrigins {
|
for originStr := range uniqueOrigins {
|
||||||
|
var origin *url.URL
|
||||||
|
if strings.Contains(originStr, "://") {
|
||||||
|
var err error
|
||||||
|
origin, err = url.Parse(originStr)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
origin.Path = ""
|
||||||
|
origin.RawPath = ""
|
||||||
|
origin.Fragment = ""
|
||||||
|
origin.RawFragment = ""
|
||||||
|
origin.RawQuery = ""
|
||||||
|
} else {
|
||||||
|
origin = &url.URL{Host: originStr}
|
||||||
|
}
|
||||||
allowed = append(allowed, origin)
|
allowed = append(allowed, origin)
|
||||||
}
|
}
|
||||||
return allowed
|
return allowed
|
||||||
|
@ -358,7 +375,7 @@ func replaceLocalAdminServer(cfg *Config) error {
|
||||||
adminLogger.Info("admin endpoint started",
|
adminLogger.Info("admin endpoint started",
|
||||||
zap.String("address", addr.String()),
|
zap.String("address", addr.String()),
|
||||||
zap.Bool("enforce_origin", adminConfig.EnforceOrigin),
|
zap.Bool("enforce_origin", adminConfig.EnforceOrigin),
|
||||||
zap.Strings("origins", handler.allowedOrigins))
|
zap.Array("origins", loggableURLArray(handler.allowedOrigins)))
|
||||||
|
|
||||||
if !handler.enforceHost {
|
if !handler.enforceHost {
|
||||||
adminLogger.Warn("admin endpoint on open interface; host checking disabled",
|
adminLogger.Warn("admin endpoint on open interface; host checking disabled",
|
||||||
|
@ -650,10 +667,10 @@ type AdminRoute struct {
|
||||||
type adminHandler struct {
|
type adminHandler struct {
|
||||||
mux *http.ServeMux
|
mux *http.ServeMux
|
||||||
|
|
||||||
// security for local/plaintext) endpoint, on by default
|
// security for local/plaintext endpoint
|
||||||
enforceOrigin bool
|
enforceOrigin bool
|
||||||
enforceHost bool
|
enforceHost bool
|
||||||
allowedOrigins []string
|
allowedOrigins []*url.URL
|
||||||
|
|
||||||
// security for remote/encrypted endpoint
|
// security for remote/encrypted endpoint
|
||||||
remoteControl *RemoteAdmin
|
remoteControl *RemoteAdmin
|
||||||
|
@ -779,8 +796,8 @@ func (h adminHandler) handleError(w http.ResponseWriter, r *http.Request, err er
|
||||||
// rebinding attacks.
|
// rebinding attacks.
|
||||||
func (h adminHandler) checkHost(r *http.Request) error {
|
func (h adminHandler) checkHost(r *http.Request) error {
|
||||||
var allowed bool
|
var allowed bool
|
||||||
for _, allowedHost := range h.allowedOrigins {
|
for _, allowedOrigin := range h.allowedOrigins {
|
||||||
if r.Host == allowedHost {
|
if r.Host == allowedOrigin.Host {
|
||||||
allowed = true
|
allowed = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -799,43 +816,45 @@ func (h adminHandler) checkHost(r *http.Request) error {
|
||||||
// sites from issuing requests to our listener. It
|
// sites from issuing requests to our listener. It
|
||||||
// returns the origin that was obtained from r.
|
// returns the origin that was obtained from r.
|
||||||
func (h adminHandler) checkOrigin(r *http.Request) (string, error) {
|
func (h adminHandler) checkOrigin(r *http.Request) (string, error) {
|
||||||
origin := h.getOriginHost(r)
|
originStr, origin := h.getOrigin(r)
|
||||||
if origin == "" {
|
if origin == nil {
|
||||||
return origin, APIError{
|
return "", APIError{
|
||||||
HTTPStatus: http.StatusForbidden,
|
HTTPStatus: http.StatusForbidden,
|
||||||
Err: fmt.Errorf("missing required Origin header"),
|
Err: fmt.Errorf("required Origin header is missing or invalid"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !h.originAllowed(origin) {
|
if !h.originAllowed(origin) {
|
||||||
return origin, APIError{
|
return "", APIError{
|
||||||
HTTPStatus: http.StatusForbidden,
|
HTTPStatus: http.StatusForbidden,
|
||||||
Err: fmt.Errorf("client is not allowed to access from origin %s", origin),
|
Err: fmt.Errorf("client is not allowed to access from origin '%s'", originStr),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return origin, nil
|
return origin.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h adminHandler) getOriginHost(r *http.Request) string {
|
func (h adminHandler) getOrigin(r *http.Request) (string, *url.URL) {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
if origin == "" {
|
if origin == "" {
|
||||||
origin = r.Header.Get("Referer")
|
origin = r.Header.Get("Referer")
|
||||||
}
|
}
|
||||||
originURL, err := url.Parse(origin)
|
originURL, err := url.Parse(origin)
|
||||||
if err == nil && originURL.Host != "" {
|
if err != nil {
|
||||||
origin = originURL.Host
|
return origin, nil
|
||||||
}
|
}
|
||||||
return origin
|
originURL.Path = ""
|
||||||
|
originURL.RawPath = ""
|
||||||
|
originURL.Fragment = ""
|
||||||
|
originURL.RawFragment = ""
|
||||||
|
originURL.RawQuery = ""
|
||||||
|
return origin, originURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h adminHandler) originAllowed(origin string) bool {
|
func (h adminHandler) originAllowed(origin *url.URL) bool {
|
||||||
for _, allowedOrigin := range h.allowedOrigins {
|
for _, allowedOrigin := range h.allowedOrigins {
|
||||||
originCopy := origin
|
if allowedOrigin.Scheme != "" && origin.Scheme != allowedOrigin.Scheme {
|
||||||
if !strings.Contains(allowedOrigin, "://") {
|
continue
|
||||||
// no scheme specified, so allow both
|
|
||||||
originCopy = strings.TrimPrefix(originCopy, "http://")
|
|
||||||
originCopy = strings.TrimPrefix(originCopy, "https://")
|
|
||||||
}
|
}
|
||||||
if originCopy == allowedOrigin {
|
if origin.Host == allowedOrigin.Host {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1189,6 +1208,18 @@ func decodeBase64DERCert(certStr string) (*x509.Certificate, error) {
|
||||||
return x509.ParseCertificate(derBytes)
|
return x509.ParseCertificate(derBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type loggableURLArray []*url.URL
|
||||||
|
|
||||||
|
func (ua loggableURLArray) MarshalLogArray(enc zapcore.ArrayEncoder) error {
|
||||||
|
if ua == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, u := range ua {
|
||||||
|
enc.AppendString(u.String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// DefaultAdminListen is the address for the local admin
|
// DefaultAdminListen is the address for the local admin
|
||||||
// listener, if none is specified at startup.
|
// listener, if none is specified at startup.
|
||||||
|
|
Loading…
Reference in a new issue