fastapi-jwt icon indicating copy to clipboard operation
fastapi-jwt copied to clipboard

Need the ability to handle revoked tokens

Open Seluj78 opened this issue 1 year ago • 1 comments

Hi,

First time using FastAPI and your library. I come from Flask and I used flask-jwt-extended which had many nice to haves that are kind of missing (like @jwt_required, get_current_user, the current_user proxy, get_jti and many more) but more specifically to this issue, I would like to be able to validate if an access token is still valid or has been revoked.

For reference

https://flask-jwt-extended.readthedocs.io/en/stable/blocklist_and_token_revoking.html

I think it'd be a great addition !

Seluj78 avatar Oct 02 '24 09:10 Seluj78

You don't include any example on how to log out a user in your examples documentation: https://k4black.github.io/fastapi-jwt/user_guide/examples/

I would usually do it like so with Flask:

import logging
from typing import Optional

from flask_jwt_extended import get_jti

from nexus import settings
from nexus.utils.redis import get_redis_client


# TODO: Change this function to be either a normal logout or a force logout for all of the user's tokens
def do_logout_user(
    user_id: int, access_token: Optional[str] = None, refresh_token: Optional[str] = None, force_logout: bool = False
) -> None:
    logging.info(f"Logging out user {user_id}")
    redis_client = get_redis_client()
    if force_logout is False:
        if access_token is None or refresh_token is None:  # pragma: no cover
            raise ValueError("Access token and refresh token are required for a normal logout")
        elif access_token is not None and refresh_token is not None:
            access_jti = get_jti(access_token)
            refresh_jti = get_jti(refresh_token)
            redis_client.set(f"is_revoked_jti:{access_jti}|{user_id}", "true", settings.JWT_ACCESS_TOKEN_EXPIRES * 1.2)
            redis_client.set(f"is_revoked_jti:{refresh_jti}|{user_id}", "true", settings.JWT_ACCESS_TOKEN_EXPIRES * 1.2)
    else:
        # Use a pipeline to batch commands
        pipe = redis_client.pipeline()

        # Use scan to iteratively go through keys

        # Set cursor to -1 to start the first iteration
        cursor = 0
        pattern = f"is_revoked_jti:*|{user_id}"
        while True:
            cursor, keys = redis_client.scan(cursor=cursor, match=pattern, count=100)
            for key in keys:
                # Set the new value in the pipeline
                pipe.set(key, "true", settings.JWT_REFRESH_TOKEN_EXPIRES * 1.2)
            if cursor == 0:
                break
        # Execute all commands in the pipeline
        pipe.execute()
    logging.info(f"User {user_id} logged out")


import logging

from flask import Response
from flask import Blueprint
from flask import jsonify
from flask import request
from flask_utils import validate_params
from flask_jwt_extended import current_user
from flask_jwt_extended import decode_token
from flask_jwt_extended import jwt_required

from nexus.modules.auth import do_logout_user

auth_logout_v1_bp = Blueprint("auth_logout_v1_bp", __name__)


@auth_logout_v1_bp.route("/logout", methods=["POST"])
@validate_params({"access_token": str, "refresh_token": str})
def logout() -> Response:
    data = request.get_json()
    access_token = data["access_token"]
    refresh_token = data["refresh_token"]

    user_id = int(decode_token(access_token)["sub"]["id"])
    do_logout_user(user_id, access_token, refresh_token, force_logout=False)
    logging.info(f"Logout success for user {user_id}")

    ret = jsonify(
        {
            "message": "Successfully logged out.",
        }
    )
    ret.status_code = 200
    return ret


@auth_logout_v1_bp.route("/logout/force", methods=["POST"])
@jwt_required()
def force_logout() -> Response:
    do_logout_user(current_user.id, None, None, force_logout=True)
    logging.info(f"Force logout success for user {current_user.id}")
    ret = jsonify(
        {
            "message": "Successfully logged out.",
        }
    )
    ret.status_code = 200
    return ret

We definetly need something like this

Seluj78 avatar Oct 02 '24 09:10 Seluj78