oapi-codegen icon indicating copy to clipboard operation
oapi-codegen copied to clipboard

Adding Values to Context in Chi Middleware

Open petersng opened this issue 3 years ago • 1 comments

Hi, I've been generating chi code with authentication (OapiRequestValidatorWithOptions with AuthenticationFunc). I can get to the point where we can verify the bearer token, but is it possible to add a value to the context within the AuthenticationFunc? For instance, a user auths in, we validate, then store some user id or user data into the context to be passed to the next handler and the rest of the request.

Is this possible with the chi code? I see in the echo and gin code there is access to those contexts.

Thanks!

petersng avatar Nov 20 '22 12:11 petersng

Actually, I can't get it working in gin. Auth middleware receives context.Context, which cannot be changed, and it can't be cast to gin.Context

NikSays avatar Jan 07 '23 19:01 NikSays

@NikSays where you able to solve the issues related to the context.Context not being able to be parsed as gin.Context?

marianozunino avatar Feb 15 '23 05:02 marianozunino

@marianozunino Thankfully, yes Package gin-middleware contains value GinContextKey.

import (
   ...
   gmw "github.com/deepmap/oapi-codegen/pkg/gin-middleware"
)
func SomeHandler(ctx context.Context) error {
  ginCtx, ok := ctx.Value(gmw.GinContextKey).(*gin.Context)
  if !ok {
    return fmt.Errorf("couldn't get context")
  }
  ginCtx.Set("key", "value")

  return nil
}

NikSays avatar Feb 15 '23 05:02 NikSays

@petersng got the same problem, The only workaround I see is to create your own context store(pointer to a structure or a map), which you can mutate and add authentication information.

ilya-hontarau avatar Jun 20 '23 17:06 ilya-hontarau

@petersng did you find a solution for this?

pcriv avatar Jun 26 '23 11:06 pcriv

same, I have trouble adding when using chi

marmiha avatar Sep 12 '23 11:09 marmiha

zzI had the same issue (in terms of not being able to add context), my temporary solution was to add a middleware that checks if security scopes had been set by the oapi generated operation handler, which meant my authentication checks passed and I could safely add auth details (in my case a decoded jwt) to the request context.

func main(){
	swagger, err := api.GetSwagger()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Error loading swagger spec\n: %s", err)
		os.Exit(1)
	}

	// Create an instance of our handler which satisfies the generated interface
	serverImpl := api.NewServerImpl()

	// setup chi router
	r := chi.NewRouter()

	r.Use(oapi_middleware.OapiRequestValidatorWithOptions(swagger,
			&oapi_middleware.Options{
				Options: openapi3filter.Options{
					AuthenticationFunc: func(c context.Context, input *openapi3filter.AuthenticationInput) error {
						// your token validation logic here based on security scopes...
						return nil
					},
				},
			},
		))
	api.HandlerWithOptions(serverImpl, api.ChiServerOptions{
		BaseRouter:  r,
		BaseURL:     config.ServerBaseUrl,
		Middlewares: []api.MiddlewareFunc{WithTokenAuth}, // <----- WithTokenAuth setter middleware
	})

	http.ListenAndServe(":8080", r)
}


func WithTokenAuth(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if _, ok := r.Context().Value(api.BearerAuthScopes).([]string); ok {
			// BearerAuthScopes is set, so we can safely add the token to the
			// request context.
			tokenString := jwtauth.TokenFromHeader(r) 
			token, _ := api.AuthStore.Ja.Decode(tokenString)
			r = r.WithContext(jwtauth.NewContext(r.Context(), token, nil))
		}
		next.ServeHTTP(w, r)
	})
}

bartventer avatar Sep 13 '23 15:09 bartventer

I've just added a middleware before the openapi3 authentication function that injects a mutable user context struct, that can be then actually accessed after the authentication call. The auth call populates it.

The Initialize is called first before the authentication function in a middleware, then you can set and get.

type (
	ContextField string

	AuthContextSetter interface {
		SetAccountInfo(ctx context.Context, accountInfo AccountInfo) error
	}

	AuthContextGetter interface {
		GetAccountInfo(ctx context.Context) (AccountInfo, error)
	}

	AuthContextInitializer interface {
		Initialize(ctx context.Context) (context.Context, error)
	}

	AuthContextProvider interface {
		AuthContextSetter
		AuthContextGetter
		AuthContextInitializer
	}

	authContextProvider struct{}

	AuthContext struct {
		mu          sync.RWMutex
		accountInfo AccountInfo
	}
)

