sqlalchemy-mixins icon indicating copy to clipboard operation
sqlalchemy-mixins copied to clipboard

Add Support Async

Open aurthurm opened this issue 1 year ago • 7 comments

Adding seamless support of async calls to allow users who want to develop async applications to use this package with ease with less or no custom modification.

Proposal:

  1. Modify Session Mixin to return a proper query based on the provided session be it scoped session or async scopped session.
@classmethod
def set_session(cls, session, isAsync=False):
    ...

@classproperty
def query(cls):
    """
    :rtype: Query
      """
    if cls._isAsync or not hasattr(cls.session, "query"):
        return select(cls)
    return cls.session.query(cls)
  1. Modify the smart_query to able to get root class from select(cls) with
 # sqlalchemy 2.x

   if query.__dict__["_propagate_attrs"]["plugin_subject"].class_:
       return query.__dict__["_propagate_attrs"]["plugin_subject"].class_
  1. Add an independent ActiveRecod mixin for Async support or extend the existing e.g
    async def fill(self, **kwargs):
        for name in kwargs.keys():
            if name in self.settable_attributes:
                setattr(self, name, kwargs[name])
            else:
                raise KeyError("Attribute '{}' doesn't exist".format(name))

        return self

    async def save(self):
        """Saves the updated model to the current entity db.
        """
        try:
            async with self.session() as session:
                session.add(self)
                await session.commit()
                return self
        except:
            async with self.session() as session:
                await session.rollback()
                raise

    @classmethod
    async def create(cls, **kwargs):
        """Create and persist a new record for the model
        :param kwargs: attributes for the record
        :return: the new model instance
        """
        return await cls().fill(**kwargs).save()

    ....

    @classmethod
    async def first(cls):
        async with cls.session() as session:
            result = await session.execute(cls.query)
            return result.scalars().first()
  1. Introduce an all feature mixin async
from .activerecordasync import ActiveRecordMixinAsync

class AllFeaturesMixinAsync(ActiveRecordMixinAsync, SmartQueryMixin, ReprMixin, SerializeMixin):
    __abstract__ = True
    __repr__ = ReprMixin.__repr__

Outcomes

These modifications will allow for async calls for example:


async def get_all_items():
    return await Item.all()

...

await Item.where(somefilters).all()

...

await Item.create(data_in)


aurthurm avatar Apr 30 '23 14:04 aurthurm

Brilliant suggestion @aurthurm! I was able to use this to add async support using the following steps:

  • Create AsyncRecordMixin which can be mixed in to get _async versions of the original methods.
  • Modify the smart_query to be able to get root class from select(cls) as you suggested. I don't like this change because I am changing the source code of another package. Would be nice to be inject the logic without touching the original code.

Here are my changes:

async_record_mixin.py

from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.future import select
from sqlalchemy_mixins.activerecord import ModelNotFoundError
from sqlalchemy_mixins.inspection import InspectionMixin
from sqlalchemy_mixins.session import SessionMixin


class AsyncRecordMixin(InspectionMixin, SessionMixin):
    """The async active record mixin."""
    __abstract__ = True

    @classmethod
    def _get_primary_key_name(cls) -> str:
        """
        Gets the primary key of the model.

        Note: This method can only be used if the model has a single primary key.
        :return: The name of the primary key.
        :raises InvalidRequestError: If the model does not have a primary key or 
        has a composite primary key.
        """
        primary_keys = cls.__table__.primary_key.columns
        if primary_keys is None:
            raise InvalidRequestError(
                f"Model {cls.__name__} does not have a primary key.")
        if len(primary_keys) > 1:
            raise InvalidRequestError(
                f"Model {cls.__name__} has a composite primary key.")

        return primary_keys[0].name

    @classmethod
    @property
    def query(cls, **filters):
        """
        Override the default query property to handle async session.
        """
        if not hasattr(cls.session, "query"):
            return select(cls)

        return cls.session.query(cls)

    async def save_async(self):
        """
        Async version of :meth:`save` method.

        :see: :meth:`save` method for more information.
        """
        async with self.session() as session:
            try:
                session.add(self)
                await session.commit()
                return self
            except:
                await session.rollback()
                raise

    @classmethod
    async def create_async(cls, **kwargs):
        """
        Async version of :meth:`create` method.

        :see: :meth:`create`
        """
        return await cls().fill(**kwargs).save_async()

    async def update_async(self, **kwargs):
        """
        Async version of :meth:`update` method.

        :see: :meth:`update`
        """
        return await self.fill(**kwargs).save_async()

    async def delete_async(self):
        """
        Async version of :meth:`delete` method.

        :see: :meth:`delete`
        """
        async with self.session() as session:
            try:
                session.sync_session.delete(self)
                await session.commit()
                return self
            except:
                await session.rollback()
                raise
            finally:
                await session.flush()

    @classmethod
    async def destroy_async(cls, *ids):
        """
        Async version of :meth:`destroy` method.

        :see: :meth:`destroy`
        """
        primary_key = cls._get_primary_key_name()
        if primary_key:
            async with cls.session() as session:
                try:
                    for row in await cls.where_async(**{f"{primary_key}__in": ids}):
                        session.sync_session.delete(row)
                    await session.commit()
                except:
                    await session.rollback()
                    raise
                await session.flush()

    @classmethod
    async def select_async(cls, stmt=None, filters=None, sort_attrs=None, schema=None):
        async with cls.session() as session:
            if stmt is None:
                stmt = cls.smart_query(
                    filters=filters, sort_attrs=sort_attrs, schema=schema)
            return (await session.execute(stmt)).scalars()

    @classmethod
    async def where_async(cls, **filters):
        """
        Aync version of where method.

        :see: :meth:`where` method for more details.
        """
        return await cls.select_async(filters=filters)

    @classmethod
    async def sort_async(cls, *columns):
        """
        Async version of sort method.

        :see: :meth:`sort` method for more details.
        """
        return await cls.select_async(sort_attrs=columns)

    @classmethod
    async def all_async(cls):
        """
        Async version of all method.
        This is same as calling ``(await select_async()).all()``.

        :see: :meth:`all` method for more details.
        """
        return (await cls.select_async()).all()

    @classmethod
    async def first_async(cls):
        """
        Async version of first method.
        This is same as calling ``(await select_async()).first()``.

        :see: :meth:`first` method for more details.
        """
        return (await cls.select_async()).first()

    @classmethod
    async def find_async(cls, id_):
        """
        Async version of find method.

        :see: :meth:`find` method for more details.
        """
        primary_key = cls._get_primary_key_name()
        if primary_key:
            return (await cls.where_async(**{primary_key: id_})).first()
        return None

    @classmethod
    async def find_or_fail_async(cls, id_):
        """
        Async version of find_or_fail method.

        :see: :meth:`find_or_fail` method for more details.
        """
        cursor = await cls.find_async(id_)
        if cursor:
            return cursor
        else:
            raise ModelNotFoundError("{} with id '{}' was not found"
                                     .format(cls.__name__, id_))

    @classmethod
    async def with_async(cls, schema):
        """
        Async version of with method.

        :see: :meth:`with` method for more details.
        """
        return await cls.select_async(cls.with_(schema))

    @classmethod
    async def with_joined_async(cls, *paths):
        """
        Async version of with_joined method.

        :see: :meth:`with_joined` method for more details.
        """
        return await cls.select_async(cls.with_joined(*paths))

    @classmethod
    async def with_subquery_async(cls, *paths):
        """
        Async version of with_subquery method.

        :see: :meth:`with_subquery` method for more details.
        """
        return await cls.select_async(cls.with_subquery(*paths))

