oapi-codegen icon indicating copy to clipboard operation
oapi-codegen copied to clipboard

Strict servers should have some support for early middlewares

Open AidanWelch opened this issue 1 year ago • 1 comments

By default bodies get decoded, this runs the risk of malicious large request bodies, or non-validated bodies. It should be(and is) possible to filter those requests. Unfortunately, the way to filter them is by wrapping api.HandlerFromMux with a custom handler. The problem with this is potentially "expensive"(marginally expensive but still needless waste) validation operations (say looking up a whitelisted/blacklisted IP) could be wasted on paths that don't even exist.

FYI, for anyone needing this it can be done with this template(for net/http), here's my somewhat tested strict-http.tmpl:

type StrictHandlerFunc = strictnethttp.StrictHTTPHandlerFunc
type StrictMiddlewareFunc = strictnethttp.StrictHTTPMiddlewareFunc
type StrictEarlyHandlerFunc = func(ctx context.Context, w http.ResponseWriter, r *http.Request) (response interface{}, requestErr error, responseErr error)
type StrictEarlywareFunc = func(f StrictEarlyHandlerFunc, operationID string) StrictEarlyHandlerFunc

type StrictHTTPServerOptions struct {
    RequestErrorHandlerFunc  func(w http.ResponseWriter, r *http.Request, err error)
    ResponseErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}

func NewStrictHandler(ssi StrictServerInterface, middlewares []StrictMiddlewareFunc, earlywares []StrictEarlywareFunc) ServerInterface {
    return &strictHandler{ssi: ssi, middlewares: middlewares, earlywares: earlywares, options: StrictHTTPServerOptions {
        RequestErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) {
            http.Error(w, err.Error(), http.StatusBadRequest)
        },
        ResponseErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) {
            http.Error(w, err.Error(), http.StatusInternalServerError)
        },
    }}
}

func NewStrictHandlerWithOptions(ssi StrictServerInterface, middlewares []StrictMiddlewareFunc, earlywares []StrictEarlywareFunc, options StrictHTTPServerOptions) ServerInterface {
    return &strictHandler{ssi: ssi, middlewares: middlewares, earlywares: earlywares, options: options}
}

type strictHandler struct {
    ssi StrictServerInterface
    middlewares []StrictMiddlewareFunc
	earlywares []StrictEarlywareFunc
    options StrictHTTPServerOptions
}

{{range .}}
    {{$opid := .OperationId}}
    // {{$opid}} operation middleware
	func (sh *strictHandler) {{.OperationId}}(w http.ResponseWriter, r *http.Request{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params {{.OperationId}}Params{{end}}) {
		handler := func (ctx context.Context, w http.ResponseWriter, r *http.Request) (interface{}, error, error) {
			var request {{$opid | ucFirst}}RequestObject

			{{range .PathParams -}}
				request.{{.GoName}} = {{.GoVariableName}}
			{{end -}}

			{{if .RequiresParamObject -}}
				request.Params = params
			{{end -}}

			{{ if .HasMaskedRequestContentTypes -}}
				request.ContentType = r.Header.Get("Content-Type")
			{{end -}}

			{{$multipleBodies := gt (len .Bodies) 1 -}}
			{{range .Bodies -}}
				{{if $multipleBodies}}if strings.HasPrefix(r.Header.Get("Content-Type"), "{{.ContentType}}") { {{end}}
					{{if .IsJSON }}
						var body {{$opid}}{{.NameTag}}RequestBody
						if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
							return nil, fmt.Errorf("can't decode JSON body: %w", err), nil
						}
						request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = &body
					{{else if eq .NameTag "Formdata" -}}
						if err := r.ParseForm(); err != nil {
							return nil, fmt.Errorf("can't decode formdata: %w", err), nil
						}
						var body {{$opid}}{{.NameTag}}RequestBody
						if err := runtime.BindForm(&body, r.Form, nil, nil); err != nil {
							return nil, fmt.Errorf("can't bind formdata: %w", err), nil
						}
						request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = &body
					{{else if eq .NameTag "Multipart" -}}
						{{if eq .ContentType "multipart/form-data" -}}
						if reader, err := r.MultipartReader(); err != nil {
							return nil, fmt.Errorf("can't decode multipart body: %w", err), nil
						} else {
							request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = reader
						}
						{{else -}}
						if _, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")); err != nil {
							return nil, err, nil
						} else if boundary := params["boundary"]; boundary == "" {
							return nil, http.ErrMissingBoundary, nil
						} else {
							request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = multipart.NewReader(r.Body, boundary)
						}
						{{end -}}
					{{else if eq .NameTag "Text" -}}
						data, err := io.ReadAll(r.Body)
						if err != nil {
							return nil, fmt.Errorf("can't read body: %w", err), nil
						}
						body := {{$opid}}{{.NameTag}}RequestBody(data)
						request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = &body
					{{else -}}
						request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = r.Body
					{{end}}{{/* if eq .NameTag "JSON" */ -}}
				{{if $multipleBodies}}}{{end}}
			{{end}}{{/* range .Bodies */}}

			handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) {
				resp, respErr := sh.ssi.{{.OperationId}}(ctx, request.({{$opid | ucFirst}}RequestObject))
				return resp, respErr
			}
			for _, middleware := range sh.middlewares {
				handler = middleware(handler, "{{.OperationId}}")
			}

			resp, respErr := handler(ctx, w, r, request)
			return resp, nil, respErr
		}

		for _, earlyware := range sh.earlywares {
			handler = earlyware(handler, "{{.OperationId}}")
		}

		response, reqErr, respErr := handler(r.Context(), w, r)

		if reqErr != nil {
			sh.options.RequestErrorHandlerFunc(w, r, reqErr)
		} else if respErr != nil {
			sh.options.ResponseErrorHandlerFunc(w, r, respErr)
		} else if validResponse, ok := response.({{$opid | ucFirst}}ResponseObject); ok {
			if err := validResponse.Visit{{$opid}}Response(w); err != nil {
				sh.options.ResponseErrorHandlerFunc(w, r, err)
			}
		} else if response != nil {
			sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response))
		}
	}
{{end}}

