`TaskGroup(wait=all)`, weird behaviour and API of `completed`
(re git tag 0.22.1)
Using TaskGroups, with the default wait=all,
Looks like taskgroup.completed is a reference to the first task that completes, even if it completes gracefully and does not stop the task group. If a later task then raises an exception, there does not seem to be a way to get a reference for it. taskgroup.exception will not get set.
So my question is: how do I get a reference to the first exception found by the join?
(One could of course iterate taskgroup.exceptions but there might be multiple non-cancellation exceptions there in arbitrary order.)
Example1:
import asyncio
from aiorpcx import TaskGroup
async def main():
async def foo():
await asyncio.sleep(0.5)
raise Exception("fff")
async with TaskGroup() as group:
await group.spawn(asyncio.sleep(0))
await group.spawn(asyncio.sleep(10))
await group.spawn(foo())
print(f"{group.exception=}")
print(f"{group.completed=}")
asyncio.run(main())
$ python ex1.py
group.exception=None
group.completed=<Task finished name='Task-2' coro=<sleep() done, defined at /usr/lib/python3.8/asyncio/tasks.py:641> result=None>
In particular, ideally something like this would work and exceptions get propagated:
Example2:
import asyncio
from aiorpcx import TaskGroup
class OldTaskGroup(TaskGroup):
"""Automatically raises exceptions on join; as in aiorpcx prior to version 0.20"""
async def join(self):
await super().join()
if self.completed is not None:
try:
self.result # raise exception, if any
except asyncio.CancelledError:
pass
async def main():
async def f():
raise Exception("fff")
async def g():
await group.spawn(f())
async with OldTaskGroup() as group:
#await group.spawn(f()) # note: this would get propagated
await group.spawn(g()) # this does NOT get propagated
print(f"{group.exception=}")
print(f"{group.completed=}")
asyncio.run(main())
$ python ex2.py
group.exception=None
group.completed=<Task finished name='Task-2' coro=<main.<locals>.g() done, defined at /media/sf_shared_folder/ex2.py:21> result=None>
I believe the intended use is something like this: https://github.com/kyuupichan/electrumx/blob/master/electrumx/server/session.py#L615-L617 At least that is what I use and what I gleaned from https://curio.readthedocs.io/en/latest/
So my question is: how do I get a reference to the first exception found by the join? (One could of course iterate
taskgroup.exceptionsbut there might be multiple non-cancellation exceptions there in arbitrary order.)I believe the intended use is something like this: https://github.com/kyuupichan/electrumx/blob/master/electrumx/server/session.py#L615-L617
async with TaskGroup() as group: await group.spawn(self.peer_mgr.discover_peers()) await group.spawn(self._clear_stale_sessions()) await group.spawn(self._handle_chain_reorgs()) await group.spawn(self._recalc_concurrency()) await group.spawn(self._log_sessions()) await group.spawn(self._manage_servers()) async for task in group: if not task.cancelled(): task.result()
Hmm... that means the join sort of becomes a no-op, as by the time it runs next_done will have waited for and popped all tasks from _done, right?
So I guess I could achieve what example2 wants with something like this:
class OldTaskGroup(TaskGroup):
"""Automatically raises exceptions on join; as in aiorpcx prior to version 0.20"""
async def join(self):
if self._wait is all:
try:
async for task in self:
if not task.cancelled():
task.result()
finally:
await super().join()
else:
await super().join()
if self.completed:
self.completed.result()
I think .completed and .exception have somewhat unintuitive and useless behaviour for the wait=all case - though the docstrings match their behaviour. I guess it's ok if you consider they are there for the other wait values.
Feel free to close if the above snippet looks reasonable.
I've tweaked it a bit more, atm using:
class OldTaskGroup(aiorpcx.TaskGroup):
"""Automatically raises exceptions on join; as in aiorpcx prior to version 0.20"""
async def join(self):
if self._wait is all:
exc = False
try:
async for task in self:
if not task.cancelled():
task.result()
except BaseException: # including asyncio.CancelledError
exc = True
raise
finally:
if exc:
await self.cancel_remaining()
await super().join()
else:
await super().join()
if self.completed:
self.completed.result()