asgiref icon indicating copy to clipboard operation
asgiref copied to clipboard

sync_to_async for converting generators to async generators

Open devxpy opened this issue 4 years ago • 6 comments

Moved from https://github.com/django/channels/issues/1411 Related to https://github.com/django/asgiref/issues/38

Hello, thanks for the amazing sync_to_async (and converse) functions, making everyone's life easier as a python developer :)

Just proposing a small update to sync_to_async (and database_sync_to_async too, I guess) that makes it work with generators -

Current situation

Running this -

import asyncio
from time import sleep

from asgiref.sync import sync_to_async


@sync_to_async
def gen():
    for i in range(10):
        sleep(1)
        yield i


async def main():
    async for i in gen():
        print(i)


asyncio.run(main())

Throws -

  async for i in gen():
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
Traceback (most recent call last):
  File "./manage.py", line 21, in <module>
    main()
  File "./manage.py", line 17, in main
    execute_from_command_line(sys.argv)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django/core/management/__init__.py", line 401, in execute_from_command_line
    utility.execute()
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django/core/management/__init__.py", line 395, in execute
    self.fetch_command(subcommand).run_from_argv(self.argv)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django_extensions/management/email_notifications.py", line 65, in run_from_argv
    super(EmailNotificationCommand, self).run_from_argv(argv)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django/core/management/base.py", line 328, in run_from_argv
    self.execute(*args, **cmd_options)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django_extensions/management/email_notifications.py", line 77, in execute
    super(EmailNotificationCommand, self).execute(*args, **options)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django/core/management/base.py", line 369, in execute
    output = self.handle(*args, **options)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django_extensions/management/utils.py", line 62, in inner
    ret = func(self, *args, **kwargs)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django_extensions/management/commands/runscript.py", line 233, in handle
    modules = find_modules_for_script(script)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django_extensions/management/commands/runscript.py", line 216, in find_modules_for_script
    mod = my_import(parent, mod_name)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/site-packages/django_extensions/management/commands/runscript.py", line 169, in my_import
    importlib.import_module(parent_package)
  File "/Users/dev/.virtualenvs/server-99338def/lib/python3.8/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1014, in _gcd_import
  File "<frozen importlib._bootstrap>", line 991, in _find_and_load
  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 783, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/Users/dev/Projects/dara/server/test.py", line 19, in <module>
    asyncio.run(main())
  File "/Users/dev/.pyenv/versions/3.8.1/lib/python3.8/asyncio/runners.py", line 43, in run
    return loop.run_until_complete(main)
  File "/Users/dev/.pyenv/versions/3.8.1/lib/python3.8/asyncio/base_events.py", line 612, in run_until_complete
    return future.result()
  File "/Users/dev/Projects/dara/server/test.py", line 15, in main
    async for i in gen():
TypeError: 'async for' requires an object with __aiter__ method, got coroutine

Proposed solution -

This seems to make the proposed example work.

import inspect
from functools import wraps

from asgiref.sync import sync_to_async as _sync_to_async


def sync_to_async(sync_fn):
    is_gen = inspect.isgeneratorfunction(sync_fn)
    async_fn = _sync_to_async(sync_fn)

    if is_gen:

        @wraps(sync_fn)
        async def wrapper(*args, **kwargs):
            sync_iterable = await async_fn(*args, **kwargs)
            sync_iterator = await iter_async(sync_iterable)

            while True:
                try:
                    yield await next_async(sync_iterator)
                except StopAsyncIteration:
                    return

    else:

        @wraps(sync_fn)
        async def wrapper(*args, **kwargs):
            return await async_fn(*args, **kwargs)

    return wrapper


iter_async = sync_to_async(iter)


@sync_to_async
def next_async(it):
    try:
        return next(it)
    except StopIteration:
        raise StopAsyncIteration

This also does adds functools.wrap(), which is always nice to have.

Sorry if this is too hacky and unsuitable for this repo, but I think it's a great addition to this great function.

devxpy avatar Feb 18 '20 10:02 devxpy

Could also add a new function sync_to_async_iterable, for users who may want to convert existing sync iterables (Querysets!) to async iterables -

def sync_to_async(sync_fn):
    is_gen = inspect.isgeneratorfunction(sync_fn)
    async_fn = _sync_to_async(sync_fn)

    if is_gen:

        @wraps(sync_fn)
        async def wrapper(*args, **kwargs):
            sync_iterable = await async_fn(*args, **kwargs)
            async_iterable = sync_to_async_iterable(sync_iterable)
            async for item in async_iterable:
                yield item

    else:

        @wraps(sync_fn)
        async def wrapper(*args, **kwargs):
            return await async_fn(*args, **kwargs)

    return wrapper
async def sync_to_async_iterable(sync_iterable):
    sync_iterator = await iter_async(sync_iterable)
    while True:
        try:
            yield await next_async(sync_iterator)
        except StopAsyncIteration:
            return

This works!

async def main():
    async for item in sync_to_async_iterable(MyModel.objects.all()):
        print(item)

devxpy avatar Feb 18 '20 10:02 devxpy

Another possibility is to extend sync_to_async to handle iterables too, but I am not sure how reliable the __iter__ and __getitem__ check is.

def sync_to_async(sync_fn):
    if hasattr(sync_fn, "__iter__") or hasattr(sync_fn, "__getitem__"):
        return sync_to_async_iterable(sync_fn)
    ....

devxpy avatar Feb 18 '20 11:02 devxpy

I would suggest:

  • Add a sync_iterable_to_async function that does as you suggest for any iterable (including generator)
  • Modify sync_to_async to detect the case you mentioned and then hand off to the iterable function

Assuming we can get this to pass all the tests and look reasonable, I have no problems pulling it in. If it gets a little tough to do that, we at least need to add a better error message than the current one.

andrewgodwin avatar Feb 18 '20 16:02 andrewgodwin

Thank you for the follow-up, Andrew.

Should I create a similar pull on django-channels, for database_sync_to_async?

devxpy avatar Feb 19 '20 14:02 devxpy

I imagine they would like it too, but you'll need to have the changes released here with a version number before you can depend on them downstream!

andrewgodwin avatar Feb 19 '20 20:02 andrewgodwin

Sorry for the delay, hope this is still open for merging

devxpy avatar May 11 '20 21:05 devxpy

Closing as per #159.

carltongibson avatar Dec 14 '22 16:12 carltongibson