aiorpcX icon indicating copy to clipboard operation
aiorpcX copied to clipboard

`TaskGroup(wait=all)`, weird behaviour and API of `completed`

Open SomberNight opened this issue 4 years ago • 3 comments

(re git tag 0.22.1) Using TaskGroups, with the default wait=all, Looks like taskgroup.completed is a reference to the first task that completes, even if it completes gracefully and does not stop the task group. If a later task then raises an exception, there does not seem to be a way to get a reference for it. taskgroup.exception will not get set.

So my question is: how do I get a reference to the first exception found by the join? (One could of course iterate taskgroup.exceptions but there might be multiple non-cancellation exceptions there in arbitrary order.)

Example1:

import asyncio
from aiorpcx import TaskGroup

async def main():
    async def foo():
        await asyncio.sleep(0.5)
        raise Exception("fff")

    async with TaskGroup() as group:
        await group.spawn(asyncio.sleep(0))
        await group.spawn(asyncio.sleep(10))
        await group.spawn(foo())
    print(f"{group.exception=}")
    print(f"{group.completed=}")

asyncio.run(main())
$ python ex1.py 
group.exception=None
group.completed=<Task finished name='Task-2' coro=<sleep() done, defined at /usr/lib/python3.8/asyncio/tasks.py:641> result=None>

In particular, ideally something like this would work and exceptions get propagated:

Example2:

import asyncio
from aiorpcx import TaskGroup

class OldTaskGroup(TaskGroup):
    """Automatically raises exceptions on join; as in aiorpcx prior to version 0.20"""
    async def join(self):
        await super().join()
        if self.completed is not None:
            try:
                self.result  # raise exception, if any
            except asyncio.CancelledError:
                pass

async def main():
    async def f():
        raise Exception("fff")

    async def g():
        await group.spawn(f())

    async with OldTaskGroup() as group:
        #await group.spawn(f())  # note: this would get propagated
        await group.spawn(g())   # this does NOT get propagated
    print(f"{group.exception=}")
    print(f"{group.completed=}")

asyncio.run(main())
$ python ex2.py 
group.exception=None
group.completed=<Task finished name='Task-2' coro=<main.<locals>.g() done, defined at /media/sf_shared_folder/ex2.py:21> result=None>

SomberNight avatar Jan 28 '22 13:01 SomberNight

I believe the intended use is something like this: https://github.com/kyuupichan/electrumx/blob/master/electrumx/server/session.py#L615-L617 At least that is what I use and what I gleaned from https://curio.readthedocs.io/en/latest/

kyuupichan avatar Jan 31 '22 13:01 kyuupichan

So my question is: how do I get a reference to the first exception found by the join? (One could of course iterate taskgroup.exceptions but there might be multiple non-cancellation exceptions there in arbitrary order.)

I believe the intended use is something like this: https://github.com/kyuupichan/electrumx/blob/master/electrumx/server/session.py#L615-L617

async with TaskGroup() as group:
    await group.spawn(self.peer_mgr.discover_peers())
    await group.spawn(self._clear_stale_sessions())
    await group.spawn(self._handle_chain_reorgs())
    await group.spawn(self._recalc_concurrency())
    await group.spawn(self._log_sessions())
    await group.spawn(self._manage_servers())

    async for task in group:
        if not task.cancelled():
            task.result()

Hmm... that means the join sort of becomes a no-op, as by the time it runs next_done will have waited for and popped all tasks from _done, right?

So I guess I could achieve what example2 wants with something like this:

class OldTaskGroup(TaskGroup):
    """Automatically raises exceptions on join; as in aiorpcx prior to version 0.20"""
    async def join(self):
        if self._wait is all:
            try:
                async for task in self:
                    if not task.cancelled():
                        task.result()
            finally:
                await super().join()
        else:
            await super().join()
            if self.completed:
                self.completed.result()

I think .completed and .exception have somewhat unintuitive and useless behaviour for the wait=all case - though the docstrings match their behaviour. I guess it's ok if you consider they are there for the other wait values. Feel free to close if the above snippet looks reasonable.

SomberNight avatar Feb 07 '22 11:02 SomberNight

I've tweaked it a bit more, atm using:

class OldTaskGroup(aiorpcx.TaskGroup):
    """Automatically raises exceptions on join; as in aiorpcx prior to version 0.20"""
    async def join(self):
        if self._wait is all:
            exc = False
            try:
                async for task in self:
                    if not task.cancelled():
                        task.result()
            except BaseException:  # including asyncio.CancelledError
                exc = True
                raise
            finally:
                if exc:
                    await self.cancel_remaining()
                await super().join()
        else:
            await super().join()
            if self.completed:
                self.completed.result()

SomberNight avatar Feb 10 '22 13:02 SomberNight