2023-05-09 18:19:48 +01:00
package runtime
import (
"context"
"errors"
"fmt"
"net/http"
"net/textproto"
2023-09-07 12:20:37 +01:00
"regexp"
2023-05-09 18:19:48 +01:00
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
"google.golang.org/grpc/codes"
2023-09-07 12:20:37 +01:00
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/health/grpc_health_v1"
2023-05-09 18:19:48 +01:00
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// UnescapingMode defines the behavior of ServeMux when unescaping path parameters.
type UnescapingMode int
const (
// UnescapingModeLegacy is the default V2 behavior, which escapes the entire
// path string before doing any routing.
UnescapingModeLegacy UnescapingMode = iota
2023-09-07 12:20:37 +01:00
// UnescapingModeAllExceptReserved unescapes all path parameters except RFC 6570
2023-05-09 18:19:48 +01:00
// reserved characters.
UnescapingModeAllExceptReserved
2023-09-07 12:20:37 +01:00
// UnescapingModeAllExceptSlash unescapes URL path parameters except path
// separators, which will be left as "%2F".
2023-05-09 18:19:48 +01:00
UnescapingModeAllExceptSlash
2023-09-07 12:20:37 +01:00
// UnescapingModeAllCharacters unescapes all URL path parameters.
2023-05-09 18:19:48 +01:00
UnescapingModeAllCharacters
// UnescapingModeDefault is the default escaping type.
// TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's
// reference implementation
UnescapingModeDefault = UnescapingModeLegacy
)
2023-09-07 12:20:37 +01:00
var encodedPathSplitter = regexp . MustCompile ( "(/|%2F)" )
2023-05-09 18:19:48 +01:00
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func ( w http . ResponseWriter , r * http . Request , pathParams map [ string ] string )
// ServeMux is a request multiplexer for grpc-gateway.
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map [ string ] [ ] handler
forwardResponseOptions [ ] func ( context . Context , http . ResponseWriter , proto . Message ) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators [ ] func ( context . Context , * http . Request ) metadata . MD
errorHandler ErrorHandlerFunc
streamErrorHandler StreamErrorHandlerFunc
routingErrorHandler RoutingErrorHandlerFunc
disablePathLengthFallback bool
unescapingMode UnescapingMode
}
// ServeMuxOption is an option that can be given to a ServeMux on construction.
type ServeMuxOption func ( * ServeMux )
// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
//
// forwardResponseOption is an option that will be called on the relevant context.Context,
// http.ResponseWriter, and proto.Message before every forwarded response.
//
// The message may be nil in the case where just a header is being sent.
func WithForwardResponseOption ( forwardResponseOption func ( context . Context , http . ResponseWriter , proto . Message ) error ) ServeMuxOption {
return func ( serveMux * ServeMux ) {
serveMux . forwardResponseOptions = append ( serveMux . forwardResponseOptions , forwardResponseOption )
}
}
2023-09-07 12:20:37 +01:00
// WithUnescapingMode sets the escaping type. See the definitions of UnescapingMode
2023-05-09 18:19:48 +01:00
// for more information.
func WithUnescapingMode ( mode UnescapingMode ) ServeMuxOption {
return func ( serveMux * ServeMux ) {
serveMux . unescapingMode = mode
}
}
// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
// done with careful consideration.
func SetQueryParameterParser ( queryParameterParser QueryParameterParser ) ServeMuxOption {
return func ( serveMux * ServeMux ) {
currentQueryParser = queryParameterParser
}
}
// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
type HeaderMatcherFunc func ( string ) ( string , bool )
// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
2023-09-07 12:20:37 +01:00
// keys (as specified by the IANA, e.g: Accept, Cookie, Host) to the gRPC metadata with the grpcgateway- prefix. If you want to know which headers are considered permanent, you can view the isPermanentHTTPHeader function.
// HTTP headers that start with 'Grpc-Metadata-' are mapped to gRPC metadata after removing the prefix 'Grpc-Metadata-'.
// Other headers are not added to the gRPC metadata.
2023-05-09 18:19:48 +01:00
func DefaultHeaderMatcher ( key string ) ( string , bool ) {
2023-09-07 12:20:37 +01:00
switch key = textproto . CanonicalMIMEHeaderKey ( key ) ; {
case isPermanentHTTPHeader ( key ) :
2023-05-09 18:19:48 +01:00
return MetadataPrefix + key , true
2023-09-07 12:20:37 +01:00
case strings . HasPrefix ( key , MetadataHeaderPrefix ) :
2023-05-09 18:19:48 +01:00
return key [ len ( MetadataHeaderPrefix ) : ] , true
}
return "" , false
}
// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
//
// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
func WithIncomingHeaderMatcher ( fn HeaderMatcherFunc ) ServeMuxOption {
2023-09-07 12:20:37 +01:00
for _ , header := range fn . matchedMalformedHeaders ( ) {
grpclog . Warningf ( "The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information." , header )
}
2023-05-09 18:19:48 +01:00
return func ( mux * ServeMux ) {
mux . incomingHeaderMatcher = fn
}
}
2023-09-07 12:20:37 +01:00
// matchedMalformedHeaders returns the malformed headers that would be forwarded to gRPC server.
func ( fn HeaderMatcherFunc ) matchedMalformedHeaders ( ) [ ] string {
if fn == nil {
return nil
}
headers := make ( [ ] string , 0 )
for header := range malformedHTTPHeaders {
out , accept := fn ( header )
if accept && isMalformedHTTPHeader ( out ) {
headers = append ( headers , out )
}
}
return headers
}
2023-05-09 18:19:48 +01:00
// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
//
// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
// passed to http response returned from gateway. To transform the header before passing to response,
// matcher should return modified header.
func WithOutgoingHeaderMatcher ( fn HeaderMatcherFunc ) ServeMuxOption {
return func ( mux * ServeMux ) {
mux . outgoingHeaderMatcher = fn
}
}
// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
//
// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
// is reading token from cookie and adding it in gRPC context.
func WithMetadata ( annotator func ( context . Context , * http . Request ) metadata . MD ) ServeMuxOption {
return func ( serveMux * ServeMux ) {
serveMux . metadataAnnotators = append ( serveMux . metadataAnnotators , annotator )
}
}
// WithErrorHandler returns a ServeMuxOption for configuring a custom error handler.
//
// This can be used to configure a custom error response.
func WithErrorHandler ( fn ErrorHandlerFunc ) ServeMuxOption {
return func ( serveMux * ServeMux ) {
serveMux . errorHandler = fn
}
}
// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
// error handler, which allows for customizing the error trailer for server-streaming
// calls.
//
// For stream errors that occur before any response has been written, the mux's
// ErrorHandler will be invoked. However, once data has been written, the errors must
// be handled differently: they must be included in the response body. The response body's
// final message will include the error details returned by the stream error handler.
func WithStreamErrorHandler ( fn StreamErrorHandlerFunc ) ServeMuxOption {
return func ( serveMux * ServeMux ) {
serveMux . streamErrorHandler = fn
}
}
// WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to handle http routing errors.
//
// Method called for errors which can happen before gRPC route selected or executed.
// The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest
func WithRoutingErrorHandler ( fn RoutingErrorHandlerFunc ) ServeMuxOption {
return func ( serveMux * ServeMux ) {
serveMux . routingErrorHandler = fn
}
}
// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
func WithDisablePathLengthFallback ( ) ServeMuxOption {
return func ( serveMux * ServeMux ) {
serveMux . disablePathLengthFallback = true
}
}
2023-09-07 12:20:37 +01:00
// WithHealthEndpointAt returns a ServeMuxOption that will add an endpoint to the created ServeMux at the path specified by endpointPath.
// When called the handler will forward the request to the upstream grpc service health check (defined in the
// gRPC Health Checking Protocol).
//
// See here https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/health_check/ for more information on how
// to setup the protocol in the grpc server.
//
// If you define a service as query parameter, this will also be forwarded as service in the HealthCheckRequest.
func WithHealthEndpointAt ( healthCheckClient grpc_health_v1 . HealthClient , endpointPath string ) ServeMuxOption {
return func ( s * ServeMux ) {
// error can be ignored since pattern is definitely valid
_ = s . HandlePath (
http . MethodGet , endpointPath , func ( w http . ResponseWriter , r * http . Request , _ map [ string ] string ,
) {
_ , outboundMarshaler := MarshalerForRequest ( s , r )
resp , err := healthCheckClient . Check ( r . Context ( ) , & grpc_health_v1 . HealthCheckRequest {
Service : r . URL . Query ( ) . Get ( "service" ) ,
} )
if err != nil {
s . errorHandler ( r . Context ( ) , s , outboundMarshaler , w , r , err )
return
}
w . Header ( ) . Set ( "Content-Type" , "application/json" )
if resp . GetStatus ( ) != grpc_health_v1 . HealthCheckResponse_SERVING {
switch resp . GetStatus ( ) {
case grpc_health_v1 . HealthCheckResponse_NOT_SERVING , grpc_health_v1 . HealthCheckResponse_UNKNOWN :
err = status . Error ( codes . Unavailable , resp . String ( ) )
case grpc_health_v1 . HealthCheckResponse_SERVICE_UNKNOWN :
err = status . Error ( codes . NotFound , resp . String ( ) )
}
s . errorHandler ( r . Context ( ) , s , outboundMarshaler , w , r , err )
return
}
_ = outboundMarshaler . NewEncoder ( w ) . Encode ( resp )
} )
}
}
// WithHealthzEndpoint returns a ServeMuxOption that will add a /healthz endpoint to the created ServeMux.
//
// See WithHealthEndpointAt for the general implementation.
func WithHealthzEndpoint ( healthCheckClient grpc_health_v1 . HealthClient ) ServeMuxOption {
return WithHealthEndpointAt ( healthCheckClient , "/healthz" )
}
2023-05-09 18:19:48 +01:00
// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux ( opts ... ServeMuxOption ) * ServeMux {
serveMux := & ServeMux {
handlers : make ( map [ string ] [ ] handler ) ,
forwardResponseOptions : make ( [ ] func ( context . Context , http . ResponseWriter , proto . Message ) error , 0 ) ,
marshalers : makeMarshalerMIMERegistry ( ) ,
errorHandler : DefaultHTTPErrorHandler ,
streamErrorHandler : DefaultStreamErrorHandler ,
routingErrorHandler : DefaultRoutingErrorHandler ,
unescapingMode : UnescapingModeDefault ,
}
for _ , opt := range opts {
opt ( serveMux )
}
if serveMux . incomingHeaderMatcher == nil {
serveMux . incomingHeaderMatcher = DefaultHeaderMatcher
}
if serveMux . outgoingHeaderMatcher == nil {
serveMux . outgoingHeaderMatcher = func ( key string ) ( string , bool ) {
return fmt . Sprintf ( "%s%s" , MetadataHeaderPrefix , key ) , true
}
}
return serveMux
}
// Handle associates "h" to the pair of HTTP method and path pattern.
func ( s * ServeMux ) Handle ( meth string , pat Pattern , h HandlerFunc ) {
s . handlers [ meth ] = append ( [ ] handler { { pat : pat , h : h } } , s . handlers [ meth ] ... )
}
// HandlePath allows users to configure custom path handlers.
// refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/
func ( s * ServeMux ) HandlePath ( meth string , pathPattern string , h HandlerFunc ) error {
compiler , err := httprule . Parse ( pathPattern )
if err != nil {
return fmt . Errorf ( "parsing path pattern: %w" , err )
}
tp := compiler . Compile ( )
pattern , err := NewPattern ( tp . Version , tp . OpCodes , tp . Pool , tp . Verb )
if err != nil {
return fmt . Errorf ( "creating new pattern: %w" , err )
}
s . Handle ( meth , pattern , h )
return nil
}
2023-09-07 12:20:37 +01:00
// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.URL.Path.
2023-05-09 18:19:48 +01:00
func ( s * ServeMux ) ServeHTTP ( w http . ResponseWriter , r * http . Request ) {
ctx := r . Context ( )
path := r . URL . Path
if ! strings . HasPrefix ( path , "/" ) {
_ , outboundMarshaler := MarshalerForRequest ( s , r )
s . routingErrorHandler ( ctx , s , outboundMarshaler , w , r , http . StatusBadRequest )
return
}
// TODO(v3): remove UnescapingModeLegacy
if s . unescapingMode != UnescapingModeLegacy && r . URL . RawPath != "" {
path = r . URL . RawPath
}
if override := r . Header . Get ( "X-HTTP-Method-Override" ) ; override != "" && s . isPathLengthFallback ( r ) {
r . Method = strings . ToUpper ( override )
if err := r . ParseForm ( ) ; err != nil {
_ , outboundMarshaler := MarshalerForRequest ( s , r )
sterr := status . Error ( codes . InvalidArgument , err . Error ( ) )
s . errorHandler ( ctx , s , outboundMarshaler , w , r , sterr )
return
}
}
2023-09-07 12:20:37 +01:00
var pathComponents [ ] string
// since in UnescapeModeLegacy, the URL will already have been fully unescaped, if we also split on "%2F"
// in this escaping mode we would be double unescaping but in UnescapingModeAllCharacters, we still do as the
// path is the RawPath (i.e. unescaped). That does mean that the behavior of this function will change its default
// behavior when the UnescapingModeDefault gets changed from UnescapingModeLegacy to UnescapingModeAllExceptReserved
if s . unescapingMode == UnescapingModeAllCharacters {
pathComponents = encodedPathSplitter . Split ( path [ 1 : ] , - 1 )
} else {
pathComponents = strings . Split ( path [ 1 : ] , "/" )
}
lastPathComponent := pathComponents [ len ( pathComponents ) - 1 ]
2023-05-09 18:19:48 +01:00
for _ , h := range s . handlers [ r . Method ] {
// If the pattern has a verb, explicitly look for a suffix in the last
// component that matches a colon plus the verb. This allows us to
// handle some cases that otherwise can't be correctly handled by the
// former LastIndex case, such as when the verb literal itself contains
// a colon. This should work for all cases that have run through the
// parser because we know what verb we're looking for, however, there
// are still some cases that the parser itself cannot disambiguate. See
// the comment there if interested.
2023-09-07 12:20:37 +01:00
var verb string
2023-05-09 18:19:48 +01:00
patVerb := h . pat . Verb ( )
2023-09-07 12:20:37 +01:00
idx := - 1
if patVerb != "" && strings . HasSuffix ( lastPathComponent , ":" + patVerb ) {
idx = len ( lastPathComponent ) - len ( patVerb ) - 1
2023-05-09 18:19:48 +01:00
}
if idx == 0 {
_ , outboundMarshaler := MarshalerForRequest ( s , r )
s . routingErrorHandler ( ctx , s , outboundMarshaler , w , r , http . StatusNotFound )
return
}
2023-09-07 12:20:37 +01:00
comps := make ( [ ] string , len ( pathComponents ) )
copy ( comps , pathComponents )
2023-05-09 18:19:48 +01:00
if idx > 0 {
2023-09-07 12:20:37 +01:00
comps [ len ( comps ) - 1 ] , verb = lastPathComponent [ : idx ] , lastPathComponent [ idx + 1 : ]
2023-05-09 18:19:48 +01:00
}
2023-09-07 12:20:37 +01:00
pathParams , err := h . pat . MatchAndEscape ( comps , verb , s . unescapingMode )
2023-05-09 18:19:48 +01:00
if err != nil {
var mse MalformedSequenceError
if ok := errors . As ( err , & mse ) ; ok {
_ , outboundMarshaler := MarshalerForRequest ( s , r )
s . errorHandler ( ctx , s , outboundMarshaler , w , r , & HTTPStatusError {
HTTPStatus : http . StatusBadRequest ,
Err : mse ,
} )
}
continue
}
h . h ( w , r , pathParams )
return
}
2023-09-07 12:20:37 +01:00
// if no handler has found for the request, lookup for other methods
// to handle POST -> GET fallback if the request is subject to path
// length fallback.
// Note we are not eagerly checking the request here as we want to return the
// right HTTP status code, and we need to process the fallback candidates in
// order to do that.
2023-05-09 18:19:48 +01:00
for m , handlers := range s . handlers {
if m == r . Method {
continue
}
for _ , h := range handlers {
2023-09-07 12:20:37 +01:00
var verb string
patVerb := h . pat . Verb ( )
idx := - 1
if patVerb != "" && strings . HasSuffix ( lastPathComponent , ":" + patVerb ) {
idx = len ( lastPathComponent ) - len ( patVerb ) - 1
}
comps := make ( [ ] string , len ( pathComponents ) )
copy ( comps , pathComponents )
if idx > 0 {
comps [ len ( comps ) - 1 ] , verb = lastPathComponent [ : idx ] , lastPathComponent [ idx + 1 : ]
}
pathParams , err := h . pat . MatchAndEscape ( comps , verb , s . unescapingMode )
2023-05-09 18:19:48 +01:00
if err != nil {
var mse MalformedSequenceError
if ok := errors . As ( err , & mse ) ; ok {
_ , outboundMarshaler := MarshalerForRequest ( s , r )
s . errorHandler ( ctx , s , outboundMarshaler , w , r , & HTTPStatusError {
HTTPStatus : http . StatusBadRequest ,
Err : mse ,
} )
}
continue
}
2023-09-07 12:20:37 +01:00
2023-05-09 18:19:48 +01:00
// X-HTTP-Method-Override is optional. Always allow fallback to POST.
2023-09-07 12:20:37 +01:00
// Also, only consider POST -> GET fallbacks, and avoid falling back to
// potentially dangerous operations like DELETE.
if s . isPathLengthFallback ( r ) && m == http . MethodGet {
2023-05-09 18:19:48 +01:00
if err := r . ParseForm ( ) ; err != nil {
_ , outboundMarshaler := MarshalerForRequest ( s , r )
sterr := status . Error ( codes . InvalidArgument , err . Error ( ) )
s . errorHandler ( ctx , s , outboundMarshaler , w , r , sterr )
return
}
h . h ( w , r , pathParams )
return
}
_ , outboundMarshaler := MarshalerForRequest ( s , r )
s . routingErrorHandler ( ctx , s , outboundMarshaler , w , r , http . StatusMethodNotAllowed )
return
}
}
_ , outboundMarshaler := MarshalerForRequest ( s , r )
s . routingErrorHandler ( ctx , s , outboundMarshaler , w , r , http . StatusNotFound )
}
// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
func ( s * ServeMux ) GetForwardResponseOptions ( ) [ ] func ( context . Context , http . ResponseWriter , proto . Message ) error {
return s . forwardResponseOptions
}
func ( s * ServeMux ) isPathLengthFallback ( r * http . Request ) bool {
return ! s . disablePathLengthFallback && r . Method == "POST" && r . Header . Get ( "Content-Type" ) == "application/x-www-form-urlencoded"
}
type handler struct {
pat Pattern
h HandlerFunc
}