diff --git a/config/setup/rewrite.go b/config/setup/rewrite.go index b86be3036..4829f55cf 100644 --- a/config/setup/rewrite.go +++ b/config/setup/rewrite.go @@ -19,22 +19,60 @@ func Rewrite(c *Controller) (middleware.Middleware, error) { func rewriteParse(c *Controller) ([]rewrite.Rule, error) { var rewrites []rewrite.Rule + var regexps []rewrite.Rule for c.Next() { var rule rewrite.Rule + var err error - if !c.NextArg() { + args := c.RemainingArgs() + + switch len(args) { + case 2: + if args[0] != "regexp" { + rule = rewrite.NewSimpleRule(args[0], args[1]) + rewrites = append(rewrites, rule) + continue + } + + var base = args[1] + var pattern, to string + var ext []string + + for c.NextBlock() { + switch c.Val() { + case "pattern": + if !c.NextArg() { + return rewrites, c.ArgErr() + } + pattern = c.Val() + case "to": + if !c.NextArg() { + return rewrites, c.ArgErr() + } + to = c.Val() + case "ext": + args1 := c.RemainingArgs() + if len(args1) == 0 { + return rewrites, c.ArgErr() + } + ext = args1 + default: + return rewrites, c.ArgErr() + } + } + if pattern == "" || to == "" { + return rewrites, c.ArgErr() + } + if rule, err = rewrite.NewRegexpRule(base, pattern, to, ext); err != nil { + return rewrites, err + } + rewrites = append(regexps, rule) + default: return rewrites, c.ArgErr() } - rule.From = c.Val() - if !c.NextArg() { - return rewrites, c.ArgErr() - } - rule.To = c.Val() - - rewrites = append(rewrites, rule) } - return rewrites, nil + return append(rewrites, regexps...), nil } diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index 870438dc2..0567290c6 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -5,7 +5,13 @@ package rewrite import ( "net/http" + "fmt" "github.com/mholt/caddy/middleware" + "net/url" + "path" + "path/filepath" + "regexp" + "strings" ) // Rewrite is middleware to rewrite request locations internally before being handled. @@ -17,8 +23,7 @@ type Rewrite struct { // ServeHTTP implements the middleware.Handler interface. func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { for _, rule := range rw.Rules { - if r.URL.Path == rule.From { - r.URL.Path = rule.To + if ok := rule.Rewrite(r); ok { break } } @@ -26,6 +31,121 @@ func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) } // A Rule describes an internal location rewrite rule. -type Rule struct { - From, To string +type Rule interface { + Rewrite(*http.Request) bool +} + +type SimpleRule [2]string + +func NewSimpleRule(from, to string) SimpleRule { + return SimpleRule{from, to} +} + +func (s SimpleRule) Rewrite(r *http.Request) bool { + if s[0] == r.URL.Path { + r.URL.Path = s[1] + return true + } + return false +} + +type RegexpRule struct { + base, to string + ext []string + *regexp.Regexp +} + +func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) { + r, err := regexp.Compile(pattern) + if err != nil { + return nil, err + } + + // validate extensions + for _, v := range ext { + if len(v) < 2 || (len(v) < 3 && v[0] == '!') { + // check if it is no extension + if v != "/" && v != "!/" { + return nil, fmt.Errorf("Invalid extension %v", v) + } + } + } + + return &RegexpRule{ + base, + to, + ext, + r, + }, nil +} + +var regexpVars [2]string = [2]string{ + "$path", + "$query", +} + +func (r *RegexpRule) Rewrite(req *http.Request) bool { + rPath := req.URL.Path + if strings.Index(rPath, r.base) != 0 { + return false + } + if !r.matchExt(rPath) { + return false + } + if !r.MatchString(req.URL.Path) { + return false + } + + to := r.to + + // check variables + for _, v := range regexpVars { + if strings.Contains(r.to, v) { + switch v { + case regexpVars[0]: + to = strings.Replace(to, v, req.URL.Path[1:], -1) + case regexpVars[1]: + to = strings.Replace(to, v, req.URL.RawQuery, -1) + } + } + } + + url, err := url.Parse(to) + if err != nil { + return false + } + + req.URL.Path = url.Path + req.URL.RawQuery = url.RawQuery + + return true +} + +func (r *RegexpRule) matchExt(rPath string) bool { + f := filepath.Base(rPath) + ext := path.Ext(f) + if ext == "" { + ext = "/" + } + mustUse := false + for _, v := range r.ext { + use := true + if v[0] == '!' { + use = false + v = v[1:] + } + + if use { + mustUse = true + } + + if ext == v { + return use + } + } + if mustUse { + return false + } + + return true }