From 4a4b80450a7a65b141057440d3d4427200487e34 Mon Sep 17 00:00:00 2001 From: Nimi Wariboko Jr Date: Sat, 2 May 2015 22:20:36 -0700 Subject: [PATCH] Upgrade proxy middleware. Add support for: multiple backends, load balancing, health checks, and pluggable backends --- middleware/proxy/policy.go | 91 +++++++++++++ middleware/proxy/policy_test.go | 57 ++++++++ middleware/proxy/proxy.go | 210 +++++++++++++++++++++-------- middleware/proxy/reverseproxy.go | 215 ++++++++++++++++++++++++++++++ middleware/proxy/upstream.go | 203 ++++++++++++++++++++++++++++ middleware/proxy/upstream_test.go | 43 ++++++ 6 files changed, 763 insertions(+), 56 deletions(-) create mode 100644 middleware/proxy/policy.go create mode 100644 middleware/proxy/policy_test.go create mode 100644 middleware/proxy/reverseproxy.go create mode 100644 middleware/proxy/upstream.go create mode 100644 middleware/proxy/upstream_test.go diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go new file mode 100644 index 000000000..806055e28 --- /dev/null +++ b/middleware/proxy/policy.go @@ -0,0 +1,91 @@ +package proxy + +import ( + "math/rand" + "sync/atomic" +) + +type HostPool []*UpstreamHost + +// Policy decides how a host will be selected from a pool. +type Policy interface { + Select(pool HostPool) *UpstreamHost +} + +// The random policy randomly selected an up host from the pool. +type Random struct{} + +func (r *Random) Select(pool HostPool) *UpstreamHost { + // instead of just generating a random index + // this is done to prevent selecting a down host + var randHost *UpstreamHost + count := 0 + for _, host := range pool { + if host.Down() { + continue + } + count++ + if count == 1 { + randHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + randHost = host + } + } + } + return randHost +} + +// The least_conn policy selects a host with the least connections. +// If multiple hosts have the least amount of connections, one is randomly +// chosen. +type LeastConn struct{} + +func (r *LeastConn) Select(pool HostPool) *UpstreamHost { + var bestHost *UpstreamHost + count := 0 + leastConn := int64(1<<63 - 1) + for _, host := range pool { + if host.Down() { + continue + } + hostConns := host.Conns + if hostConns < leastConn { + bestHost = host + leastConn = hostConns + count = 1 + } else if hostConns == leastConn { + // randomly select host among hosts with least connections + count++ + if count == 1 { + bestHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + bestHost = host + } + } + } + } + return bestHost +} + +// The round_robin policy selects a host based on round robin ordering. +type RoundRobin struct { + Robin uint32 +} + +func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { + poolLen := uint32(len(pool)) + selection := atomic.AddUint32(&r.Robin, 1) % poolLen + host := pool[selection] + // if the currently selected host is down, just ffwd to up host + for i := uint32(1); host.Down() && i < poolLen; i++ { + host = pool[(selection+i)%poolLen] + } + if host.Down() { + return nil + } + return host +} diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go new file mode 100644 index 000000000..11269a4f2 --- /dev/null +++ b/middleware/proxy/policy_test.go @@ -0,0 +1,57 @@ +package proxy + +import ( + "testing" +) + +func testPool() HostPool { + pool := []*UpstreamHost{ + &UpstreamHost{ + Name: "http://google.com", // this should resolve (healthcheck test) + }, + &UpstreamHost{ + Name: "http://shouldnot.resolve", // this shouldn't + }, + &UpstreamHost{ + Name: "http://C", + }, + } + return HostPool(pool) +} + +func TestRoundRobinPolicy(t *testing.T) { + pool := testPool() + rrPolicy := &RoundRobin{} + h := rrPolicy.Select(pool) + // First selected host is 1, because counter starts at 0 + // and increments before host is selected + if h != pool[1] { + t.Error("Expected first round robin host to be second host in the pool.") + } + h = rrPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected second round robin host to be third host in the pool.") + } + // mark host as down + pool[0].Unhealthy = true + h = rrPolicy.Select(pool) + if h != pool[1] { + t.Error("Expected third round robin host to be first host in the pool.") + } +} + +func TestLeastConnPolicy(t *testing.T) { + pool := testPool() + lcPolicy := &LeastConn{} + pool[0].Conns = 10 + pool[1].Conns = 10 + h := lcPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected least connection host to be third host.") + } + pool[2].Conns = 100 + h = lcPolicy.Select(pool) + if h != pool[0] && h != pool[1] { + t.Error("Expected least connection host to be first or second host.") + } +} diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 5c1c56fd3..cc3b068da 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -2,51 +2,168 @@ package proxy import ( - "net/http" - "net/http/httputil" - "net/url" - "strings" - + "errors" "github.com/mholt/caddy/middleware" + "net" + "net/http" + "net/url" + "regexp" + "strings" + "sync/atomic" + "time" ) +var errUnreachable = errors.New("Unreachable backend") + // Proxy represents a middleware instance that can proxy requests. type Proxy struct { - Next middleware.Handler - Rules []Rule + Next middleware.Handler + Upstreams []Upstream +} + +// An upstream manages a pool of proxy upstream hosts. Select should return a +// suitable upstream host, or nil if no such hosts are available. +type Upstream interface { + // The path this upstream host should be routed on + From() string + // Selects an upstream host to be routed to. + Select() *UpstreamHost +} + +type UpstreamHostDownFunc func(*UpstreamHost) bool + +// An UpstreamHost represents a single proxy upstream +type UpstreamHost struct { + Name string + ReverseProxy *ReverseProxy + Conns int64 + Fails int32 + FailTimeout time.Duration + Unhealthy bool + ExtraHeaders http.Header + CheckDown UpstreamHostDownFunc +} + +func (uh *UpstreamHost) Down() bool { + if uh.CheckDown == nil { + // Default settings + return uh.Unhealthy || uh.Fails > 0 + } + return uh.CheckDown(uh) +} + +//https://github.com/mgutz/str +var tRe = regexp.MustCompile(`([\-\[\]()*\s])`) +var tRe2 = regexp.MustCompile(`\$`) +var openDelim = tRe2.ReplaceAllString(tRe.ReplaceAllString("{{", "\\$1"), "\\$") +var closDelim = tRe2.ReplaceAllString(tRe.ReplaceAllString("}}", "\\$1"), "\\$") +var templateDelim = regexp.MustCompile(openDelim + `(.+?)` + closDelim) + +type requestVars struct { + Host string + RemoteIp string + Scheme string + Upstream string + UpstreamHost string +} + +func templateWithDelimiters(s string, vars requestVars) string { + matches := templateDelim.FindAllStringSubmatch(s, -1) + for _, submatches := range matches { + match := submatches[0] + key := submatches[1] + found := true + repl := "" + switch key { + case "http_host": + repl = vars.Host + case "remote_addr": + repl = vars.RemoteIp + case "scheme": + repl = vars.Scheme + case "upstream": + repl = vars.Upstream + case "upstream_host": + repl = vars.UpstreamHost + default: + found = false + } + if found { + s = strings.Replace(s, match, repl, -1) + } + } + return s } // ServeHTTP satisfies the middleware.Handler interface. func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - for _, rule := range p.Rules { - if middleware.Path(r.URL.Path).Matches(rule.From) { - var base string - - if strings.HasPrefix(rule.To, "http") { // includes https - // destination includes a scheme! no need to guess - base = rule.To - } else { - // no scheme specified; assume same as request - var scheme string - if r.TLS == nil { - scheme = "http" - } else { - scheme = "https" + for _, upstream := range p.Upstreams { + if middleware.Path(r.URL.Path).Matches(upstream.From()) { + vars := requestVars{ + Host: r.Host, + Scheme: "http", + } + if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + vars.RemoteIp = clientIP + } + if fFor := r.Header.Get("X-Forwarded-For"); fFor != "" { + vars.RemoteIp = fFor + } + if r.TLS != nil { + vars.Scheme = "https" + } + // Since Select() should give us "up" hosts, keep retrying + // hosts until timeout (or until we get a nil host). + start := time.Now() + for time.Now().Sub(start) < (60 * time.Second) { + host := upstream.Select() + if host == nil { + return http.StatusBadGateway, errUnreachable } - base = scheme + "://" + rule.To - } + proxy := host.ReverseProxy + vars.Upstream = host.Name + r.Host = host.Name - baseUrl, err := url.Parse(base) - if err != nil { - return http.StatusInternalServerError, err - } - r.Host = baseUrl.Host + if baseUrl, err := url.Parse(host.Name); err == nil { + vars.UpstreamHost = baseUrl.Host + if proxy == nil { + proxy = NewSingleHostReverseProxy(baseUrl) + } + } else if proxy == nil { + return http.StatusInternalServerError, err + } + var extraHeaders http.Header + if host.ExtraHeaders != nil { + extraHeaders = make(http.Header) + for header, values := range host.ExtraHeaders { + for _, value := range values { + extraHeaders.Add(header, + templateWithDelimiters(value, vars)) + if header == "Host" { + r.Host = templateWithDelimiters(value, vars) + } + } + } + } - // TODO: Construct this before; not during every request, if possible - proxy := httputil.NewSingleHostReverseProxy(baseUrl) - proxy.ServeHTTP(w, r) - return 0, nil + atomic.AddInt64(&host.Conns, 1) + backendErr := proxy.ServeHTTP(w, r, extraHeaders) + atomic.AddInt64(&host.Conns, -1) + if backendErr == nil { + return 0, nil + } + timeout := host.FailTimeout + if timeout == 0 { + timeout = 10 * time.Second + } + atomic.AddInt32(&host.Fails, 1) + go func(host *UpstreamHost, timeout time.Duration) { + time.Sleep(timeout) + atomic.AddInt32(&host.Fails, -1) + }(host, timeout) + } + return http.StatusBadGateway, errUnreachable } } @@ -55,30 +172,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // New creates a new instance of proxy middleware. func New(c middleware.Controller) (middleware.Middleware, error) { - rules, err := parse(c) - if err != nil { + if upstreams, err := newStaticUpstreams(c); err == nil { + return func(next middleware.Handler) middleware.Handler { + return Proxy{Next: next, Upstreams: upstreams} + }, nil + } else { return nil, err } - - return func(next middleware.Handler) middleware.Handler { - return Proxy{Next: next, Rules: rules} - }, nil -} - -func parse(c middleware.Controller) ([]Rule, error) { - var rules []Rule - - for c.Next() { - var rule Rule - if !c.Args(&rule.From, &rule.To) { - return rules, c.ArgErr() - } - rules = append(rules, rule) - } - - return rules, nil -} - -type Rule struct { - From, To string } diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go new file mode 100644 index 000000000..027f2266c --- /dev/null +++ b/middleware/proxy/reverseproxy.go @@ -0,0 +1,215 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package proxy + +import ( + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// onExitFlushLoop is a callback set by tests to detect the state of the +// flushLoop() goroutine. +var onExitFlushLoop func() + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + FlushInterval time.Duration +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + targetQuery := target.RawQuery + director := func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + } + return &ReverseProxy{Director: director} +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extraHeaders http.Header) error { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + outreq := new(http.Request) + *outreq = *req // includes shallow copies of maps, but okay + + p.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. This + // is modifying the same underlying map from req (shallow + // copied above) so we only copy it if necessary. + copiedHeaders := false + for _, h := range hopHeaders { + if outreq.Header.Get(h) != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + copiedHeaders = true + } + outreq.Header.Del(h) + } + } + + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) + } + + if extraHeaders != nil { + for k, v := range extraHeaders { + outreq.Header[k] = v + } + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + return err + } + defer res.Body.Close() + + for _, h := range hopHeaders { + res.Header.Del(h) + } + + copyHeader(rw.Header(), res.Header) + + rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) + return nil +} + +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), + } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw + } + } + + io.Copy(dst, src) +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration + + lk sync.Mutex // protects Write + Flush + done chan bool +} + +func (m *maxLatencyWriter) Write(p []byte) (int, error) { + m.lk.Lock() + defer m.lk.Unlock() + return m.dst.Write(p) +} + +func (m *maxLatencyWriter) flushLoop() { + t := time.NewTicker(m.latency) + defer t.Stop() + for { + select { + case <-m.done: + if onExitFlushLoop != nil { + onExitFlushLoop() + } + return + case <-t.C: + m.lk.Lock() + m.dst.Flush() + m.lk.Unlock() + } + } +} + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go new file mode 100644 index 000000000..a01002090 --- /dev/null +++ b/middleware/proxy/upstream.go @@ -0,0 +1,203 @@ +package proxy + +import ( + "github.com/mholt/caddy/middleware" + "io" + "io/ioutil" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +type staticUpstream struct { + from string + Hosts HostPool + Policy Policy + + FailTimeout time.Duration + MaxFails int32 + HealthCheck struct { + Path string + Interval time.Duration + } +} + +func newStaticUpstreams(c middleware.Controller) ([]Upstream, error) { + var upstreams []Upstream + + for c.Next() { + upstream := &staticUpstream{ + from: "", + Hosts: nil, + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + var proxyHeaders http.Header + if !c.Args(&upstream.from) { + return upstreams, c.ArgErr() + } + to := c.RemainingArgs() + if len(to) == 0 { + return upstreams, c.ArgErr() + } + + for c.NextBlock() { + switch c.Val() { + case "policy": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + switch c.Val() { + case "random": + upstream.Policy = &Random{} + case "round_robin": + upstream.Policy = &RoundRobin{} + case "least_conn": + upstream.Policy = &LeastConn{} + default: + return upstreams, c.ArgErr() + } + case "fail_timeout": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + if dur, err := time.ParseDuration(c.Val()); err == nil { + upstream.FailTimeout = dur + } else { + return upstreams, err + } + case "max_fails": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + if n, err := strconv.Atoi(c.Val()); err == nil { + upstream.MaxFails = int32(n) + } else { + return upstreams, err + } + case "health_check": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + upstream.HealthCheck.Path = c.Val() + upstream.HealthCheck.Interval = 30 * time.Second + if c.NextArg() { + if dur, err := time.ParseDuration(c.Val()); err == nil { + upstream.HealthCheck.Interval = dur + } else { + return upstreams, err + } + } + case "proxy_header": + var header, value string + if !c.Args(&header, &value) { + return upstreams, c.ArgErr() + } + if proxyHeaders == nil { + proxyHeaders = make(map[string][]string) + } + proxyHeaders.Add(header, value) + } + } + + upstream.Hosts = make([]*UpstreamHost, len(to)) + for i, host := range to { + if !strings.HasPrefix(host, "http") { + host = "http://" + host + } + uh := &UpstreamHost{ + Name: host, + Conns: 0, + Fails: 0, + FailTimeout: upstream.FailTimeout, + Unhealthy: false, + ExtraHeaders: proxyHeaders, + CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { + return func(uh *UpstreamHost) bool { + if uh.Unhealthy { + return true + } + if uh.Fails >= upstream.MaxFails && + upstream.MaxFails != 0 { + return true + } + return false + } + }(upstream), + } + if baseUrl, err := url.Parse(uh.Name); err == nil { + uh.ReverseProxy = NewSingleHostReverseProxy(baseUrl) + } else { + return upstreams, err + } + upstream.Hosts[i] = uh + } + + if upstream.HealthCheck.Path != "" { + go upstream.healthCheckWorker(nil) + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil +} + +func (u *staticUpstream) healthCheck() { + for _, host := range u.Hosts { + hostUrl := host.Name + u.HealthCheck.Path + if r, err := http.Get(hostUrl); err == nil { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 + } else { + host.Unhealthy = true + } + } +} + +func (u *staticUpstream) healthCheckWorker(stop chan struct{}) { + ticker := time.NewTicker(u.HealthCheck.Interval) + u.healthCheck() + for { + select { + case <-ticker.C: + u.healthCheck() + case <-stop: + // TODO: the library should provide a stop channel and global + // waitgroup to allow goroutines started by plugins a chance + // to clean themselves up. + } + } +} + +func (u *staticUpstream) From() string { + return u.from +} + +func (u *staticUpstream) Select() *UpstreamHost { + pool := u.Hosts + if len(pool) == 1 { + if pool[0].Down() { + return nil + } + return pool[0] + } + allDown := true + for _, host := range pool { + if !host.Down() { + allDown = false + break + } + } + if allDown { + return nil + } + + if u.Policy == nil { + return (&Random{}).Select(pool) + } else { + return u.Policy.Select(pool) + } +} diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go new file mode 100644 index 000000000..6be3f6cea --- /dev/null +++ b/middleware/proxy/upstream_test.go @@ -0,0 +1,43 @@ +package proxy + +import ( + "testing" + "time" +) + +func TestHealthCheck(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool(), + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.healthCheck() + if upstream.Hosts[0].Down() { + t.Error("Expected first host in testpool to not fail healthcheck.") + } + if !upstream.Hosts[1].Down() { + t.Error("Expected second host in testpool to fail healthcheck.") + } +} + +func TestSelect(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool()[:3], + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.Hosts[0].Unhealthy = true + upstream.Hosts[1].Unhealthy = true + upstream.Hosts[2].Unhealthy = true + if h := upstream.Select(); h != nil { + t.Error("Expected select to return nil as all host are down") + } + upstream.Hosts[2].Unhealthy = false + if h := upstream.Select(); h == nil { + t.Error("Expected select to not return nil") + } +}