taskiq
taskiq copied to clipboard
How to cancel sending a task using middleware
I create singleton middleware. It checks whether such a task already exists and, if so, should cancel its sending in the pre_send func. How to do it correctly? Return None or raise exception?
Middleware code
import inspect
import time
from hashlib import md5
from typing import Any, Coroutine, Union
from cashews import cache
from loguru import logger
from orjson import orjson
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult
class SingletonMiddleware(TaskiqMiddleware):
SINGLETON_LABEL = "singleton"
UNIQUE_ON_LABEL = "unique_on"
LOCK_EXPIRE_LABEL = "lock_expire"
KEY_PREFIX = "TKQ_SINGLETON_LOCK_"
def __init__(
self,
default_lock_expire: int = 60,
) -> None:
super().__init__()
self.default_lock_expire = default_lock_expire
def pre_send(
self,
message: "TaskiqMessage",
) -> "Union[TaskiqMessage, Coroutine[Any, Any, TaskiqMessage]]":
if self.is_singleton_task(message):
return self.lock_and_run(message)
else:
return message
async def post_execute(
self,
message: "TaskiqMessage",
result: "TaskiqResult[Any]",
) -> "Union[None, Coroutine[Any, Any, None]]":
if self.is_singleton_task(message):
await self.release_lock(message)
return None
async def on_error(
self,
message: "TaskiqMessage",
result: "TaskiqResult[Any]",
exception: BaseException,
) -> "Union[None, Coroutine[Any, Any, None]]":
if self.is_singleton_task(message):
await self.release_lock(message)
return None
def is_singleton_task(self, message: "TaskiqMessage") -> bool:
return self.SINGLETON_LABEL in message.labels
@staticmethod
async def unlock(lock_key: str, task_id: str) -> bool:
return await cache.unlock(lock_key, task_id)
@staticmethod
async def lock(lock_key: str, task_id: str, expire: int) -> bool:
return await cache.set_lock(key=lock_key, value=task_id, expire=expire)
@staticmethod
async def locked(lock_key: str) -> bool:
return await cache.is_locked(key=lock_key)
@staticmethod
async def get_existing_task_id(lock_key: str) -> int:
return await cache.get(key=lock_key)
async def lock_and_run(self, message: TaskiqMessage) -> TaskiqMessage | None:
lock_acquired = await self.acquire_lock(message)
if lock_acquired:
return message
else:
lock_key = self.generate_lock(message)
existing_task_id = self.get_existing_task_id(lock_key)
logger.warning(f"Attempted to queue a duplicate of task ID {existing_task_id}")
# raise SendTaskError()
return None
async def get_lock_expire(self, message: "TaskiqMessage") -> int:
if self.LOCK_EXPIRE_LABEL in message.labels:
return message.labels[self.LOCK_EXPIRE_LABEL]
elif 'timeout' in message.labels:
task_timeout = int(message.labels['timeout'])
task_timeout += 5 * 60
return task_timeout
else:
return self.default_lock_expire
async def release_lock(self, message: "TaskiqMessage") -> bool:
lock_key = self.generate_lock(message)
unlocked = await self.unlock(lock_key, message.task_id)
return unlocked
async def acquire_lock(self, message: "TaskiqMessage") -> bool:
lock_key = self.generate_lock(message)
lock_expire = await self.get_lock_expire(message)
locked = await self.lock(lock_key, message.task_id, lock_expire)
return locked
@staticmethod
def generate_lock_key(task_name: str, task_args: list, task_kwargs: dict, key_prefix: str) -> str:
str_args = str(orjson.dumps(task_args, option=orjson.OPT_SORT_KEYS))
str_kwargs = str(orjson.dumps(task_kwargs, option=orjson.OPT_SORT_KEYS))
task_hash = md5((task_name + str_args + str_kwargs).encode()).hexdigest()
return key_prefix + task_hash
def generate_lock(self, message: "TaskiqMessage") -> str:
task = self.broker.find_task(message.task_name)
if unique_on := message.labels.get('unique_on'):
if isinstance(unique_on, str):
unique_on = [unique_on]
sig = inspect.signature(task.original_func)
bound = sig.bind(*message.args, **message.kwargs).arguments
unique_args = []
unique_kwargs = {key: bound[key] for key in unique_on}
else:
unique_args = message.args
unique_kwargs = message.kwargs
lock_key = self.generate_lock_key(
task_name=str(message.task_name),
task_args=unique_args,
task_kwargs=unique_kwargs,
key_prefix=self.KEY_PREFIX,
)
return lock_key
Task example
@broker.task(
singleton=True,
unique_on=['id', 'name']
)
async def my_singleton_task(id: str, name: str) -> None:
pass
@Bohdan-Ilchyshyn that is really helpful example, thanks for sharing, curious if you end up using it? I'm looking for something similar but also want to include check if result backend contains output for that idempotent taskid to cancel task execution, my understanding is thatwait_result()will get result immediately from result backend even if task is re-submitted, so I want to skip resubmission step