var (
	_ AuthContextSetter      = (*authContextProvider)(nil)
	_ AuthContextGetter      = (*authContextProvider)(nil)
	_ AuthContextProvider    = (*authContextProvider)(nil)
	_ AuthContextInitializer = (*authContextProvider)(nil)
)

func NewAuthContextProvider() *authContextProvider {
	return &authContextProvider{}
}

// Initialize implements AuthContextInitializer.
func (*authContextProvider) Initialize(ctx context.Context) (context.Context, error) {
	// Check if the auth context is already set
	authContextValue := ctx.Value(authContextField)
	if authContextValue != nil {
		return nil, ErrAlreadyInitialized
	}

	// Initialize the auth context
	authContext := &AuthContext{}
	ctx = context.WithValue(ctx, authContextField, authContext)

	return ctx, nil
}

// GetAccountInfo implements AuthContextGetter.
func (*authContextProvider) GetAccountInfo(ctx context.Context) (AccountInfo, error) {
	authContextValue := ctx.Value(authContextField)

	if authContextValue == nil {
		return AccountInfo{}, ErrMissingAuthContext
	}

	authContext, ok := authContextValue.(*AuthContext)
	if !ok {
		return AccountInfo{}, ErrInvalidAuthContextFormat
	}

	authContext.mu.RLock()
	defer authContext.mu.RUnlock()

	return authContext.accountInfo, nil
}

// SetAccountInfo implements AuthContextSetter.
func (*authContextProvider) SetAccountInfo(ctx context.Context, accountInfo AccountInfo) error {
	authContextValue := ctx.Value(authContextField)

	if authContextValue == nil {
		return ErrMissingAuthContext
	}

	authContext, ok := authContextValue.(*AuthContext)
	if !ok {
		return ErrInvalidAuthContextFormat
	}

	// Set the account info
	authContext.mu.Lock()
	defer authContext.mu.Unlock()

	authContext.accountInfo = accountInfo
	return nil
}

marmiha avatar Sep 13 '23 15:09 marmiha

@marmiha Based on the code you shared, I tried writing a slightly simpler version of the code.

First, you write your own context. And write middleware to insert your own context created earlier.

package webcontext

import (
	"context"
	"net/http"
	"sync"
)

const (
	Key = "web-context"
)

type WebContext struct {
	mu sync.RWMutex
	Keys map[string]any
}

func (c *WebContext) Set(key string, value any) {
	c.mu.Lock()
	defer c.mu.Unlock()
	if c.Keys == nil {
		c.Keys = make(map[string]any)
	}

	c.Keys[key] = value
}

func (c *WebContext) Get(key string) (value any, exists bool) {
	c.mu.RLock()
	defer c.mu.RUnlock()
	value, exists = c.Keys[key]
	return
}

func Middleware() func(next http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			r = r.WithContext(context.WithValue(r.Context(), Key, &WebContext{}))
			next.ServeHTTP(w, r)
		})
	}
}

func ContextFromRequest(r *http.Request) *WebContext {
	return r.Context().Value(Key).(*WebContext)
}

Then, configure the middleware created earlier called first before the OapiRequestValidatorWithOptions middleware.

func main() {
	swagger, err := gen.GetSwagger()
	if err != nil {
		log.Fatalf("Unable to get swagger: %v", err)
	}

	router := chi.NewRouter()

	router.Use(webcontext.Middleware())
	router.Use(oapimiddleware.OapiRequestValidatorWithOptions(swagger,
		&oapimiddleware.Options{
			ErrorHandler: errorHandler,
			Options: openapi3filter.Options{
				AuthenticationFunc: func(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
					// Your authentication logic here...
					token := "some token or object"
					webcontext.ContextFromRequest(input.RequestValidationInput.Request).Set("your_key", token)

					return nil
				},
			},
		}),
	)
	
	// Rest of your code here...
}

Finally, you can use the loaded context as shown in the code below.

func (s Server) GetTokenInfo(w http.ResponseWriter, r *http.Request) {
	customCtx := webcontext.ContextFromRequest(r)
	rawToken, ok := customCtx.Get("your_key")
	// Rest of your code here...
}

ugabiga avatar Nov 29 '23 15:11 ugabiga

I have published a module for this: https://github.com/induzo/gocom/tree/main/http/middleware/writablecontext

it works well with the authenticator func.

not sure why we would need a mutex, as the context is only accessed serially in a request.

vincentserpoul avatar Dec 07 '23 17:12 vincentserpoul