taskiq icon indicating copy to clipboard operation
taskiq copied to clipboard

How to cancel sending a task using middleware

Open Bohdan-Ilchyshyn opened this issue 1 year ago • 1 comments

I create singleton middleware. It checks whether such a task already exists and, if so, should cancel its sending in the pre_send func. How to do it correctly? Return None or raise exception?

Middleware code



import inspect
import time
from hashlib import md5
from typing import Any, Coroutine, Union

from cashews import cache
from loguru import logger
from orjson import orjson
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult


class SingletonMiddleware(TaskiqMiddleware):
    SINGLETON_LABEL = "singleton"
    UNIQUE_ON_LABEL = "unique_on"
    LOCK_EXPIRE_LABEL = "lock_expire"
    KEY_PREFIX = "TKQ_SINGLETON_LOCK_"

    def __init__(
            self,
            default_lock_expire: int = 60,
    ) -> None:
        super().__init__()
        self.default_lock_expire = default_lock_expire

    def pre_send(
        self,
        message: "TaskiqMessage",
    ) -> "Union[TaskiqMessage, Coroutine[Any, Any, TaskiqMessage]]":
        if self.is_singleton_task(message):
            return self.lock_and_run(message)
        else:
            return message

    async def post_execute(
        self,
        message: "TaskiqMessage",
        result: "TaskiqResult[Any]",
    ) -> "Union[None, Coroutine[Any, Any, None]]":
        if self.is_singleton_task(message):
            await self.release_lock(message)
        return None

    async def on_error(
        self,
        message: "TaskiqMessage",
        result: "TaskiqResult[Any]",
        exception: BaseException,
    ) -> "Union[None, Coroutine[Any, Any, None]]":
        if self.is_singleton_task(message):
            await self.release_lock(message)
        return None

    def is_singleton_task(self, message: "TaskiqMessage") -> bool:
        return self.SINGLETON_LABEL in message.labels

    @staticmethod
    async def unlock(lock_key: str, task_id: str) -> bool:
        return await cache.unlock(lock_key, task_id)

    @staticmethod
    async def lock(lock_key: str, task_id: str, expire: int) -> bool:
        return await cache.set_lock(key=lock_key, value=task_id, expire=expire)

    @staticmethod
    async def locked(lock_key: str) -> bool:
        return await cache.is_locked(key=lock_key)

    @staticmethod
    async def get_existing_task_id(lock_key: str) -> int:
        return await cache.get(key=lock_key)

    async def lock_and_run(self, message: TaskiqMessage) -> TaskiqMessage | None:
        lock_acquired = await self.acquire_lock(message)

        if lock_acquired:
            return message
        else:
            lock_key = self.generate_lock(message)
            existing_task_id = self.get_existing_task_id(lock_key)
            logger.warning(f"Attempted to queue a duplicate of task ID {existing_task_id}")
            # raise SendTaskError()
            return None

    async def get_lock_expire(self, message: "TaskiqMessage") -> int:
        if self.LOCK_EXPIRE_LABEL in message.labels:
            return message.labels[self.LOCK_EXPIRE_LABEL]
        elif 'timeout' in message.labels:
            task_timeout = int(message.labels['timeout'])
            task_timeout += 5 * 60
            return task_timeout
        else:
            return self.default_lock_expire

    async def release_lock(self, message: "TaskiqMessage") -> bool:
        lock_key = self.generate_lock(message)
        unlocked = await self.unlock(lock_key, message.task_id)
        return unlocked

    async def acquire_lock(self, message: "TaskiqMessage") -> bool:
        lock_key = self.generate_lock(message)
        lock_expire = await self.get_lock_expire(message)
        locked = await self.lock(lock_key, message.task_id, lock_expire)
        return locked

    @staticmethod
    def generate_lock_key(task_name: str, task_args: list, task_kwargs: dict, key_prefix: str) -> str:
        str_args = str(orjson.dumps(task_args, option=orjson.OPT_SORT_KEYS))
        str_kwargs = str(orjson.dumps(task_kwargs, option=orjson.OPT_SORT_KEYS))
        task_hash = md5((task_name + str_args + str_kwargs).encode()).hexdigest()
        return key_prefix + task_hash

    def generate_lock(self, message: "TaskiqMessage") -> str:
        task = self.broker.find_task(message.task_name)

        if unique_on := message.labels.get('unique_on'):
            if isinstance(unique_on, str):
                unique_on = [unique_on]

            sig = inspect.signature(task.original_func)
            bound = sig.bind(*message.args, **message.kwargs).arguments

            unique_args = []
            unique_kwargs = {key: bound[key] for key in unique_on}

        else:
            unique_args = message.args
            unique_kwargs = message.kwargs

        lock_key = self.generate_lock_key(
            task_name=str(message.task_name),
            task_args=unique_args,
            task_kwargs=unique_kwargs,
            key_prefix=self.KEY_PREFIX,
        )

        return lock_key

Task example

@broker.task(
    singleton=True,
    unique_on=['id', 'name']
)
async def my_singleton_task(id: str, name: str) -> None:
    pass

Bohdan-Ilchyshyn avatar May 20 '24 09:05 Bohdan-Ilchyshyn

@Bohdan-Ilchyshyn that is really helpful example, thanks for sharing, curious if you end up using it? I'm looking for something similar but also want to include check if result backend contains output for that idempotent taskid to cancel task execution, my understanding is thatwait_result()will get result immediately from result backend even if task is re-submitted, so I want to skip resubmission step

dobrych avatar Oct 30 '25 21:10 dobrych