mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-24 01:26:47 +01:00
Upgrade proxy middleware. Add support for: multiple backends, load balancing, health checks, and pluggable backends
This commit is contained in:
parent
782ba32457
commit
4a4b80450a
6 changed files with 763 additions and 56 deletions
91
middleware/proxy/policy.go
Normal file
91
middleware/proxy/policy.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type HostPool []*UpstreamHost
|
||||
|
||||
// Policy decides how a host will be selected from a pool.
|
||||
type Policy interface {
|
||||
Select(pool HostPool) *UpstreamHost
|
||||
}
|
||||
|
||||
// The random policy randomly selected an up host from the pool.
|
||||
type Random struct{}
|
||||
|
||||
func (r *Random) Select(pool HostPool) *UpstreamHost {
|
||||
// instead of just generating a random index
|
||||
// this is done to prevent selecting a down host
|
||||
var randHost *UpstreamHost
|
||||
count := 0
|
||||
for _, host := range pool {
|
||||
if host.Down() {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
if count == 1 {
|
||||
randHost = host
|
||||
} else {
|
||||
r := rand.Int() % count
|
||||
if r == (count - 1) {
|
||||
randHost = host
|
||||
}
|
||||
}
|
||||
}
|
||||
return randHost
|
||||
}
|
||||
|
||||
// The least_conn policy selects a host with the least connections.
|
||||
// If multiple hosts have the least amount of connections, one is randomly
|
||||
// chosen.
|
||||
type LeastConn struct{}
|
||||
|
||||
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
|
||||
var bestHost *UpstreamHost
|
||||
count := 0
|
||||
leastConn := int64(1<<63 - 1)
|
||||
for _, host := range pool {
|
||||
if host.Down() {
|
||||
continue
|
||||
}
|
||||
hostConns := host.Conns
|
||||
if hostConns < leastConn {
|
||||
bestHost = host
|
||||
leastConn = hostConns
|
||||
count = 1
|
||||
} else if hostConns == leastConn {
|
||||
// randomly select host among hosts with least connections
|
||||
count++
|
||||
if count == 1 {
|
||||
bestHost = host
|
||||
} else {
|
||||
r := rand.Int() % count
|
||||
if r == (count - 1) {
|
||||
bestHost = host
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return bestHost
|
||||
}
|
||||
|
||||
// The round_robin policy selects a host based on round robin ordering.
|
||||
type RoundRobin struct {
|
||||
Robin uint32
|
||||
}
|
||||
|
||||
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
|
||||
poolLen := uint32(len(pool))
|
||||
selection := atomic.AddUint32(&r.Robin, 1) % poolLen
|
||||
host := pool[selection]
|
||||
// if the currently selected host is down, just ffwd to up host
|
||||
for i := uint32(1); host.Down() && i < poolLen; i++ {
|
||||
host = pool[(selection+i)%poolLen]
|
||||
}
|
||||
if host.Down() {
|
||||
return nil
|
||||
}
|
||||
return host
|
||||
}
|
57
middleware/proxy/policy_test.go
Normal file
57
middleware/proxy/policy_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testPool() HostPool {
|
||||
pool := []*UpstreamHost{
|
||||
&UpstreamHost{
|
||||
Name: "http://google.com", // this should resolve (healthcheck test)
|
||||
},
|
||||
&UpstreamHost{
|
||||
Name: "http://shouldnot.resolve", // this shouldn't
|
||||
},
|
||||
&UpstreamHost{
|
||||
Name: "http://C",
|
||||
},
|
||||
}
|
||||
return HostPool(pool)
|
||||
}
|
||||
|
||||
func TestRoundRobinPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
rrPolicy := &RoundRobin{}
|
||||
h := rrPolicy.Select(pool)
|
||||
// First selected host is 1, because counter starts at 0
|
||||
// and increments before host is selected
|
||||
if h != pool[1] {
|
||||
t.Error("Expected first round robin host to be second host in the pool.")
|
||||
}
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected second round robin host to be third host in the pool.")
|
||||
}
|
||||
// mark host as down
|
||||
pool[0].Unhealthy = true
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected third round robin host to be first host in the pool.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastConnPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
lcPolicy := &LeastConn{}
|
||||
pool[0].Conns = 10
|
||||
pool[1].Conns = 10
|
||||
h := lcPolicy.Select(pool)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected least connection host to be third host.")
|
||||
}
|
||||
pool[2].Conns = 100
|
||||
h = lcPolicy.Select(pool)
|
||||
if h != pool[0] && h != pool[1] {
|
||||
t.Error("Expected least connection host to be first or second host.")
|
||||
}
|
||||
}
|
|
@ -2,52 +2,169 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"errors"
|
||||
"github.com/mholt/caddy/middleware"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var errUnreachable = errors.New("Unreachable backend")
|
||||
|
||||
// Proxy represents a middleware instance that can proxy requests.
|
||||
type Proxy struct {
|
||||
Next middleware.Handler
|
||||
Rules []Rule
|
||||
Upstreams []Upstream
|
||||
}
|
||||
|
||||
// An upstream manages a pool of proxy upstream hosts. Select should return a
|
||||
// suitable upstream host, or nil if no such hosts are available.
|
||||
type Upstream interface {
|
||||
// The path this upstream host should be routed on
|
||||
From() string
|
||||
// Selects an upstream host to be routed to.
|
||||
Select() *UpstreamHost
|
||||
}
|
||||
|
||||
type UpstreamHostDownFunc func(*UpstreamHost) bool
|
||||
|
||||
// An UpstreamHost represents a single proxy upstream
|
||||
type UpstreamHost struct {
|
||||
Name string
|
||||
ReverseProxy *ReverseProxy
|
||||
Conns int64
|
||||
Fails int32
|
||||
FailTimeout time.Duration
|
||||
Unhealthy bool
|
||||
ExtraHeaders http.Header
|
||||
CheckDown UpstreamHostDownFunc
|
||||
}
|
||||
|
||||
func (uh *UpstreamHost) Down() bool {
|
||||
if uh.CheckDown == nil {
|
||||
// Default settings
|
||||
return uh.Unhealthy || uh.Fails > 0
|
||||
}
|
||||
return uh.CheckDown(uh)
|
||||
}
|
||||
|
||||
//https://github.com/mgutz/str
|
||||
var tRe = regexp.MustCompile(`([\-\[\]()*\s])`)
|
||||
var tRe2 = regexp.MustCompile(`\$`)
|
||||
var openDelim = tRe2.ReplaceAllString(tRe.ReplaceAllString("{{", "\\$1"), "\\$")
|
||||
var closDelim = tRe2.ReplaceAllString(tRe.ReplaceAllString("}}", "\\$1"), "\\$")
|
||||
var templateDelim = regexp.MustCompile(openDelim + `(.+?)` + closDelim)
|
||||
|
||||
type requestVars struct {
|
||||
Host string
|
||||
RemoteIp string
|
||||
Scheme string
|
||||
Upstream string
|
||||
UpstreamHost string
|
||||
}
|
||||
|
||||
func templateWithDelimiters(s string, vars requestVars) string {
|
||||
matches := templateDelim.FindAllStringSubmatch(s, -1)
|
||||
for _, submatches := range matches {
|
||||
match := submatches[0]
|
||||
key := submatches[1]
|
||||
found := true
|
||||
repl := ""
|
||||
switch key {
|
||||
case "http_host":
|
||||
repl = vars.Host
|
||||
case "remote_addr":
|
||||
repl = vars.RemoteIp
|
||||
case "scheme":
|
||||
repl = vars.Scheme
|
||||
case "upstream":
|
||||
repl = vars.Upstream
|
||||
case "upstream_host":
|
||||
repl = vars.UpstreamHost
|
||||
default:
|
||||
found = false
|
||||
}
|
||||
if found {
|
||||
s = strings.Replace(s, match, repl, -1)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ServeHTTP satisfies the middleware.Handler interface.
|
||||
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
|
||||
for _, rule := range p.Rules {
|
||||
if middleware.Path(r.URL.Path).Matches(rule.From) {
|
||||
var base string
|
||||
|
||||
if strings.HasPrefix(rule.To, "http") { // includes https
|
||||
// destination includes a scheme! no need to guess
|
||||
base = rule.To
|
||||
} else {
|
||||
// no scheme specified; assume same as request
|
||||
var scheme string
|
||||
if r.TLS == nil {
|
||||
scheme = "http"
|
||||
} else {
|
||||
scheme = "https"
|
||||
for _, upstream := range p.Upstreams {
|
||||
if middleware.Path(r.URL.Path).Matches(upstream.From()) {
|
||||
vars := requestVars{
|
||||
Host: r.Host,
|
||||
Scheme: "http",
|
||||
}
|
||||
base = scheme + "://" + rule.To
|
||||
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
vars.RemoteIp = clientIP
|
||||
}
|
||||
if fFor := r.Header.Get("X-Forwarded-For"); fFor != "" {
|
||||
vars.RemoteIp = fFor
|
||||
}
|
||||
if r.TLS != nil {
|
||||
vars.Scheme = "https"
|
||||
}
|
||||
// Since Select() should give us "up" hosts, keep retrying
|
||||
// hosts until timeout (or until we get a nil host).
|
||||
start := time.Now()
|
||||
for time.Now().Sub(start) < (60 * time.Second) {
|
||||
host := upstream.Select()
|
||||
if host == nil {
|
||||
return http.StatusBadGateway, errUnreachable
|
||||
}
|
||||
proxy := host.ReverseProxy
|
||||
vars.Upstream = host.Name
|
||||
r.Host = host.Name
|
||||
|
||||
baseUrl, err := url.Parse(base)
|
||||
if err != nil {
|
||||
if baseUrl, err := url.Parse(host.Name); err == nil {
|
||||
vars.UpstreamHost = baseUrl.Host
|
||||
if proxy == nil {
|
||||
proxy = NewSingleHostReverseProxy(baseUrl)
|
||||
}
|
||||
} else if proxy == nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
r.Host = baseUrl.Host
|
||||
var extraHeaders http.Header
|
||||
if host.ExtraHeaders != nil {
|
||||
extraHeaders = make(http.Header)
|
||||
for header, values := range host.ExtraHeaders {
|
||||
for _, value := range values {
|
||||
extraHeaders.Add(header,
|
||||
templateWithDelimiters(value, vars))
|
||||
if header == "Host" {
|
||||
r.Host = templateWithDelimiters(value, vars)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Construct this before; not during every request, if possible
|
||||
proxy := httputil.NewSingleHostReverseProxy(baseUrl)
|
||||
proxy.ServeHTTP(w, r)
|
||||
atomic.AddInt64(&host.Conns, 1)
|
||||
backendErr := proxy.ServeHTTP(w, r, extraHeaders)
|
||||
atomic.AddInt64(&host.Conns, -1)
|
||||
if backendErr == nil {
|
||||
return 0, nil
|
||||
}
|
||||
timeout := host.FailTimeout
|
||||
if timeout == 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
atomic.AddInt32(&host.Fails, 1)
|
||||
go func(host *UpstreamHost, timeout time.Duration) {
|
||||
time.Sleep(timeout)
|
||||
atomic.AddInt32(&host.Fails, -1)
|
||||
}(host, timeout)
|
||||
}
|
||||
return http.StatusBadGateway, errUnreachable
|
||||
}
|
||||
}
|
||||
|
||||
return p.Next.ServeHTTP(w, r)
|
||||
|
@ -55,30 +172,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
|
||||
// New creates a new instance of proxy middleware.
|
||||
func New(c middleware.Controller) (middleware.Middleware, error) {
|
||||
rules, err := parse(c)
|
||||
if err != nil {
|
||||
if upstreams, err := newStaticUpstreams(c); err == nil {
|
||||
return func(next middleware.Handler) middleware.Handler {
|
||||
return Proxy{Next: next, Upstreams: upstreams}
|
||||
}, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(next middleware.Handler) middleware.Handler {
|
||||
return Proxy{Next: next, Rules: rules}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parse(c middleware.Controller) ([]Rule, error) {
|
||||
var rules []Rule
|
||||
|
||||
for c.Next() {
|
||||
var rule Rule
|
||||
if !c.Args(&rule.From, &rule.To) {
|
||||
return rules, c.ArgErr()
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
type Rule struct {
|
||||
From, To string
|
||||
}
|
||||
|
|
215
middleware/proxy/reverseproxy.go
Normal file
215
middleware/proxy/reverseproxy.go
Normal file
|
@ -0,0 +1,215 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// HTTP reverse proxy handler
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// onExitFlushLoop is a callback set by tests to detect the state of the
|
||||
// flushLoop() goroutine.
|
||||
var onExitFlushLoop func()
|
||||
|
||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||
// sends it to another server, proxying the response back to the
|
||||
// client.
|
||||
type ReverseProxy struct {
|
||||
// Director must be a function which modifies
|
||||
// the request into a new request to be sent
|
||||
// using Transport. Its response is then copied
|
||||
// back to the original client unmodified.
|
||||
Director func(*http.Request)
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// FlushInterval specifies the flush interval
|
||||
// to flush to the client while copying the
|
||||
// response body.
|
||||
// If zero, no periodic flushing is done.
|
||||
FlushInterval time.Duration
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
|
||||
// URLs to the scheme, host, and base path provided in target. If the
|
||||
// target's path is "/base" and the incoming request was for "/dir",
|
||||
// the target request will be for /base/dir.
|
||||
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
|
||||
targetQuery := target.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
}
|
||||
return &ReverseProxy{Director: director}
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailers",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extraHeaders http.Header) error {
|
||||
transport := p.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
outreq := new(http.Request)
|
||||
*outreq = *req // includes shallow copies of maps, but okay
|
||||
|
||||
p.Director(outreq)
|
||||
outreq.Proto = "HTTP/1.1"
|
||||
outreq.ProtoMajor = 1
|
||||
outreq.ProtoMinor = 1
|
||||
outreq.Close = false
|
||||
|
||||
// Remove hop-by-hop headers to the backend. Especially
|
||||
// important is "Connection" because we want a persistent
|
||||
// connection, regardless of what the client sent to us. This
|
||||
// is modifying the same underlying map from req (shallow
|
||||
// copied above) so we only copy it if necessary.
|
||||
copiedHeaders := false
|
||||
for _, h := range hopHeaders {
|
||||
if outreq.Header.Get(h) != "" {
|
||||
if !copiedHeaders {
|
||||
outreq.Header = make(http.Header)
|
||||
copyHeader(outreq.Header, req.Header)
|
||||
copiedHeaders = true
|
||||
}
|
||||
outreq.Header.Del(h)
|
||||
}
|
||||
}
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
outreq.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
|
||||
if extraHeaders != nil {
|
||||
for k, v := range extraHeaders {
|
||||
outreq.Header[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
for _, h := range hopHeaders {
|
||||
res.Header.Del(h)
|
||||
}
|
||||
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
p.copyResponse(rw, res.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||
if p.FlushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
dst: wf,
|
||||
latency: p.FlushInterval,
|
||||
done: make(chan bool),
|
||||
}
|
||||
go mlw.flushLoop()
|
||||
defer mlw.stop()
|
||||
dst = mlw
|
||||
}
|
||||
}
|
||||
|
||||
io.Copy(dst, src)
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
io.Writer
|
||||
http.Flusher
|
||||
}
|
||||
|
||||
type maxLatencyWriter struct {
|
||||
dst writeFlusher
|
||||
latency time.Duration
|
||||
|
||||
lk sync.Mutex // protects Write + Flush
|
||||
done chan bool
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
|
||||
m.lk.Lock()
|
||||
defer m.lk.Unlock()
|
||||
return m.dst.Write(p)
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) flushLoop() {
|
||||
t := time.NewTicker(m.latency)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.done:
|
||||
if onExitFlushLoop != nil {
|
||||
onExitFlushLoop()
|
||||
}
|
||||
return
|
||||
case <-t.C:
|
||||
m.lk.Lock()
|
||||
m.dst.Flush()
|
||||
m.lk.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) stop() { m.done <- true }
|
203
middleware/proxy/upstream.go
Normal file
203
middleware/proxy/upstream.go
Normal file
|
@ -0,0 +1,203 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/mholt/caddy/middleware"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type staticUpstream struct {
|
||||
from string
|
||||
Hosts HostPool
|
||||
Policy Policy
|
||||
|
||||
FailTimeout time.Duration
|
||||
MaxFails int32
|
||||
HealthCheck struct {
|
||||
Path string
|
||||
Interval time.Duration
|
||||
}
|
||||
}
|
||||
|
||||
func newStaticUpstreams(c middleware.Controller) ([]Upstream, error) {
|
||||
var upstreams []Upstream
|
||||
|
||||
for c.Next() {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
Hosts: nil,
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
var proxyHeaders http.Header
|
||||
if !c.Args(&upstream.from) {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
to := c.RemainingArgs()
|
||||
if len(to) == 0 {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "policy":
|
||||
if !c.NextArg() {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
switch c.Val() {
|
||||
case "random":
|
||||
upstream.Policy = &Random{}
|
||||
case "round_robin":
|
||||
upstream.Policy = &RoundRobin{}
|
||||
case "least_conn":
|
||||
upstream.Policy = &LeastConn{}
|
||||
default:
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
case "fail_timeout":
|
||||
if !c.NextArg() {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
if dur, err := time.ParseDuration(c.Val()); err == nil {
|
||||
upstream.FailTimeout = dur
|
||||
} else {
|
||||
return upstreams, err
|
||||
}
|
||||
case "max_fails":
|
||||
if !c.NextArg() {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
if n, err := strconv.Atoi(c.Val()); err == nil {
|
||||
upstream.MaxFails = int32(n)
|
||||
} else {
|
||||
return upstreams, err
|
||||
}
|
||||
case "health_check":
|
||||
if !c.NextArg() {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
upstream.HealthCheck.Path = c.Val()
|
||||
upstream.HealthCheck.Interval = 30 * time.Second
|
||||
if c.NextArg() {
|
||||
if dur, err := time.ParseDuration(c.Val()); err == nil {
|
||||
upstream.HealthCheck.Interval = dur
|
||||
} else {
|
||||
return upstreams, err
|
||||
}
|
||||
}
|
||||
case "proxy_header":
|
||||
var header, value string
|
||||
if !c.Args(&header, &value) {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
if proxyHeaders == nil {
|
||||
proxyHeaders = make(map[string][]string)
|
||||
}
|
||||
proxyHeaders.Add(header, value)
|
||||
}
|
||||
}
|
||||
|
||||
upstream.Hosts = make([]*UpstreamHost, len(to))
|
||||
for i, host := range to {
|
||||
if !strings.HasPrefix(host, "http") {
|
||||
host = "http://" + host
|
||||
}
|
||||
uh := &UpstreamHost{
|
||||
Name: host,
|
||||
Conns: 0,
|
||||
Fails: 0,
|
||||
FailTimeout: upstream.FailTimeout,
|
||||
Unhealthy: false,
|
||||
ExtraHeaders: proxyHeaders,
|
||||
CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
|
||||
return func(uh *UpstreamHost) bool {
|
||||
if uh.Unhealthy {
|
||||
return true
|
||||
}
|
||||
if uh.Fails >= upstream.MaxFails &&
|
||||
upstream.MaxFails != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
}(upstream),
|
||||
}
|
||||
if baseUrl, err := url.Parse(uh.Name); err == nil {
|
||||
uh.ReverseProxy = NewSingleHostReverseProxy(baseUrl)
|
||||
} else {
|
||||
return upstreams, err
|
||||
}
|
||||
upstream.Hosts[i] = uh
|
||||
}
|
||||
|
||||
if upstream.HealthCheck.Path != "" {
|
||||
go upstream.healthCheckWorker(nil)
|
||||
}
|
||||
upstreams = append(upstreams, upstream)
|
||||
}
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
func (u *staticUpstream) healthCheck() {
|
||||
for _, host := range u.Hosts {
|
||||
hostUrl := host.Name + u.HealthCheck.Path
|
||||
if r, err := http.Get(hostUrl); err == nil {
|
||||
io.Copy(ioutil.Discard, r.Body)
|
||||
r.Body.Close()
|
||||
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
|
||||
} else {
|
||||
host.Unhealthy = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *staticUpstream) healthCheckWorker(stop chan struct{}) {
|
||||
ticker := time.NewTicker(u.HealthCheck.Interval)
|
||||
u.healthCheck()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
u.healthCheck()
|
||||
case <-stop:
|
||||
// TODO: the library should provide a stop channel and global
|
||||
// waitgroup to allow goroutines started by plugins a chance
|
||||
// to clean themselves up.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *staticUpstream) From() string {
|
||||
return u.from
|
||||
}
|
||||
|
||||
func (u *staticUpstream) Select() *UpstreamHost {
|
||||
pool := u.Hosts
|
||||
if len(pool) == 1 {
|
||||
if pool[0].Down() {
|
||||
return nil
|
||||
}
|
||||
return pool[0]
|
||||
}
|
||||
allDown := true
|
||||
for _, host := range pool {
|
||||
if !host.Down() {
|
||||
allDown = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allDown {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.Policy == nil {
|
||||
return (&Random{}).Select(pool)
|
||||
} else {
|
||||
return u.Policy.Select(pool)
|
||||
}
|
||||
}
|
43
middleware/proxy/upstream_test.go
Normal file
43
middleware/proxy/upstream_test.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
Hosts: testPool(),
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
upstream.healthCheck()
|
||||
if upstream.Hosts[0].Down() {
|
||||
t.Error("Expected first host in testpool to not fail healthcheck.")
|
||||
}
|
||||
if !upstream.Hosts[1].Down() {
|
||||
t.Error("Expected second host in testpool to fail healthcheck.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelect(t *testing.T) {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
Hosts: testPool()[:3],
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
upstream.Hosts[0].Unhealthy = true
|
||||
upstream.Hosts[1].Unhealthy = true
|
||||
upstream.Hosts[2].Unhealthy = true
|
||||
if h := upstream.Select(); h != nil {
|
||||
t.Error("Expected select to return nil as all host are down")
|
||||
}
|
||||
upstream.Hosts[2].Unhealthy = false
|
||||
if h := upstream.Select(); h == nil {
|
||||
t.Error("Expected select to not return nil")
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue