diff --git a/caddy/parse/import_glob0.txt b/caddy/parse/import_glob0.txt new file mode 100644 index 000000000..e610b5e7c --- /dev/null +++ b/caddy/parse/import_glob0.txt @@ -0,0 +1,6 @@ +glob0.host0 { + dir2 arg1 +} + +glob0.host1 { +} diff --git a/caddy/parse/import_glob1.txt b/caddy/parse/import_glob1.txt new file mode 100644 index 000000000..111eb044d --- /dev/null +++ b/caddy/parse/import_glob1.txt @@ -0,0 +1,4 @@ +glob1.host0 { + dir1 + dir2 arg1 +} diff --git a/caddy/parse/import_glob2.txt b/caddy/parse/import_glob2.txt new file mode 100644 index 000000000..c09f784ec --- /dev/null +++ b/caddy/parse/import_glob2.txt @@ -0,0 +1,3 @@ +glob2.host0 { + dir2 arg1 +} diff --git a/caddy/parse/parsing.go b/caddy/parse/parsing.go index 03d9d800a..6ef908b0b 100644 --- a/caddy/parse/parsing.go +++ b/caddy/parse/parsing.go @@ -176,19 +176,52 @@ func (p *parser) directives() error { } // doImport swaps out the import directive and its argument -// (a total of 2 tokens) with the tokens in the file specified. -// When the function returns, the cursor is on the token before -// where the import directive was. In other words, call Next() -// to access the first token that was imported. +// (a total of 2 tokens) with the tokens in the specified file +// or globbing pattern. When the function returns, the cursor +// is on the token before where the import directive was. In +// other words, call Next() to access the first token that was +// imported. func (p *parser) doImport() error { if !p.NextArg() { return p.ArgErr() } - importFile := p.Val() + importPattern := p.Val() if p.NextArg() { - return p.Err("Import allows only one file to import") + return p.Err("Import allows only one expression, either file or glob pattern") } + matches, err := filepath.Glob(importPattern) + if err != nil { + return p.Errf("Failed to use import pattern %s - %s", importPattern, err.Error()) + } + + if len(matches) == 0 { + return p.Errf("No files matching the import pattern %s", importPattern) + } + + // Splice out the import directive and its argument (2 tokens total) + // and insert the imported tokens in their place. + tokensBefore := p.tokens[:p.cursor-1] + tokensAfter := p.tokens[p.cursor+1:] + // cursor was advanced one position to read filename; rewind it + p.cursor-- + + p.tokens = tokensBefore + + for _, importFile := range matches { + if err := p.doSingleImport(importFile); err != nil { + return err + } + } + + p.tokens = append(p.tokens, append(tokensAfter)...) + + return nil +} + +// doSingleImport lexes the individual files matching the +// globbing pattern from of the import directive. +func (p *parser) doSingleImport(importFile string) error { file, err := os.Open(importFile) if err != nil { return p.Errf("Could not import %s - %v", importFile, err) @@ -203,10 +236,7 @@ func (p *parser) doImport() error { // Splice out the import directive and its argument (2 tokens total) // and insert the imported tokens in their place. - tokensBefore := p.tokens[:p.cursor-1] - tokensAfter := p.tokens[p.cursor+1:] - p.tokens = append(tokensBefore, append(importedTokens, tokensAfter...)...) - p.cursor-- // cursor was advanced one position to read the filename; rewind it + p.tokens = append(p.tokens, append(importedTokens)...) return nil } diff --git a/caddy/parse/parsing_test.go b/caddy/parse/parsing_test.go index 97c86808a..bda6b29bc 100644 --- a/caddy/parse/parsing_test.go +++ b/caddy/parse/parsing_test.go @@ -329,6 +329,13 @@ func TestParseAll(t *testing.T) { []address{{"host1.com", "http"}, {"host2.com", "http"}}, []address{{"host3.com", "https"}, {"host4.com", "https"}}, }}, + + {`import import_glob*.txt`, false, [][]address{ + []address{{"glob0.host0", ""}}, + []address{{"glob0.host1", ""}}, + []address{{"glob1.host0", ""}}, + []address{{"glob2.host0", ""}}, + }}, } { p := testParser(test.input) blocks, err := p.parseAll() diff --git a/caddy/setup/rewrite.go b/caddy/setup/rewrite.go index b510a237b..4c84cb5fd 100644 --- a/caddy/setup/rewrite.go +++ b/caddy/setup/rewrite.go @@ -1,6 +1,9 @@ package setup import ( + "net/http" + "strings" + "github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware/rewrite" ) @@ -13,7 +16,11 @@ func Rewrite(c *Controller) (middleware.Middleware, error) { } return func(next middleware.Handler) middleware.Handler { - return rewrite.Rewrite{Next: next, Rules: rewrites} + return rewrite.Rewrite{ + Next: next, + FileSys: http.Dir(c.Root), + Rules: rewrites, + } }, nil } @@ -30,6 +37,8 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { args := c.RemainingArgs() + var ifs []rewrite.If + switch len(args) { case 2: rule = rewrite.NewSimpleRule(args[0], args[1]) @@ -46,25 +55,36 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { } pattern = c.Val() case "to": - if !c.NextArg() { + args1 := c.RemainingArgs() + if len(args1) == 0 { return nil, c.ArgErr() } - to = c.Val() + to = strings.Join(args1, " ") case "ext": args1 := c.RemainingArgs() if len(args1) == 0 { return nil, c.ArgErr() } ext = args1 + case "if": + args1 := c.RemainingArgs() + if len(args1) != 3 { + return nil, c.ArgErr() + } + ifCond, err := rewrite.NewIf(args1[0], args1[1], args1[2]) + if err != nil { + return nil, err + } + ifs = append(ifs, ifCond) default: return nil, c.ArgErr() } } - // ensure pattern and to are specified - if pattern == "" || to == "" { + // ensure to is specified + if to == "" { return nil, c.ArgErr() } - if rule, err = rewrite.NewRegexpRule(base, pattern, to, ext); err != nil { + if rule, err = rewrite.NewComplexRule(base, pattern, to, ext, ifs); err != nil { return nil, err } regexpRules = append(regexpRules, rule) diff --git a/caddy/setup/rewrite_test.go b/caddy/setup/rewrite_test.go index f54266788..c43818b2d 100644 --- a/caddy/setup/rewrite_test.go +++ b/caddy/setup/rewrite_test.go @@ -1,10 +1,9 @@ package setup import ( - "testing" - "fmt" "regexp" + "testing" "github.com/mholt/caddy/middleware/rewrite" ) @@ -96,16 +95,16 @@ func TestRewriteParse(t *testing.T) { }{ {`rewrite { r .* - to /to + to /to /index.php? }`, false, []rewrite.Rule{ - &rewrite.RegexpRule{Base: "/", To: "/to", Regexp: regexp.MustCompile(".*")}, + &rewrite.ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")}, }}, {`rewrite { regexp .* to /to ext / html txt }`, false, []rewrite.Rule{ - &rewrite.RegexpRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")}, + &rewrite.ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")}, }}, {`rewrite /path { r rr @@ -113,29 +112,30 @@ func TestRewriteParse(t *testing.T) { } rewrite / { regexp [a-z]+ - to /to + to /to /to2 } `, false, []rewrite.Rule{ - &rewrite.RegexpRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")}, - &rewrite.RegexpRule{Base: "/", To: "/to", Regexp: regexp.MustCompile("[a-z]+")}, - }}, - {`rewrite { - to /to - }`, true, []rewrite.Rule{ - &rewrite.RegexpRule{}, + &rewrite.ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")}, + &rewrite.ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")}, }}, {`rewrite { r .* }`, true, []rewrite.Rule{ - &rewrite.RegexpRule{}, + &rewrite.ComplexRule{}, }}, {`rewrite { }`, true, []rewrite.Rule{ - &rewrite.RegexpRule{}, + &rewrite.ComplexRule{}, }}, {`rewrite /`, true, []rewrite.Rule{ - &rewrite.RegexpRule{}, + &rewrite.ComplexRule{}, + }}, + {`rewrite { + to /to + if {path} is a + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{rewrite.If{A: "{path}", Operator: "is", B: "a"}}}, }}, } @@ -157,8 +157,8 @@ func TestRewriteParse(t *testing.T) { } for j, e := range test.expected { - actualRule := actual[j].(*rewrite.RegexpRule) - expectedRule := e.(*rewrite.RegexpRule) + actualRule := actual[j].(*rewrite.ComplexRule) + expectedRule := e.(*rewrite.ComplexRule) if actualRule.Base != expectedRule.Base { t.Errorf("Test %d, rule %d: Expected Base=%s, got %s", @@ -175,10 +175,18 @@ func TestRewriteParse(t *testing.T) { i, j, expectedRule.To, actualRule.To) } - if actualRule.String() != expectedRule.String() { - t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", - i, j, expectedRule.String(), actualRule.String()) + if actualRule.Regexp != nil { + if actualRule.String() != expectedRule.String() { + t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", + i, j, expectedRule.String(), actualRule.String()) + } } + + if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) { + t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", + i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs)) + } + } } diff --git a/middleware/markdown/generator.go b/middleware/markdown/generator.go index d3485e918..d218f22b4 100644 --- a/middleware/markdown/generator.go +++ b/middleware/markdown/generator.go @@ -5,6 +5,8 @@ import ( "encoding/hex" "fmt" "io/ioutil" + "net/http" + "net/url" "os" "path/filepath" "strings" @@ -103,8 +105,12 @@ func generateStaticHTML(md Markdown, cfg *Config) error { reqPath = filepath.ToSlash(reqPath) reqPath = "/" + reqPath + // Create empty requests and url to cater for template values. + req, _ := http.NewRequest("", "/", nil) + urlVar, _ := url.Parse("/") + // Generate the static file - ctx := middleware.Context{Root: md.FileSys} + ctx := middleware.Context{Root: md.FileSys, Req: req, URL: urlVar} _, err = md.Process(cfg, reqPath, body, ctx) if err != nil { return err diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go new file mode 100644 index 000000000..c863af4f0 --- /dev/null +++ b/middleware/rewrite/condition.go @@ -0,0 +1,111 @@ +package rewrite + +import ( + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/mholt/caddy/middleware" +) + +const ( + // Operators + Is = "is" + Not = "not" + Has = "has" + StartsWith = "starts_with" + EndsWith = "ends_with" + Match = "match" +) + +func operatorError(operator string) error { + return fmt.Errorf("Invalid operator %v", operator) +} + +func newReplacer(r *http.Request) middleware.Replacer { + return middleware.NewReplacer(r, nil, "") +} + +// condition is a rewrite condition. +type condition func(string, string) bool + +var conditions = map[string]condition{ + Is: isFunc, + Not: notFunc, + Has: hasFunc, + StartsWith: startsWithFunc, + EndsWith: endsWithFunc, + Match: matchFunc, +} + +// isFunc is condition for Is operator. +// It checks for equality. +func isFunc(a, b string) bool { + return a == b +} + +// notFunc is condition for Not operator. +// It checks for inequality. +func notFunc(a, b string) bool { + return a != b +} + +// hasFunc is condition for Has operator. +// It checks if b is a substring of a. +func hasFunc(a, b string) bool { + return strings.Contains(a, b) +} + +// startsWithFunc is condition for StartsWith operator. +// It checks if b is a prefix of a. +func startsWithFunc(a, b string) bool { + return strings.HasPrefix(a, b) +} + +// endsWithFunc is condition for EndsWith operator. +// It checks if b is a suffix of a. +func endsWithFunc(a, b string) bool { + return strings.HasSuffix(a, b) +} + +// matchFunc is condition for Match operator. +// It does regexp matching of a against pattern in b +func matchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return matched +} + +// If is statement for a rewrite condition. +type If struct { + A string + Operator string + B string +} + +// True returns true if the condition is true and false otherwise. +// If r is not nil, it replaces placeholders before comparison. +func (i If) True(r *http.Request) bool { + if c, ok := conditions[i.Operator]; ok { + a, b := i.A, i.B + if r != nil { + replacer := newReplacer(r) + a = replacer.Replace(i.A) + b = replacer.Replace(i.B) + } + return c(a, b) + } + return false +} + +// NewIf creates a new If condition. +func NewIf(a, operator, b string) (If, error) { + if _, ok := conditions[operator]; !ok { + return If{}, operatorError(operator) + } + return If{ + A: a, + Operator: operator, + B: b, + }, nil +} diff --git a/middleware/rewrite/condition_test.go b/middleware/rewrite/condition_test.go new file mode 100644 index 000000000..c056f8964 --- /dev/null +++ b/middleware/rewrite/condition_test.go @@ -0,0 +1,90 @@ +package rewrite + +import ( + "net/http" + "strings" + "testing" +) + +func TestConditions(t *testing.T) { + tests := []struct { + condition string + isTrue bool + }{ + {"a is b", false}, + {"a is a", true}, + {"a not b", true}, + {"a not a", false}, + {"a has a", true}, + {"a has b", false}, + {"ba has b", true}, + {"bab has b", true}, + {"bab has bb", false}, + {"bab starts_with bb", false}, + {"bab starts_with ba", true}, + {"bab starts_with bab", true}, + {"bab ends_with bb", false}, + {"bab ends_with bab", true}, + {"bab ends_with ab", true}, + {"a match *", false}, + {"a match a", true}, + {"a match .*", true}, + {"a match a.*", true}, + {"a match b.*", false}, + {"ba match b.*", true}, + {"ba match b[a-z]", true}, + {"b0 match b[a-z]", false}, + {"b0a match b[a-z]", false}, + {"b0a match b[a-z]+", false}, + {"b0a match b[a-z0-9]+", true}, + } + + for i, test := range tests { + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(nil) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } + + invalidOperators := []string{"ss", "and", "if"} + for _, op := range invalidOperators { + _, err := NewIf("a", op, "b") + if err == nil { + t.Errorf("Invalid operator %v used, expected error.", op) + } + } + + replaceTests := []struct { + url string + condition string + isTrue bool + }{ + {"/home", "{uri} match /home", true}, + {"/hom", "{uri} match /home", false}, + {"/hom", "{uri} starts_with /home", false}, + {"/hom", "{uri} starts_with /h", true}, + {"/home/.hiddenfile", `{uri} match \/\.(.*)`, true}, + {"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true}, + } + + for i, test := range replaceTests { + r, err := http.NewRequest("GET", test.url, nil) + if err != nil { + t.Error(err) + } + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(r) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } +} diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index 88944f73d..f173e7547 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -5,7 +5,6 @@ package rewrite import ( "fmt" "net/http" - "net/url" "path" "path/filepath" "regexp" @@ -16,14 +15,15 @@ import ( // Rewrite is middleware to rewrite request locations internally before being handled. type Rewrite struct { - Next middleware.Handler - Rules []Rule + Next middleware.Handler + FileSys http.FileSystem + Rules []Rule } // ServeHTTP implements the middleware.Handler interface. func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { for _, rule := range rw.Rules { - if ok := rule.Rewrite(r); ok { + if ok := rule.Rewrite(rw.FileSys, r); ok { break } } @@ -33,7 +33,7 @@ func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) // Rule describes an internal location rewrite rule. type Rule interface { // Rewrite rewrites the internal location of the current request. - Rewrite(*http.Request) bool + Rewrite(http.FileSystem, *http.Request) bool } // SimpleRule is a simple rewrite rule. @@ -47,23 +47,20 @@ func NewSimpleRule(from, to string) SimpleRule { } // Rewrite rewrites the internal location of the current request. -func (s SimpleRule) Rewrite(r *http.Request) bool { +func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) bool { if s.From == r.URL.Path { // take note of this rewrite for internal use by fastcgi // all we need is the URI, not full URL r.Header.Set(headerFieldName, r.URL.RequestURI()) - // replace variables - to := path.Clean(middleware.NewReplacer(r, nil, "").Replace(s.To)) - - r.URL.Path = to - return true + // attempt rewrite + return To(fs, r, s.To) } return false } -// RegexpRule is a rewrite rule based on a regular expression -type RegexpRule struct { +// ComplexRule is a rewrite rule based on a regular expression +type ComplexRule struct { // Path base. Request to this path and subpaths will be rewritten Base string @@ -73,18 +70,26 @@ type RegexpRule struct { // Extensions to filter by Exts []string + // Rewrite conditions + Ifs []If + *regexp.Regexp } // NewRegexpRule creates a new RegexpRule. It returns an error if regexp // pattern (pattern) or extensions (ext) are invalid. -func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) { - r, err := regexp.Compile(pattern) - if err != nil { - return nil, err +func NewComplexRule(base, pattern, to string, ext []string, ifs []If) (*ComplexRule, error) { + // validate regexp if present + var r *regexp.Regexp + if pattern != "" { + var err error + r, err = regexp.Compile(pattern) + if err != nil { + return nil, err + } } - // validate extensions + // validate extensions if present for _, v := range ext { if len(v) < 2 || (len(v) < 3 && v[0] == '!') { // check if no extension is specified @@ -94,16 +99,17 @@ func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) } } - return &RegexpRule{ - base, - to, - ext, - r, + return &ComplexRule{ + Base: base, + To: to, + Exts: ext, + Ifs: ifs, + Regexp: r, }, nil } // Rewrite rewrites the internal location of the current request. -func (r *RegexpRule) Rewrite(req *http.Request) bool { +func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) bool { rPath := req.URL.Path // validate base @@ -122,36 +128,27 @@ func (r *RegexpRule) Rewrite(req *http.Request) bool { start-- } - // validate regexp - if !r.MatchString(rPath[start:]) { - return false + // validate regexp if present + if r.Regexp != nil { + if !r.MatchString(rPath[start:]) { + return false + } } - // replace variables - to := path.Clean(middleware.NewReplacer(req, nil, "").Replace(r.To)) - - // validate resulting path - url, err := url.Parse(to) - if err != nil { - return false + // validate rewrite conditions + for _, i := range r.Ifs { + if !i.True(req) { + return false + } } - // take note of this rewrite for internal use by fastcgi - // all we need is the URI, not full URL - req.Header.Set(headerFieldName, req.URL.RequestURI()) - - // perform rewrite - req.URL.Path = url.Path - if url.RawQuery != "" { - // overwrite query string if present - req.URL.RawQuery = url.RawQuery - } - return true + // attempt rewrite + return To(fs, req, r.To) } // matchExt matches rPath against registered file extensions. // Returns true if a match is found and false otherwise. -func (r *RegexpRule) matchExt(rPath string) bool { +func (r *ComplexRule) matchExt(rPath string) bool { f := filepath.Base(rPath) ext := path.Ext(f) if ext == "" { diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index fb0470262..a538b79d8 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -4,9 +4,8 @@ import ( "fmt" "net/http" "net/http/httptest" - "testing" - "strings" + "testing" "github.com/mholt/caddy/middleware" ) @@ -19,9 +18,10 @@ func TestRewrite(t *testing.T) { NewSimpleRule("/a", "/b"), NewSimpleRule("/b", "/b{uri}"), }, + FileSys: http.Dir("."), } - regexpRules := [][]string{ + regexps := [][]string{ {"/reg/", ".*", "/to", ""}, {"/r/", "[a-z]+", "/toaz", "!.html|"}, {"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""}, @@ -33,12 +33,12 @@ func TestRewrite(t *testing.T) { {"/ab/", `.*\.jpg`, "/ajpg", ""}, } - for _, regexpRule := range regexpRules { + for _, regexpRule := range regexps { var ext []string if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { ext = s[:len(s)-1] } - rule, err := NewRegexpRule(regexpRule[0], regexpRule[1], regexpRule[2], ext) + rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], ext, nil) if err != nil { t.Fatal(err) } diff --git a/middleware/rewrite/testdata/testdir/empty b/middleware/rewrite/testdata/testdir/empty new file mode 100644 index 000000000..e69de29bb diff --git a/middleware/rewrite/testdata/testfile b/middleware/rewrite/testdata/testfile new file mode 100644 index 000000000..7b4d68d70 --- /dev/null +++ b/middleware/rewrite/testdata/testfile @@ -0,0 +1 @@ +empty \ No newline at end of file diff --git a/middleware/rewrite/to.go b/middleware/rewrite/to.go new file mode 100644 index 000000000..d6c7f5210 --- /dev/null +++ b/middleware/rewrite/to.go @@ -0,0 +1,86 @@ +package rewrite + +import ( + "log" + "net/http" + "net/url" + "path" + "strings" +) + +// To attempts rewrite. It attempts to rewrite to first valid path +// or the last path if none of the paths are valid. +// Returns true if rewrite is successful and false otherwise. +func To(fs http.FileSystem, r *http.Request, to string) bool { + tos := strings.Fields(to) + replacer := newReplacer(r) + + // try each rewrite paths + t := "" + for _, v := range tos { + t = path.Clean(replacer.Replace(v)) + + // add trailing slash for directories, if present + if strings.HasSuffix(v, "/") && !strings.HasSuffix(t, "/") { + t += "/" + } + + // validate file + if isValidFile(fs, t) { + break + } + } + + // validate resulting path + u, err := url.Parse(t) + if err != nil { + // Let the user know we got here. Rewrite is expected but + // the resulting url is invalid. + log.Printf("[ERROR] rewrite: resulting path '%v' is invalid. error: %v", t, err) + return false + } + + // take note of this rewrite for internal use by fastcgi + // all we need is the URI, not full URL + r.Header.Set(headerFieldName, r.URL.RequestURI()) + + // perform rewrite + r.URL.Path = u.Path + if u.RawQuery != "" { + // overwrite query string if present + r.URL.RawQuery = u.RawQuery + } + if u.Fragment != "" { + // overwrite fragment if present + r.URL.Fragment = u.Fragment + } + + return true +} + +// isValidFile checks if file exists on the filesystem. +// if file ends with `/`, it is validated as a directory. +func isValidFile(fs http.FileSystem, file string) bool { + if fs == nil { + return false + } + + f, err := fs.Open(file) + if err != nil { + return false + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return false + } + + // directory + if strings.HasSuffix(file, "/") { + return stat.IsDir() + } + + // file + return !stat.IsDir() +} diff --git a/middleware/rewrite/to_test.go b/middleware/rewrite/to_test.go new file mode 100644 index 000000000..2d8b535ac --- /dev/null +++ b/middleware/rewrite/to_test.go @@ -0,0 +1,44 @@ +package rewrite + +import ( + "net/http" + "net/url" + "testing" +) + +func TestTo(t *testing.T) { + fs := http.Dir("testdata") + tests := []struct { + url string + to string + expected string + }{ + {"/", "/somefiles", "/somefiles"}, + {"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"}, + {"/somefiles", "/testfile /index.php{uri}", "/testfile"}, + {"/somefiles", "/testfile/ /index.php{uri}", "/index.php/somefiles"}, + {"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"}, + {"/?a=b", "/somefiles /index.php?{query}", "/index.php?a=b"}, + {"/?a=b", "/testfile /index.php?{query}", "/testfile?a=b"}, + {"/?a=b", "/testdir /index.php?{query}", "/index.php?a=b"}, + {"/?a=b", "/testdir/ /index.php?{query}", "/testdir/?a=b"}, + } + + uri := func(r *url.URL) string { + uri := r.Path + if r.RawQuery != "" { + uri += "?" + r.RawQuery + } + return uri + } + for i, test := range tests { + r, err := http.NewRequest("GET", test.url, nil) + if err != nil { + t.Error(err) + } + To(fs, r, test.to) + if uri(r.URL) != test.expected { + t.Errorf("Test %v: expected %v found %v", i, test.expected, uri(r.URL)) + } + } +}