diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go index 27e02a177..f90a1e239 100644 --- a/middleware/proxy/policy.go +++ b/middleware/proxy/policy.go @@ -11,6 +11,7 @@ type HostPool []*UpstreamHost // Policy decides how a host will be selected from a pool. type Policy interface { Select(pool HostPool) *UpstreamHost + Name() string } // Random is a policy that selects up hosts from a pool at random. @@ -39,6 +40,11 @@ func (r *Random) Select(pool HostPool) *UpstreamHost { return randHost } +// Name returns the name of the policy. +func (r *Random) Name() string { + return "random" +} + // LeastConn is a policy that selects the host with the least connections. type LeastConn struct{} @@ -74,6 +80,11 @@ func (r *LeastConn) Select(pool HostPool) *UpstreamHost { return bestHost } +// Name returns the name of the policy. +func (r *LeastConn) Name() string { + return "least_conn" +} + // RoundRobin is a policy that selects hosts based on round robin ordering. type RoundRobin struct { Robin uint32 @@ -93,3 +104,8 @@ func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { } return host } + +// Name returns the name of the policy. +func (r *RoundRobin) Name() string { + return "round_robin" +} diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go index 11269a4f2..4272e332a 100644 --- a/middleware/proxy/policy_test.go +++ b/middleware/proxy/policy_test.go @@ -4,6 +4,16 @@ import ( "testing" ) +type customPolicy struct{} + +func (r *customPolicy) Select(pool HostPool) *UpstreamHost { + return pool[0] +} + +func (r *customPolicy) Name() string { + return "custom" +} + func testPool() HostPool { pool := []*UpstreamHost{ &UpstreamHost{ @@ -55,3 +65,12 @@ func TestLeastConnPolicy(t *testing.T) { t.Error("Expected least connection host to be first or second host.") } } + +func TestCustomPolicy(t *testing.T) { + pool := testPool() + customPolicy := &customPolicy{} + h := customPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected custom policy host to be the first host.") + } +} diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index c724da78e..58325f1d4 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -12,6 +12,8 @@ import ( "github.com/mholt/caddy/config/parse" ) +var supportedPolicies map[string]Policy = make(map[string]Policy) + type staticUpstream struct { from string Hosts HostPool @@ -30,6 +32,10 @@ type staticUpstream struct { func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { var upstreams []Upstream + RegisterPolicy(&Random{}) + RegisterPolicy(&LeastConn{}) + RegisterPolicy(&RoundRobin{}) + for c.Next() { upstream := &staticUpstream{ from: "", @@ -53,16 +59,12 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { if !c.NextArg() { return upstreams, c.ArgErr() } - switch c.Val() { - case "random": - upstream.Policy = &Random{} - case "round_robin": - upstream.Policy = &RoundRobin{} - case "least_conn": - upstream.Policy = &LeastConn{} - default: + + policy, ok := supportedPolicies[c.Val()] + if !ok { return upstreams, c.ArgErr() } + upstream.Policy = policy case "fail_timeout": if !c.NextArg() { return upstreams, c.ArgErr() @@ -147,6 +149,11 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { return upstreams, nil } +// RegisterPolicy adds a custom policy to the proxy. +func RegisterPolicy(policy Policy) { + supportedPolicies[policy.Name()] = policy +} + func (u *staticUpstream) From() string { return u.from } diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go index 6be3f6cea..1d1cc317d 100644 --- a/middleware/proxy/upstream_test.go +++ b/middleware/proxy/upstream_test.go @@ -41,3 +41,12 @@ func TestSelect(t *testing.T) { t.Error("Expected select to not return nil") } } + +func TestRegisterPolicy(t *testing.T) { + customPolicy := &customPolicy{} + RegisterPolicy(customPolicy) + if _, ok := supportedPolicies[customPolicy.Name()]; !ok { + t.Error("Expected supportedPolicies to have a custom policy.") + } + +}