djangorestframework-simplejwt
djangorestframework-simplejwt copied to clipboard
Allow other header claims in tokens
As defined in RFC7515, section 4.1, tokens can include several more header claims than just typ
and alg
as allowed from this.
I have tried to include a kid
one as I use signed token but I couldn't.
Using pyjwt I was able to add it to the token string but when I called RefreshToken(token)
constructor it removed all custom headers.
I have checked in the doc and nothing seems to cover this use case.
I haven't digged much in the code though.
As for kid
claim, I suggest to include it by default in header when the token is signed.
(AuthLib documentation for reference)
This is somehow related to #491 as kid
might be useful when combined with JWK endpoint.
Seems that the first part of my issue can be done using what has been done in !517
Sadly I couldn't find any issue related to this.
Any clue when this will be on Pypi?
Well after digging into the code I have managed to include the kid
header claim as I wanted without using what's in !517.
I've had to redefine quite a few classes.
views.py
import jwt
import rest_framework_simplejwt.views as original_views
from authlib.jose import JsonWebKey
from django.conf import settings
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.serializers import (TokenObtainPairSerializer,
TokenRefreshSerializer)
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import AccessToken, RefreshToken, Token
class TokenBackendWithHeaders(TokenBackend):
def encode(self, payload, headers={}):
"""
Returns an encoded token for the given payload dictionary.
"""
jwt_payload = payload.copy()
if self.audience is not None:
jwt_payload["aud"] = self.audience
if self.issuer is not None:
jwt_payload["iss"] = self.issuer
token = jwt.encode(jwt_payload, self.signing_key,
algorithm=self.algorithm, headers=headers)
if isinstance(token, bytes):
# For PyJWT <= 1.7.1
return token.decode("utf-8")
# For PyJWT >= 2.0.0a1
return token
class TokenWithAnotherTokenBackend(Token):
_token_backend = TokenBackendWithHeaders(
api_settings.ALGORITHM,
api_settings.SIGNING_KEY,
api_settings.VERIFYING_KEY,
api_settings.AUDIENCE,
api_settings.ISSUER,
api_settings.JWK_URL,
api_settings.LEEWAY,
)
def __init__(self, token=None, verify=True):
Token.__init__(self, token, verify)
self.headers = {}
def __str__(self):
"""
Signs and returns a token as a base64 encoded string.
"""
return self.get_token_backend().encode(self.payload, self.headers)
class AccessTokenWithAnotherTokenBackend(AccessToken, TokenWithAnotherTokenBackend):
pass
class RefreshTokenWithAnotherTokenBackend(RefreshToken, TokenWithAnotherTokenBackend):
@property
def access_token(self):
"""
Returns an access token created from this refresh token. Copies all
claims present in this refresh token to the new access token except
those claims listed in the `no_copy_claims` attribute.
"""
access = AccessTokenWithAnotherTokenBackend()
# Use instantiation time of refresh token as relative timestamp for
# access token "exp" claim. This ensures that both a refresh and
# access token expire relative to the same time if they are created as
# a pair.
access.set_exp(from_time=self.current_time)
no_copy = self.no_copy_claims
for claim, value in self.payload.items():
if claim in no_copy:
continue
access[claim] = value
for claim, value in self.headers.items():
access.headers[claim] = value
return access
class TokenObtainPairSerializerDifferentToken(TokenObtainPairSerializer):
token_class = RefreshTokenWithAnotherTokenBackend
@classmethod
def get_token(cls, user):
key = JsonWebKey.import_key(
settings.SIMPLE_JWT['VERIFYING_KEY'], {'kty': 'RSA'})
token = cls.token_class.for_user(user)
# Add custom header claims
token.headers['kid'] = key.thumbprint()
return token
class TokenRefreshSerializerDifferentToken(TokenRefreshSerializer):
# Needed to redifine all of this due to the hardcoded "RefreshToken" in
# the original code. Replaced here by "RefreshTokenWithAnotherTokenBackend"
# PR for fixing this was already merged. New version of simple-jwt should
# include changes contained in
# https://github.com/jazzband/djangorestframework-simplejwt/pull/517
def validate(self, attrs):
refresh = RefreshTokenWithAnotherTokenBackend(attrs['refresh'])
data = {'access': str(refresh.access_token)}
if api_settings.ROTATE_REFRESH_TOKENS:
if api_settings.BLACKLIST_AFTER_ROTATION:
try:
# Attempt to blacklist the given refresh token
refresh.blacklist()
except AttributeError:
# If blacklist app not installed, `blacklist` method will
# not be present
pass
refresh.set_jti()
refresh.set_exp()
refresh.set_iat()
data['refresh'] = str(refresh)
return data
class TokenObtainPairView(original_views.TokenObtainPairView):
serializer_class = TokenObtainPairSerializerDifferentToken
class TokenRefreshView(original_views.TokenRefreshView):
serializer_class = TokenRefreshSerializerDifferentToken
urls.py
"""e_abeilles URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.0/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.urls import path, include
from rest_framework_simplejwt import views as jwt_views
from my_package import views
urlpatterns = [
path('admin/', admin.site.urls),
path('api/token/', views.TokenObtainPairView.as_view(),
name='token_obtain_pair'),
path('api/token/refresh/', views.TokenRefreshView.as_view(),
name='token_refresh'),
]
Could this be included in the base code? I can open a PR if you wish!
What we’ve done in the past is have a callable or a dotted import string in SIMPLE_JWT settings. In the serializer, we can pass the token to your function. This is similar to the authorization callable.
@Andrew-Chen-Wang That might be possible but I don't think this is the way to go as it involves encoding a token
-> sending it to the callback
-> decoding it
-> adding a header while reencoding it
-> sending it back
.
Performancewise, adding a header before encoding it would be much better, don't you think?
Yes, it definitely would be. I just worry about the ordering and people missing something with override classes. But please open a PR and we shall deliberate :)
For anyone interested, here is the same for sliding tokens
import jwt
import rest_framework_simplejwt.views as original_views
from authlib.jose import JsonWebKey
from django.conf import settings
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.serializers import TokenObtainSlidingSerializer
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import SlidingToken, Token
class TokenBackendWithHeaders(TokenBackend):
def encode(self, payload, headers={}):
"""
Returns an encoded token for the given payload dictionary.
"""
jwt_payload = payload.copy()
if self.audience is not None:
jwt_payload["aud"] = self.audience
if self.issuer is not None:
jwt_payload["iss"] = self.issuer
token = jwt.encode(jwt_payload, self.signing_key,
algorithm=self.algorithm, headers=headers)
if isinstance(token, bytes):
# For PyJWT <= 1.7.1
return token.decode("utf-8")
# For PyJWT >= 2.0.0a1
return token
class TokenWithAnotherTokenBackend(Token):
_token_backend = TokenBackendWithHeaders(
api_settings.ALGORITHM,
api_settings.SIGNING_KEY,
api_settings.VERIFYING_KEY,
api_settings.AUDIENCE,
api_settings.ISSUER,
api_settings.JWK_URL,
api_settings.LEEWAY,
)
def __init__(self, token=None, verify=True):
Token.__init__(self, token, verify)
self.headers = {}
def __str__(self):
"""
Signs and returns a token as a base64 encoded string.
"""
return self.get_token_backend().encode(self.payload, self.headers)
class SlidingokenWithAnotherTokenBackend(SlidingToken, TokenWithAnotherTokenBackend):
pass
class TokenObtainSlidingSerializerDifferentToken(TokenObtainSlidingSerializer):
token_class = SlidingokenWithAnotherTokenBackend
@classmethod
def get_token(cls, user):
key = JsonWebKey.import_key(
settings.SIMPLE_JWT['VERIFYING_KEY'], {'kty': 'RSA'})
token = cls.token_class.for_user(user)
# Add custom header claims
token.headers['kid'] = key.thumbprint()
return token
class TokenObtainSlidingView(original_views.TokenObtainPairView):
serializer_class = TokenObtainSlidingSerializerDifferentToken
Has this been incorporated or solved in the latest codebase as I am currently facing the exact same issue of trying to add a 'kid' claim into the header of the signed token? So strange that this is not mentioned anywhere in the docs.
This is not implemented.
Would a new settings EXTRA_JWT_HEADERS
be a solution?
I'm facing this problem where I want to add kid
in the headers.
At this point it would be good to just have the kid
header always added by default no?
Its part of the JWK standard now
https://datatracker.ietf.org/doc/html/rfc7517#section-4.5