diff --git a/caddyhttp/fastcgi/fastcgi.go b/caddyhttp/fastcgi/fastcgi.go index dad13d17d..ee466a3e8 100644 --- a/caddyhttp/fastcgi/fastcgi.go +++ b/caddyhttp/fastcgi/fastcgi.go @@ -20,6 +20,7 @@ package fastcgi import ( "context" "errors" + "fmt" "io" "net" "net/http" @@ -107,7 +108,11 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) } // Connect to FastCGI gateway - network, address := parseAddress(rule.Address()) + address, err := rule.Address() + if err != nil { + return http.StatusBadGateway, err + } + network, address := parseAddress(address) ctx := context.Background() if rule.ConnectTimeout > 0 { @@ -381,7 +386,7 @@ type Rule struct { type balancer interface { // Address picks an upstream address from the // underlying load balancer. - Address() string + Address() (string, error) } // roundRobin is a round robin balancer for fastcgi upstreams. @@ -393,9 +398,34 @@ type roundRobin struct { addresses []string } -func (r *roundRobin) Address() string { +func (r *roundRobin) Address() (string, error) { index := atomic.AddInt64(&r.index, 1) % int64(len(r.addresses)) - return r.addresses[index] + return r.addresses[index], nil +} + +// srvResolver is a private interface used to abstract +// the DNS resolver. It is mainly used to facilitate testing. +type srvResolver interface { + LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) +} + +// srv is a service locator for fastcgi upstreams +type srv struct { + resolver srvResolver + service string +} + +// Address looks up the service and returns the address:port +// from first result in resolved list. +// No explicit balancing is required because net.LookupSRV +// sorts the results by priority and randomizes within priority. +func (s *srv) Address() (string, error) { + _, addrs, err := s.resolver.LookupSRV(context.Background(), "", "", s.service) + if err != nil { + return "", err + } + + return fmt.Sprintf("%s:%d", strings.TrimRight(addrs[0].Target, "."), addrs[0].Port), nil } // canSplit checks if path can split into two based on rule.SplitPath. diff --git a/caddyhttp/fastcgi/fastcgi_test.go b/caddyhttp/fastcgi/fastcgi_test.go index e7c7a7ba0..e3d3f86f3 100644 --- a/caddyhttp/fastcgi/fastcgi_test.go +++ b/caddyhttp/fastcgi/fastcgi_test.go @@ -84,11 +84,15 @@ func TestRuleParseAddress(t *testing.T) { } for _, entry := range getClientTestTable { - if actualnetwork, _ := parseAddress(entry.rule.Address()); actualnetwork != entry.expectednetwork { - t.Errorf("Unexpected network for address string %v. Got %v, expected %v", entry.rule.Address(), actualnetwork, entry.expectednetwork) + addr, err := entry.rule.Address() + if err != nil { + t.Errorf("Unexpected error in retrieving address: %s", err.Error()) } - if _, actualaddress := parseAddress(entry.rule.Address()); actualaddress != entry.expectedaddress { - t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address(), actualaddress, entry.expectedaddress) + if actualnetwork, _ := parseAddress(addr); actualnetwork != entry.expectednetwork { + t.Errorf("Unexpected network for address string %v. Got %v, expected %v", addr, actualnetwork, entry.expectednetwork) + } + if _, actualaddress := parseAddress(addr); actualaddress != entry.expectedaddress { + t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", addr, actualaddress, entry.expectedaddress) } } } @@ -365,7 +369,10 @@ func TestBalancer(t *testing.T) { for i, test := range tests { b := address(test...) for _, host := range test { - a := b.Address() + a, err := b.Address() + if err != nil { + t.Errorf("Unexpected error in trying to retrieve address: %s", err.Error()) + } if a != host { t.Errorf("Test %d: expected %s, found %s", i, host, a) } diff --git a/caddyhttp/fastcgi/setup.go b/caddyhttp/fastcgi/setup.go index 24f0337a3..22b310ff3 100644 --- a/caddyhttp/fastcgi/setup.go +++ b/caddyhttp/fastcgi/setup.go @@ -16,8 +16,11 @@ package fastcgi import ( "errors" + "fmt" + "net" "net/http" "path/filepath" + "strings" "time" "github.com/mholt/caddy" @@ -76,8 +79,14 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { Root: absRoot, Path: args[0], } + upstreams := []string{args[1]} + srvUpstream := false + if strings.HasPrefix(upstreams[0], "srv://") { + srvUpstream = true + } + if len(args) == 3 { if err := fastcgiPreset(args[2], &rule); err != nil { return rules, err @@ -112,6 +121,10 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { rule.IndexFiles = args case "upstream": + if srvUpstream { + return rules, c.Err("additional upstreams are not supported with SRV upstream") + } + args := c.RemainingArgs() if len(args) != 1 { @@ -161,13 +174,32 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { } } - rule.balancer = &roundRobin{addresses: upstreams, index: -1} + if srvUpstream { + balancer, err := parseSRV(upstreams[0]) + if err != nil { + return rules, c.Err("malformed service locator string: " + err.Error()) + } + rule.balancer = balancer + } else { + rule.balancer = &roundRobin{addresses: upstreams, index: -1} + } rules = append(rules, rule) } return rules, nil } +func parseSRV(locator string) (*srv, error) { + if locator[6:] == "" { + return nil, fmt.Errorf("%s does not include the host", locator) + } + + return &srv{ + service: locator[6:], + resolver: &net.Resolver{}, + }, nil +} + // fastcgiPreset configures rule according to name. It returns an error if // name is not a recognized preset name. func fastcgiPreset(name string, rule *Rule) error { diff --git a/caddyhttp/fastcgi/setup_test.go b/caddyhttp/fastcgi/setup_test.go index f21cfa848..37f27bceb 100644 --- a/caddyhttp/fastcgi/setup_test.go +++ b/caddyhttp/fastcgi/setup_test.go @@ -15,7 +15,9 @@ package fastcgi import ( + "context" "fmt" + "net" "testing" "github.com/mholt/caddy" @@ -43,10 +45,14 @@ func TestSetup(t *testing.T) { if myHandler.Rules[0].Path != "/" { t.Errorf("Expected / as the Path") } - if myHandler.Rules[0].Address() != "127.0.0.1:9000" { - t.Errorf("Expected 127.0.0.1:9000 as the Address") + addr, err := myHandler.Rules[0].Address() + if err != nil { + t.Errorf("Unexpected error in trying to retrieve address: %s", err.Error()) } + if addr != "127.0.0.1:9000" { + t.Errorf("Expected 127.0.0.1:9000 as the Address") + } } func TestFastcgiParse(t *testing.T) { @@ -106,9 +112,19 @@ func TestFastcgiParse(t *testing.T) { i, j, test.expectedFastcgiConfig[j].Path, actualFastcgiConfig.Path) } - if actualFastcgiConfig.Address() != test.expectedFastcgiConfig[j].Address() { + actualAddr, err := actualFastcgiConfig.Address() + if err != nil { + t.Errorf("Test %d unexpected error in trying to retrieve %dth actual address: %s", i, j, err.Error()) + } + + expectedAddr, err := test.expectedFastcgiConfig[j].Address() + if err != nil { + t.Errorf("Test %d unexpected error in trying to retrieve %dth expected address: %s", i, j, err.Error()) + } + + if actualAddr != expectedAddr { t.Errorf("Test %d expected %dth FastCGI Address to be %s , but got %s", - i, j, test.expectedFastcgiConfig[j].Address(), actualFastcgiConfig.Address()) + i, j, expectedAddr, actualAddr) } if actualFastcgiConfig.Ext != test.expectedFastcgiConfig[j].Ext { @@ -134,3 +150,75 @@ func TestFastcgiParse(t *testing.T) { } } + +func TestFastCGIResolveSRV(t *testing.T) { + tests := []struct { + inputFastcgiConfig string + locator string + target string + port uint16 + shouldErr bool + }{ + { + `fastcgi / srv://fpm.tcp.service.consul { + upstream yolo + }`, + "fpm.tcp.service.consul", + "127.0.0.1", + 9000, + true, + }, + { + `fastcgi / srv://fpm.tcp.service.consul`, + "fpm.tcp.service.consul", + "127.0.0.1", + 9000, + false, + }, + } + + for i, test := range tests { + actualFastcgiConfigs, err := fastcgiParse(caddy.NewTestController("http", test.inputFastcgiConfig)) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } + + for _, actualFastcgiConfig := range actualFastcgiConfigs { + resolver, ok := (actualFastcgiConfig.balancer).(*srv) + if !ok { + t.Errorf("Test %d upstream balancer is not srv", i) + } + resolver.resolver = buildTestResolver(test.target, test.port) + + addr, err := actualFastcgiConfig.Address() + if err != nil { + t.Errorf("Test %d failed to retrieve upstream address. %s", i, err.Error()) + } + + expectedAddr := fmt.Sprintf("%s:%d", test.target, test.port) + if addr != expectedAddr { + t.Errorf("Test %d expected upstream address to be %s, got %s", i, expectedAddr, addr) + } + } + } +} + +func buildTestResolver(target string, port uint16) srvResolver { + return &testSRVResolver{target, port} +} + +type testSRVResolver struct { + target string + port uint16 +} + +func (r *testSRVResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) { + return "", []*net.SRV{ + {Target: r.target, + Port: r.port, + Priority: 1, + Weight: 1}}, nil +}