diff --git a/config/setup/rewrite.go b/config/setup/rewrite.go index b86be3036..b510a237b 100644 --- a/config/setup/rewrite.go +++ b/config/setup/rewrite.go @@ -18,23 +18,62 @@ func Rewrite(c *Controller) (middleware.Middleware, error) { } func rewriteParse(c *Controller) ([]rewrite.Rule, error) { - var rewrites []rewrite.Rule + var simpleRules []rewrite.Rule + var regexpRules []rewrite.Rule for c.Next() { var rule rewrite.Rule + var err error + var base = "/" + var pattern, to string + var ext []string - if !c.NextArg() { - return rewrites, c.ArgErr() + args := c.RemainingArgs() + + switch len(args) { + case 2: + rule = rewrite.NewSimpleRule(args[0], args[1]) + simpleRules = append(simpleRules, rule) + case 1: + base = args[0] + fallthrough + case 0: + for c.NextBlock() { + switch c.Val() { + case "r", "regexp": + if !c.NextArg() { + return nil, c.ArgErr() + } + pattern = c.Val() + case "to": + if !c.NextArg() { + return nil, c.ArgErr() + } + to = c.Val() + case "ext": + args1 := c.RemainingArgs() + if len(args1) == 0 { + return nil, c.ArgErr() + } + ext = args1 + default: + return nil, c.ArgErr() + } + } + // ensure pattern and to are specified + if pattern == "" || to == "" { + return nil, c.ArgErr() + } + if rule, err = rewrite.NewRegexpRule(base, pattern, to, ext); err != nil { + return nil, err + } + regexpRules = append(regexpRules, rule) + default: + return nil, c.ArgErr() } - rule.From = c.Val() - if !c.NextArg() { - return rewrites, c.ArgErr() - } - rule.To = c.Val() - - rewrites = append(rewrites, rule) } - return rewrites, nil + // put simple rules in front to avoid regexp computation for them + return append(simpleRules, regexpRules...), nil } diff --git a/config/setup/rewrite_test.go b/config/setup/rewrite_test.go index 17a0e97b5..9ff294ef0 100644 --- a/config/setup/rewrite_test.go +++ b/config/setup/rewrite_test.go @@ -3,7 +3,9 @@ package setup import ( "testing" + "fmt" "github.com/mholt/caddy/middleware/rewrite" + "regexp" ) func TestRewrite(t *testing.T) { @@ -33,27 +35,27 @@ func TestRewrite(t *testing.T) { } func TestRewriteParse(t *testing.T) { - tests := []struct { + simpleTests := []struct { input string shouldErr bool expected []rewrite.Rule }{ {`rewrite /from /to`, false, []rewrite.Rule{ - {From: "/from", To: "/to"}, + rewrite.SimpleRule{"/from", "/to"}, }}, {`rewrite /from /to rewrite a b`, false, []rewrite.Rule{ - {From: "/from", To: "/to"}, - {From: "a", To: "b"}, + rewrite.SimpleRule{"/from", "/to"}, + rewrite.SimpleRule{"a", "b"}, }}, {`rewrite a`, true, []rewrite.Rule{}}, {`rewrite`, true, []rewrite.Rule{}}, {`rewrite a b c`, true, []rewrite.Rule{ - {From: "a", To: "b"}, + rewrite.SimpleRule{"a", "b"}, }}, } - for i, test := range tests { + for i, test := range simpleTests { c := newTestController(test.input) actual, err := rewriteParse(c) @@ -61,6 +63,8 @@ func TestRewriteParse(t *testing.T) { t.Errorf("Test %d didn't error, but it should have", i) } else if err != nil && !test.shouldErr { t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } else if err != nil && test.shouldErr { + continue } if len(actual) != len(test.expected) { @@ -68,8 +72,9 @@ func TestRewriteParse(t *testing.T) { i, len(test.expected), len(actual)) } - for j, expectedRule := range test.expected { - actualRule := actual[j] + for j, e := range test.expected { + actualRule := actual[j].(rewrite.SimpleRule) + expectedRule := e.(rewrite.SimpleRule) if actualRule.From != expectedRule.From { t.Errorf("Test %d, rule %d: Expected From=%s, got %s", @@ -82,4 +87,98 @@ func TestRewriteParse(t *testing.T) { } } } + + regexpTests := []struct { + input string + shouldErr bool + expected []rewrite.Rule + }{ + {`rewrite { + r .* + to /to + }`, false, []rewrite.Rule{ + &rewrite.RegexpRule{"/", "/to", nil, regexp.MustCompile(".*")}, + }}, + {`rewrite { + regexp .* + to /to + ext / html txt + }`, false, []rewrite.Rule{ + &rewrite.RegexpRule{"/", "/to", []string{"/", "html", "txt"}, regexp.MustCompile(".*")}, + }}, + {`rewrite /path { + r rr + to /dest + } + rewrite / { + regexp [a-z]+ + to /to + } + `, false, []rewrite.Rule{ + &rewrite.RegexpRule{"/path", "/dest", nil, regexp.MustCompile("rr")}, + &rewrite.RegexpRule{"/", "/to", nil, regexp.MustCompile("[a-z]+")}, + }}, + {`rewrite { + to /to + }`, true, []rewrite.Rule{ + &rewrite.RegexpRule{}, + }}, + {`rewrite { + r .* + }`, true, []rewrite.Rule{ + &rewrite.RegexpRule{}, + }}, + {`rewrite { + + }`, true, []rewrite.Rule{ + &rewrite.RegexpRule{}, + }}, + {`rewrite /`, true, []rewrite.Rule{ + &rewrite.RegexpRule{}, + }}, + } + + for i, test := range regexpTests { + c := newTestController(test.input) + actual, err := rewriteParse(c) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } else if err != nil && test.shouldErr { + continue + } + + if len(actual) != len(test.expected) { + t.Fatalf("Test %d expected %d rules, but got %d", + i, len(test.expected), len(actual)) + } + + for j, e := range test.expected { + actualRule := actual[j].(*rewrite.RegexpRule) + expectedRule := e.(*rewrite.RegexpRule) + + if actualRule.Base != expectedRule.Base { + t.Errorf("Test %d, rule %d: Expected Base=%s, got %s", + i, j, expectedRule.Base, actualRule.Base) + } + + if actualRule.To != expectedRule.To { + t.Errorf("Test %d, rule %d: Expected To=%s, got %s", + i, j, expectedRule.To, actualRule.To) + } + + if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) { + t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v", + i, j, expectedRule.To, actualRule.To) + } + + if actualRule.String() != expectedRule.String() { + t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", + i, j, expectedRule.String(), actualRule.String()) + } + } + } + } diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index 870438dc2..aefb23295 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -5,7 +5,13 @@ package rewrite import ( "net/http" + "fmt" "github.com/mholt/caddy/middleware" + "net/url" + "path" + "path/filepath" + "regexp" + "strings" ) // Rewrite is middleware to rewrite request locations internally before being handled. @@ -17,15 +23,171 @@ type Rewrite struct { // ServeHTTP implements the middleware.Handler interface. func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { for _, rule := range rw.Rules { - if r.URL.Path == rule.From { - r.URL.Path = rule.To + if ok := rule.Rewrite(r); ok { break } } return rw.Next.ServeHTTP(w, r) } -// A Rule describes an internal location rewrite rule. -type Rule struct { +// Rule describes an internal location rewrite rule. +type Rule interface { + // Rewrite rewrites the internal location of the current request. + Rewrite(*http.Request) bool +} + +// SimpleRule is a simple rewrite rule. +type SimpleRule struct { From, To string } + +// NewSimpleRule creates a new Simple Rule +func NewSimpleRule(from, to string) SimpleRule { + return SimpleRule{from, to} +} + +// Rewrite rewrites the internal location of the current request. +func (s SimpleRule) Rewrite(r *http.Request) bool { + if s.From == r.URL.Path { + r.URL.Path = s.To + return true + } + return false +} + +// RegexpRule is a rewrite rule based on a regular expression +type RegexpRule struct { + // Path base. Request to this path and subpaths will be rewritten + Base string + + // Path to rewrite to + To string + + // Extensions to filter by + Exts []string + + *regexp.Regexp +} + +// NewRegexpRule creates a new RegexpRule. It returns an error if regexp +// pattern (pattern) or extensions (ext) are invalid. +func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) { + r, err := regexp.Compile(pattern) + if err != nil { + return nil, err + } + + // validate extensions + for _, v := range ext { + if len(v) < 2 || (len(v) < 3 && v[0] == '!') { + // check if no extension is specified + if v != "/" && v != "!/" { + return nil, fmt.Errorf("Invalid extension %v", v) + } + } + } + + return &RegexpRule{ + base, + to, + ext, + r, + }, nil +} + +// regexpVars are variables that can be used for To (rewrite destination path). +var regexpVars []string = []string{ + "{path}", + "{query}", + "{file}", + "{dir}", + "{frag}", +} + +// Rewrite rewrites the internal location of the current request. +func (r *RegexpRule) Rewrite(req *http.Request) bool { + rPath := req.URL.Path + + // validate base + if !middleware.Path(rPath).Matches(r.Base) { + return false + } + + // validate extensions + if !r.matchExt(rPath) { + return false + } + + // validate regexp + if !r.MatchString(rPath[len(r.Base):]) { + return false + } + + to := r.To + + // check variables + for _, v := range regexpVars { + if strings.Contains(r.To, v) { + switch v { + case "{path}": + to = strings.Replace(to, v, req.URL.Path[1:], -1) + case "{query}": + to = strings.Replace(to, v, req.URL.RawQuery, -1) + case "{frag}": + to = strings.Replace(to, v, req.URL.Fragment, -1) + case "{file}": + _, file := path.Split(req.URL.Path) + to = strings.Replace(to, v, file, -1) + case "{dir}": + dir, _ := path.Split(req.URL.Path) + to = path.Clean(strings.Replace(to, v, dir, -1)) + } + } + } + + // validate resulting path + url, err := url.Parse(to) + if err != nil { + return false + } + + // perform rewrite + req.URL.Path = url.Path + if url.RawQuery != "" { + // overwrite query string if present + req.URL.RawQuery = url.RawQuery + } + return true +} + +// matchExt matches rPath against registered file extensions. +// Returns true if a match is found and false otherwise. +func (r *RegexpRule) matchExt(rPath string) bool { + f := filepath.Base(rPath) + ext := path.Ext(f) + if ext == "" { + ext = "/" + } + + mustUse := false + for _, v := range r.Exts { + use := true + if v[0] == '!' { + use = false + v = v[1:] + } + + if use { + mustUse = true + } + + if ext == v { + return use + } + } + + if mustUse { + return false + } + return true +} diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index e9793ac13..684e6b213 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -7,16 +7,41 @@ import ( "testing" "github.com/mholt/caddy/middleware" + "strings" ) func TestRewrite(t *testing.T) { rw := Rewrite{ Next: middleware.HandlerFunc(urlPrinter), Rules: []Rule{ - {From: "/from", To: "/to"}, - {From: "/a", To: "/b"}, + NewSimpleRule("/from", "/to"), + NewSimpleRule("/a", "/b"), }, } + + regexpRules := [][]string{ + []string{"/reg/", ".*", "/to", ""}, + []string{"/r/", "[a-z]+", "/toaz", "!.html|"}, + []string{"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""}, + []string{"/ab/", "ab", "/ab?{query}", ".txt|"}, + []string{"/ab/", "ab", "/ab?type=html&{query}", ".html|"}, + []string{"/abc/", "ab", "/abc/{file}", ".html|"}, + []string{"/abcd/", "ab", "/a/{dir}/{file}", ".html|"}, + []string{"/abcde/", "ab", "/a#{frag}", ".html|"}, + } + + for _, regexpRule := range regexpRules { + var ext []string + if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { + ext = s[:len(s)-1] + } + rule, err := NewRegexpRule(regexpRule[0], regexpRule[1], regexpRule[2], ext) + if err != nil { + t.Fatal(err) + } + rw.Rules = append(rw.Rules, rule) + } + tests := []struct { from string expectedTo string @@ -29,6 +54,28 @@ func TestRewrite(t *testing.T) { {"/asdf?foo=bar", "/asdf?foo=bar"}, {"/foo#bar", "/foo#bar"}, {"/a#foo", "/b#foo"}, + {"/reg/foo", "/to"}, + {"/re", "/re"}, + {"/r/", "/r/"}, + {"/r/123", "/r/123"}, + {"/r/a123", "/toaz"}, + {"/r/abcz", "/toaz"}, + {"/r/z", "/toaz"}, + {"/r/z.html", "/r/z.html"}, + {"/r/z.js", "/toaz"}, + {"/url/asAB", "/to/url/asAB"}, + {"/url/aBsAB", "/url/aBsAB"}, + {"/url/a00sAB", "/to/url/a00sAB"}, + {"/url/a0z0sAB", "/to/url/a0z0sAB"}, + {"/ab/aa", "/ab/aa"}, + {"/ab/ab", "/ab/ab"}, + {"/ab/ab.txt", "/ab"}, + {"/ab/ab.txt?name=name", "/ab?name=name"}, + {"/ab/ab.html?name=name", "/ab?type=html&name=name"}, + {"/abc/ab.html", "/abc/ab.html"}, + {"/abcd/abcd.html", "/a/abcd/abcd.html"}, + {"/abcde/abcde.html", "/a"}, + {"/abcde/abcde.html#1234", "/a#1234"}, } for i, test := range tests {