And an example "Earlyware" with github.com/pb33f/libopenapi-validator:

package api

import (
	"context"
	"errors"
	"fmt"
	"net/http"

	"github.com/pb33f/libopenapi"
	validator "github.com/pb33f/libopenapi-validator"
)

func CreateValidatorEarlyware() (StrictEarlywareFunc, error) {
	byteSpec, err := rawSpec() // from generated (requires `embedded-spec: true`)
	if err != nil {
		return nil, fmt.Errorf("gen failed to decode raw spec: %v", err)
	}

	spec, err := libopenapi.NewDocument(byteSpec)
	if err != nil {
		return nil, fmt.Errorf("libopeapi failed to parse document from spec: %v", err)
	}

	apiValidator, validatorErrs := validator.NewValidator(spec)
	if len(validatorErrs) > 0 {
		return nil, fmt.Errorf(
			"libopenapi-validator couldn't create a validator: %v",
			errors.Join(validatorErrs...),
		)
	}

	return func(next StrictEarlyHandlerFunc, operation string) StrictEarlyHandlerFunc {
		return func(
			ctx context.Context,
			w http.ResponseWriter,
			r *http.Request,
		) (interface{}, error, error) {

			// Protect from some types of DOS attacks using MaxBytesReader
			// because ValidateHttpRequest uses io.ReadAll
			r.Body = http.MaxBytesReader(w, r.Body, 1_000_000) // 1 MB
			// But still use `MaxHeaderBytes` since this is only for Body
			// And use `ReadTimeout` to protect from slow loris

			requestValid, validationErrors := apiValidator.ValidateHttpRequest(r)

			if !requestValid {
				var totalErr error
				for _, validationErr := range validationErrors {
					// errors.Join(validationErrors...) doesn't work, I guess
					// because of the pointers
					totalErr = errors.Join(totalErr, validationErr)
				}

				return nil, totalErr, nil
			}

			return next(ctx, w, r)
		}
	}, nil
}

AidanWelch avatar Aug 29 '24 00:08 AidanWelch

Also, I had initially thought that since the validator closed the r.Body that what would be problem, but from my testing, it isn't. Just in case it is in some cases(though I don't think it does), this validator should not have that problem with it :

type wrapReaderCloser struct {
	Reader io.Reader
	Closer io.Closer
}

func (rc *wrapReaderCloser) Read(p []byte) (int, error) {
	return rc.Reader.Read(p)
}
func (rc *wrapReaderCloser) Close() error {
	return rc.Closer.Close()
}

func CreateValidatorEarlyware() (StrictEarlywareFunc, error) {
	byteSpec, err := rawSpec() // from generated (requires `embedded-spec: true`)
	if err != nil {
		return nil, fmt.Errorf("gen failed to decode raw spec: %v", err)
	}

	spec, err := libopenapi.NewDocument(byteSpec)
	if err != nil {
		return nil, fmt.Errorf("libopeapi failed to parse document from spec: %v", err)
	}

	apiValidator, validatorErrs := validator.NewValidator(spec)
	if len(validatorErrs) > 0 {
		return nil, fmt.Errorf(
			"libopenapi-validator couldn't create a validator: %v",
			errors.Join(validatorErrs...),
		)
	}

	return func(next StrictEarlyHandlerFunc, operation string) StrictEarlyHandlerFunc {
		return func(
			ctx context.Context,
			w http.ResponseWriter,
			r *http.Request,
		) (interface{}, error, error) {

			// Split the reader and write it to a buffer so we can reread it
			bodyBuf := new(bytes.Buffer)
			// Protect from some types of DOS attacks using MaxBytesReader
			// because ValidateHttpRequest uses io.ReadAll
			r.Body = &wrapReaderCloser{
				Reader: io.TeeReader(http.MaxBytesReader(w, r.Body, 1_000_000), bodyBuf),
				Closer: r.Body,
			}
			// But still use `MaxHeaderBytes` since this is only for Body
			// And use `ReadTimeout` to protect from slow loris

			requestValid, validationErrors := apiValidator.ValidateHttpRequest(r)

			r.Body = &wrapReaderCloser{
				Reader: bodyBuf,
				Closer: r.Body,
			}

			if !requestValid {
				var totalErr error
				for _, validationErr := range validationErrors {
					// errors.Join(validationErrors...) doesn't work, I guess
					// because of the pointers
					totalErr = errors.Join(totalErr, validationErr)
				}

				return nil, totalErr, nil
			}

			return next(ctx, w, r)
		}
	}, nil
}

AidanWelch avatar Aug 29 '24 01:08 AidanWelch