Then mixin the AsyncRecordMixin as shown below:

from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy_mixins import AllFeaturesMixin
from sqlalchemy_mixins.timestamp import TimestampsMixin

class Base(AsyncAttrs, DeclarativeBase):
    """The base class for all database models."""
    __abstract__ = True
    pass
    
class BaseRecord(Base, AsyncRecordMixin, AllFeaturesMixin, TimestampsMixin):
    __abstract__ = True
    
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async_session_maker = async_sessionmaker(
    engine,
    expire_on_commit=False,
    class_=AsyncSession
)

BaseRecord.set_session(async_session_maker)   

Use the "_async" versions of the methods as shown below:

bob = await User.create_async(name='Bob')
post1 = await Post.create_async(body='Post 1', user=bob, rating=3)
post2 = await Post.create_async(body='long-long-long-long-long body', rating=2,
                    user=await User.create_async(name='Bill'),
                    comments=[await Comment.create_async(body='cool!', user=bob)])

# filter using operators like 'in' and 'contains' and relations like 'user'
# will output this beauty: <Post #1 body:'Post1' user:'Bill'>
print((await Post.where_async(rating__in=[2, 3, 4], user___name__like='%Bi%')).all())
# joinedload post and user
print((await Comment.with_joined_async(Comment.user, Comment.post)).first())
# subqueryload posts
print((await User.with_subquery_async(User.posts)).first())
# sort by rating DESC, user name ASC
print((await Post.sort_async('-rating', 'user___name')).all())

I am a not python expert. So please let me know if there is any issue with this approach or if you have better suggestions.

raajkumars avatar Jun 18 '23 14:06 raajkumars

@aurthurm I think @raajkumars changes are better since we don't have to use a flag to decided whether we should enable async mode or not.

michaelbukachi avatar Jun 18 '23 17:06 michaelbukachi

Thank you for the brilliant suggestions. I will look into it in the coming week and make the recommended changes.

aurthurm avatar Aug 02 '23 21:08 aurthurm

  • Modify the smart_query to be able to get root class from select(cls) as you suggested. I don't like this change because I am changing the source code of another package. Would be nice to be inject the logic without touching the original code.

What's the best way to achieve this without touching the original source?

aurthurm avatar Aug 02 '23 21:08 aurthurm

  • Modify the smart_query to be able to get root class from select(cls) as you suggested. I don't like this change because I am changing the source code of another package. Would be nice to be inject the logic without touching the original code.

What's the best way to achieve this without touching the original source?

Actually, since it is possible to monkey patch easily in python, I managed to fix the original implementation as shown below:

import sqlalchemy_mixins.smartquery as SmaryQuery

original_get_root_cls = SmaryQuery._get_root_cls

def my_get_root_cls(query):
    """Monkey patch SmaryQuery to handle async queries."""
    try:
        return original_get_root_cls(query)
    except ValueError:
        # Handle async queries
        if query.__dict__["_propagate_attrs"]["plugin_subject"].class_:
            return query.__dict__["_propagate_attrs"]["plugin_subject"].class_
        raise


SmaryQuery._get_root_cls = lambda query: my_get_root_cls(query)

I put this code block in my async_record_mixin.py file, so that everything is contained in one file and does not change the original behavior unless you are using the async_record_mixin.py.

raajkumars avatar Aug 03 '23 15:08 raajkumars

Is there any update on async support with this library overall? Would be good to have an official path.

epicwhale avatar Feb 29 '24 07:02 epicwhale

@epicwhale There is an open PR that is almost done.

michaelbukachi avatar Feb 29 '24 09:02 michaelbukachi