Strict servers should have some support for early middlewares
By default bodies get decoded, this runs the risk of malicious large request bodies, or non-validated bodies. It should be(and is) possible to filter those requests. Unfortunately, the way to filter them is by wrapping api.HandlerFromMux with a custom handler. The problem with this is potentially "expensive"(marginally expensive but still needless waste) validation operations (say looking up a whitelisted/blacklisted IP) could be wasted on paths that don't even exist.
FYI, for anyone needing this it can be done with this template(for net/http), here's my somewhat tested strict-http.tmpl:
type StrictHandlerFunc = strictnethttp.StrictHTTPHandlerFunc
type StrictMiddlewareFunc = strictnethttp.StrictHTTPMiddlewareFunc
type StrictEarlyHandlerFunc = func(ctx context.Context, w http.ResponseWriter, r *http.Request) (response interface{}, requestErr error, responseErr error)
type StrictEarlywareFunc = func(f StrictEarlyHandlerFunc, operationID string) StrictEarlyHandlerFunc
type StrictHTTPServerOptions struct {
RequestErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
ResponseErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}
func NewStrictHandler(ssi StrictServerInterface, middlewares []StrictMiddlewareFunc, earlywares []StrictEarlywareFunc) ServerInterface {
return &strictHandler{ssi: ssi, middlewares: middlewares, earlywares: earlywares, options: StrictHTTPServerOptions {
RequestErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, err.Error(), http.StatusBadRequest)
},
ResponseErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, err.Error(), http.StatusInternalServerError)
},
}}
}
func NewStrictHandlerWithOptions(ssi StrictServerInterface, middlewares []StrictMiddlewareFunc, earlywares []StrictEarlywareFunc, options StrictHTTPServerOptions) ServerInterface {
return &strictHandler{ssi: ssi, middlewares: middlewares, earlywares: earlywares, options: options}
}
type strictHandler struct {
ssi StrictServerInterface
middlewares []StrictMiddlewareFunc
earlywares []StrictEarlywareFunc
options StrictHTTPServerOptions
}
{{range .}}
{{$opid := .OperationId}}
// {{$opid}} operation middleware
func (sh *strictHandler) {{.OperationId}}(w http.ResponseWriter, r *http.Request{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params {{.OperationId}}Params{{end}}) {
handler := func (ctx context.Context, w http.ResponseWriter, r *http.Request) (interface{}, error, error) {
var request {{$opid | ucFirst}}RequestObject
{{range .PathParams -}}
request.{{.GoName}} = {{.GoVariableName}}
{{end -}}
{{if .RequiresParamObject -}}
request.Params = params
{{end -}}
{{ if .HasMaskedRequestContentTypes -}}
request.ContentType = r.Header.Get("Content-Type")
{{end -}}
{{$multipleBodies := gt (len .Bodies) 1 -}}
{{range .Bodies -}}
{{if $multipleBodies}}if strings.HasPrefix(r.Header.Get("Content-Type"), "{{.ContentType}}") { {{end}}
{{if .IsJSON }}
var body {{$opid}}{{.NameTag}}RequestBody
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("can't decode JSON body: %w", err), nil
}
request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = &body
{{else if eq .NameTag "Formdata" -}}
if err := r.ParseForm(); err != nil {
return nil, fmt.Errorf("can't decode formdata: %w", err), nil
}
var body {{$opid}}{{.NameTag}}RequestBody
if err := runtime.BindForm(&body, r.Form, nil, nil); err != nil {
return nil, fmt.Errorf("can't bind formdata: %w", err), nil
}
request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = &body
{{else if eq .NameTag "Multipart" -}}
{{if eq .ContentType "multipart/form-data" -}}
if reader, err := r.MultipartReader(); err != nil {
return nil, fmt.Errorf("can't decode multipart body: %w", err), nil
} else {
request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = reader
}
{{else -}}
if _, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")); err != nil {
return nil, err, nil
} else if boundary := params["boundary"]; boundary == "" {
return nil, http.ErrMissingBoundary, nil
} else {
request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = multipart.NewReader(r.Body, boundary)
}
{{end -}}
{{else if eq .NameTag "Text" -}}
data, err := io.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("can't read body: %w", err), nil
}
body := {{$opid}}{{.NameTag}}RequestBody(data)
request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = &body
{{else -}}
request.{{if $multipleBodies}}{{.NameTag}}{{end}}Body = r.Body
{{end}}{{/* if eq .NameTag "JSON" */ -}}
{{if $multipleBodies}}}{{end}}
{{end}}{{/* range .Bodies */}}
handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) {
resp, respErr := sh.ssi.{{.OperationId}}(ctx, request.({{$opid | ucFirst}}RequestObject))
return resp, respErr
}
for _, middleware := range sh.middlewares {
handler = middleware(handler, "{{.OperationId}}")
}
resp, respErr := handler(ctx, w, r, request)
return resp, nil, respErr
}
for _, earlyware := range sh.earlywares {
handler = earlyware(handler, "{{.OperationId}}")
}
response, reqErr, respErr := handler(r.Context(), w, r)
if reqErr != nil {
sh.options.RequestErrorHandlerFunc(w, r, reqErr)
} else if respErr != nil {
sh.options.ResponseErrorHandlerFunc(w, r, respErr)
} else if validResponse, ok := response.({{$opid | ucFirst}}ResponseObject); ok {
if err := validResponse.Visit{{$opid}}Response(w); err != nil {
sh.options.ResponseErrorHandlerFunc(w, r, err)
}
} else if response != nil {
sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response))
}
}
{{end}}
And an example "Earlyware" with github.com/pb33f/libopenapi-validator:
package api
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/pb33f/libopenapi"
validator "github.com/pb33f/libopenapi-validator"
)
func CreateValidatorEarlyware() (StrictEarlywareFunc, error) {
byteSpec, err := rawSpec() // from generated (requires `embedded-spec: true`)
if err != nil {
return nil, fmt.Errorf("gen failed to decode raw spec: %v", err)
}
spec, err := libopenapi.NewDocument(byteSpec)
if err != nil {
return nil, fmt.Errorf("libopeapi failed to parse document from spec: %v", err)
}
apiValidator, validatorErrs := validator.NewValidator(spec)
if len(validatorErrs) > 0 {
return nil, fmt.Errorf(
"libopenapi-validator couldn't create a validator: %v",
errors.Join(validatorErrs...),
)
}
return func(next StrictEarlyHandlerFunc, operation string) StrictEarlyHandlerFunc {
return func(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
) (interface{}, error, error) {
// Protect from some types of DOS attacks using MaxBytesReader
// because ValidateHttpRequest uses io.ReadAll
r.Body = http.MaxBytesReader(w, r.Body, 1_000_000) // 1 MB
// But still use `MaxHeaderBytes` since this is only for Body
// And use `ReadTimeout` to protect from slow loris
requestValid, validationErrors := apiValidator.ValidateHttpRequest(r)
if !requestValid {
var totalErr error
for _, validationErr := range validationErrors {
// errors.Join(validationErrors...) doesn't work, I guess
// because of the pointers
totalErr = errors.Join(totalErr, validationErr)
}
return nil, totalErr, nil
}
return next(ctx, w, r)
}
}, nil
}
Also, I had initially thought that since the validator closed the r.Body that what would be problem, but from my testing, it isn't. Just in case it is in some cases(though I don't think it does), this validator should not have that problem with it :
type wrapReaderCloser struct {
Reader io.Reader
Closer io.Closer
}
func (rc *wrapReaderCloser) Read(p []byte) (int, error) {
return rc.Reader.Read(p)
}
func (rc *wrapReaderCloser) Close() error {
return rc.Closer.Close()
}
func CreateValidatorEarlyware() (StrictEarlywareFunc, error) {
byteSpec, err := rawSpec() // from generated (requires `embedded-spec: true`)
if err != nil {
return nil, fmt.Errorf("gen failed to decode raw spec: %v", err)
}
spec, err := libopenapi.NewDocument(byteSpec)
if err != nil {
return nil, fmt.Errorf("libopeapi failed to parse document from spec: %v", err)
}
apiValidator, validatorErrs := validator.NewValidator(spec)
if len(validatorErrs) > 0 {
return nil, fmt.Errorf(
"libopenapi-validator couldn't create a validator: %v",
errors.Join(validatorErrs...),
)
}
return func(next StrictEarlyHandlerFunc, operation string) StrictEarlyHandlerFunc {
return func(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
) (interface{}, error, error) {
// Split the reader and write it to a buffer so we can reread it
bodyBuf := new(bytes.Buffer)
// Protect from some types of DOS attacks using MaxBytesReader
// because ValidateHttpRequest uses io.ReadAll
r.Body = &wrapReaderCloser{
Reader: io.TeeReader(http.MaxBytesReader(w, r.Body, 1_000_000), bodyBuf),
Closer: r.Body,
}
// But still use `MaxHeaderBytes` since this is only for Body
// And use `ReadTimeout` to protect from slow loris
requestValid, validationErrors := apiValidator.ValidateHttpRequest(r)
r.Body = &wrapReaderCloser{
Reader: bodyBuf,
Closer: r.Body,
}
if !requestValid {
var totalErr error
for _, validationErr := range validationErrors {
// errors.Join(validationErrors...) doesn't work, I guess
// because of the pointers
totalErr = errors.Join(totalErr, validationErr)
}
return nil, totalErr, nil
}
return next(ctx, w, r)
}
}, nil
}