From 47b78714b8afe220642d57b83afd3ec63007f34a Mon Sep 17 00:00:00 2001 From: comp500 Date: Wed, 6 Mar 2019 21:35:07 +0000 Subject: [PATCH] proxy: Change headers using regex (#2144) * Add upstream header replacements (TODO: tests, docs) * Add tests, fix a few bugs * Add more tests and comments * Refactor header_upstream to use a fallthrough; return regex errors --- caddyhttp/proxy/proxy.go | 29 ++++++-- caddyhttp/proxy/proxy_test.go | 48 +++++++++---- caddyhttp/proxy/upstream.go | 111 +++++++++++++++++++++---------- caddyhttp/proxy/upstream_test.go | 55 +++++++++++++++ 4 files changed, 190 insertions(+), 53 deletions(-) diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index fdd642dc8..fe10ea614 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -92,8 +92,10 @@ type UpstreamHost struct { // This is an int32 so that we can use atomic operations to do concurrent // reads & writes to this value. The default value of 0 indicates that it // is healthy and any non-zero value indicates unhealthy. - Unhealthy int32 - HealthCheckResult atomic.Value + Unhealthy int32 + HealthCheckResult atomic.Value + UpstreamHeaderReplacements headerReplacements + DownstreamHeaderReplacements headerReplacements } // Down checks whether the upstream host is down or not. @@ -220,7 +222,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // set headers for request going upstream if host.UpstreamHeaders != nil { // modify headers for request that will be sent to the upstream host - mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer) + mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer, host.UpstreamHeaderReplacements) if hostHeaders, ok := outreq.Header["Host"]; ok && len(hostHeaders) > 0 { outreq.Host = hostHeaders[len(hostHeaders)-1] } @@ -230,7 +232,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // headers coming back downstream var downHeaderUpdateFn respUpdateFn if host.DownstreamHeaders != nil { - downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) + downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer, host.DownstreamHeaderReplacements) } // Before we retry the request we have to make sure @@ -376,13 +378,13 @@ func createUpstreamRequest(rw http.ResponseWriter, r *http.Request) (*http.Reque return outreq, cancel } -func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn { +func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer, replacements headerReplacements) respUpdateFn { return func(resp *http.Response) { - mutateHeadersByRules(resp.Header, rules, replacer) + mutateHeadersByRules(resp.Header, rules, replacer, replacements) } } -func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) { +func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer, replacements headerReplacements) { for ruleField, ruleValues := range rules { if strings.HasPrefix(ruleField, "+") { for _, ruleValue := range ruleValues { @@ -400,6 +402,19 @@ func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) } } } + + for ruleField, ruleValues := range replacements { + for _, ruleValue := range ruleValues { + // Replace variables in replacement string + replacement := repl.Replace(ruleValue.to) + original := headers.Get(ruleField) + if len(replacement) > 0 && len(original) > 0 { + // Replace matches in original string with replacement string + replaced := ruleValue.regexp.ReplaceAllString(original, replacement) + headers.Set(ruleField, replaced) + } + } + } } const CustomStatusContextCancelled = 499 diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 213647af2..07ee4df6e 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -31,6 +31,7 @@ import ( "path" "path/filepath" "reflect" + "regexp" "runtime" "strings" "sync" @@ -724,6 +725,14 @@ func TestUpstreamHeadersUpdate(t *testing.T) { "Clear-Me": {""}, "Host": {"{>Host}"}, } + regex1, _ := regexp.Compile("was originally") + regex2, _ := regexp.Compile("this") + regex3, _ := regexp.Compile("bad") + upstream.host.UpstreamHeaderReplacements = headerReplacements{ + "Regex-Me": {headerReplacement{regex1, "am now"}, headerReplacement{regex2, "that"}}, + "Regexreplace-Me": {headerReplacement{regex3, "{hostname}"}}, + } + // set up proxy p := &Proxy{ Next: httpserver.EmptyNext, // prevents panic in some cases when test fails @@ -740,18 +749,22 @@ func TestUpstreamHeadersUpdate(t *testing.T) { r.Header.Add("Remove-Me", "Remove-Value") r.Header.Add("Replace-Me", "Replace-Value") r.Header.Add("Host", expectHost) + r.Header.Add("Regex-Me", "I was originally this") + r.Header.Add("Regexreplace-Me", "The host is bad") p.ServeHTTP(w, r) replacer := httpserver.NewReplacer(r, nil, "") for headerKey, expect := range map[string][]string{ - "Merge-Me": {"Initial", "Merge-Value"}, - "Add-Me": {"Add-Value"}, - "Add-Empty": nil, - "Remove-Me": nil, - "Replace-Me": {replacer.Replace("{hostname}")}, - "Clear-Me": nil, + "Merge-Me": {"Initial", "Merge-Value"}, + "Add-Me": {"Add-Value"}, + "Add-Empty": nil, + "Remove-Me": nil, + "Replace-Me": {replacer.Replace("{hostname}")}, + "Clear-Me": nil, + "Regex-Me": {"I am now that"}, + "Regexreplace-Me": {"The host is " + replacer.Replace("{hostname}")}, } { if got := actualHeaders[headerKey]; !reflect.DeepEqual(got, expect) { t.Errorf("Upstream request does not contain expected %v header: expect %v, but got %v", @@ -775,6 +788,8 @@ func TestDownstreamHeadersUpdate(t *testing.T) { w.Header().Add("Replace-Me", "Replace-Value") w.Header().Add("Content-Type", "text/html") w.Header().Add("Overwrite-Me", "Overwrite-Value") + w.Header().Add("Regex-Me", "I was originally this") + w.Header().Add("Regexreplace-Me", "The host is bad") w.Write([]byte("Hello, client")) })) defer backend.Close() @@ -786,6 +801,13 @@ func TestDownstreamHeadersUpdate(t *testing.T) { "-Remove-Me": {""}, "Replace-Me": {"{hostname}"}, } + regex1, _ := regexp.Compile("was originally") + regex2, _ := regexp.Compile("this") + regex3, _ := regexp.Compile("bad") + upstream.host.DownstreamHeaderReplacements = headerReplacements{ + "Regex-Me": {headerReplacement{regex1, "am now"}, headerReplacement{regex2, "that"}}, + "Regexreplace-Me": {headerReplacement{regex3, "{hostname}"}}, + } // set up proxy p := &Proxy{ Next: httpserver.EmptyNext, // prevents panic in some cases when test fails @@ -806,12 +828,14 @@ func TestDownstreamHeadersUpdate(t *testing.T) { actualHeaders := w.Header() for headerKey, expect := range map[string][]string{ - "Merge-Me": {"Initial", "Merge-Value"}, - "Add-Me": {"Add-Value"}, - "Remove-Me": nil, - "Replace-Me": {replacer.Replace("{hostname}")}, - "Content-Type": {"text/css"}, - "Overwrite-Me": {"Overwrite-Value"}, + "Merge-Me": {"Initial", "Merge-Value"}, + "Add-Me": {"Add-Value"}, + "Remove-Me": nil, + "Replace-Me": {replacer.Replace("{hostname}")}, + "Content-Type": {"text/css"}, + "Overwrite-Me": {"Overwrite-Value"}, + "Regex-Me": {"I am now that"}, + "Regexreplace-Me": {"The host is " + replacer.Replace("{hostname}")}, } { if got := actualHeaders[headerKey]; !reflect.DeepEqual(got, expect) { t.Errorf("Downstream response does not contain expected %s header: expect %v, but got %v", diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index 8dac144ef..e318c5e88 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -23,8 +23,10 @@ import ( "io/ioutil" "net" "net/http" + "net/textproto" "net/url" "path" + "regexp" "strconv" "strings" "sync" @@ -65,18 +67,39 @@ type staticUpstream struct { Port string ContentString string } - WithoutPathPrefix string - IgnoredSubPaths []string - insecureSkipVerify bool - MaxFails int32 - resolver srvResolver - CaCertPool *x509.CertPool + WithoutPathPrefix string + IgnoredSubPaths []string + insecureSkipVerify bool + MaxFails int32 + resolver srvResolver + CaCertPool *x509.CertPool + upstreamHeaderReplacements headerReplacements + downstreamHeaderReplacements headerReplacements } type srvResolver interface { LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) } +// headerReplacement stores a compiled regex matcher and a string replacer, for replacement rules +type headerReplacement struct { + regexp *regexp.Regexp + to string +} + +// headerReplacements stores a mapping of canonical MIME header to headerReplacement +// Implements a subset of http.Header functions, to allow convenient addition and deletion of rules +type headerReplacements map[string][]headerReplacement + +func (h headerReplacements) Add(key string, value headerReplacement) { + key = textproto.CanonicalMIMEHeaderKey(key) + h[key] = append(h[key], value) +} + +func (h headerReplacements) Del(key string) { + delete(h, textproto.CanonicalMIMEHeaderKey(key)) +} + // NewStaticUpstreams parses the configuration input and sets up // static upstreams for the proxy middleware. The host string parameter, // if not empty, is used for setting the upstream Host header for the @@ -86,18 +109,20 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) for c.Next() { upstream := &staticUpstream{ - from: "", - stop: make(chan struct{}), - upstreamHeaders: make(http.Header), - downstreamHeaders: make(http.Header), - Hosts: nil, - Policy: &Random{}, - MaxFails: 1, - TryInterval: 250 * time.Millisecond, - MaxConns: 0, - KeepAlive: http.DefaultMaxIdleConnsPerHost, - Timeout: 30 * time.Second, - resolver: net.DefaultResolver, + from: "", + stop: make(chan struct{}), + upstreamHeaders: make(http.Header), + downstreamHeaders: make(http.Header), + Hosts: nil, + Policy: &Random{}, + MaxFails: 1, + TryInterval: 250 * time.Millisecond, + MaxConns: 0, + KeepAlive: http.DefaultMaxIdleConnsPerHost, + Timeout: 30 * time.Second, + resolver: net.DefaultResolver, + upstreamHeaderReplacements: make(headerReplacements), + downstreamHeaderReplacements: make(headerReplacements), } if !c.Args(&upstream.from) { @@ -220,9 +245,11 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { return false } }(u), - WithoutPathPrefix: u.WithoutPathPrefix, - MaxConns: u.MaxConns, - HealthCheckResult: atomic.Value{}, + WithoutPathPrefix: u.WithoutPathPrefix, + MaxConns: u.MaxConns, + HealthCheckResult: atomic.Value{}, + UpstreamHeaderReplacements: u.upstreamHeaderReplacements, + DownstreamHeaderReplacements: u.downstreamHeaderReplacements, } baseURL, err := url.Parse(uh.Name) @@ -302,6 +329,8 @@ func parseUpstream(u string) ([]string, error) { } func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { + var isUpstream bool + switch c.Val() { case "policy": if !c.NextArg() { @@ -431,23 +460,37 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { } u.HealthCheck.ContentString = c.Val() case "header_upstream": - var header, value string - if !c.Args(&header, &value) { - // When removing a header, the value can be optional. - if !strings.HasPrefix(header, "-") { - return c.ArgErr() - } - } - u.upstreamHeaders.Add(header, value) + isUpstream = true + fallthrough case "header_downstream": - var header, value string - if !c.Args(&header, &value) { - // When removing a header, the value can be optional. - if !strings.HasPrefix(header, "-") { + var header, value, replaced string + if c.Args(&header, &value, &replaced) { + // Don't allow - or + in replacements + if strings.HasPrefix(header, "-") || strings.HasPrefix(header, "+") { return c.ArgErr() } + r, err := regexp.Compile(value) + if err != nil { + return err + } + if isUpstream { + u.upstreamHeaderReplacements.Add(header, headerReplacement{r, replaced}) + } else { + u.downstreamHeaderReplacements.Add(header, headerReplacement{r, replaced}) + } + } else { + if len(value) == 0 { + // When removing a header, the value can be optional. + if !strings.HasPrefix(header, "-") { + return c.ArgErr() + } + } + if isUpstream { + u.upstreamHeaders.Add(header, value) + } else { + u.downstreamHeaders.Add(header, value) + } } - u.downstreamHeaders.Add(header, value) case "transparent": // Note: X-Forwarded-For header is always being appended for proxy connections // See implementation of createUpstreamRequest in proxy.go diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go index 2e10eb092..d9d566de1 100644 --- a/caddyhttp/proxy/upstream_test.go +++ b/caddyhttp/proxy/upstream_test.go @@ -386,6 +386,61 @@ func TestParseBlockTransparent(t *testing.T) { } } +func TestParseBlockRegex(t *testing.T) { + // tests for regex replacement of headers + r, _ := http.NewRequest("GET", "/", nil) + tests := []struct { + config string + }{ + // Test #1: transparent preset with replacement of Host + {"proxy / localhost:8080 {\n transparent \nheader_upstream Host (.*) NewHost \n}"}, + + // Test #2: transparent preset with replacement of another param + {"proxy / localhost:8080 {\n transparent \nheader_upstream X-Test Tester \nheader_upstream X-Test Test Host \n}"}, + + // Test #3: transparent preset with multiple params + {"proxy / localhost:8080 {\n transparent \nheader_upstream X-Test Tester \nheader_upstream X-Test Test Host \nheader_upstream X-Test er ing \n}"}, + } + + for i, test := range tests { + upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") + if err != nil { + t.Errorf("Expected no error. Got: %s", err.Error()) + } + for _, upstream := range upstreams { + headers := upstream.Select(r).UpstreamHeaderReplacements + + switch i { + case 0: + if host, ok := headers["Host"]; !ok || host[0].to != "NewHost" { + t.Errorf("Test %d: Incorrect Host replacement: %v", i+1, host[0]) + } + case 1: + if v, ok := headers["X-Test"]; !ok { + t.Errorf("Test %d: Incorrect X-Test replacement", i+1) + } else { + if v[0].to != "Host" { + t.Errorf("Test %d: Incorrect X-Test replacement: %v", i+1, v[0]) + } + } + case 2: + if v, ok := headers["X-Test"]; !ok { + t.Errorf("Test %d: Incorrect X-Test replacement", i+1) + } else { + if v[0].to != "Host" { + t.Errorf("Test %d: Incorrect X-Test replacement: %v", i+1, v[0]) + } + if v[1].to != "ing" { + t.Errorf("Test %d: Incorrect X-Test replacement: %v", i+1, v[1]) + } + } + default: + t.Error("Testing error") + } + } + } +} + func TestHealthSetUp(t *testing.T) { // tests for insecure skip verify tests := []struct {