sqlalchemy-mixins
sqlalchemy-mixins copied to clipboard
Add Support Async
Adding seamless support of async calls to allow users who want to develop async applications to use this package with ease with less or no custom modification.
Proposal:
- Modify Session Mixin to return a proper query based on the provided session be it scoped session or async scopped session.
@classmethod
def set_session(cls, session, isAsync=False):
...
@classproperty
def query(cls):
"""
:rtype: Query
"""
if cls._isAsync or not hasattr(cls.session, "query"):
return select(cls)
return cls.session.query(cls)
- Modify the smart_query to able to get root class from select(cls) with
# sqlalchemy 2.x
if query.__dict__["_propagate_attrs"]["plugin_subject"].class_:
return query.__dict__["_propagate_attrs"]["plugin_subject"].class_
- Add an independent ActiveRecod mixin for Async support or extend the existing e.g
async def fill(self, **kwargs):
for name in kwargs.keys():
if name in self.settable_attributes:
setattr(self, name, kwargs[name])
else:
raise KeyError("Attribute '{}' doesn't exist".format(name))
return self
async def save(self):
"""Saves the updated model to the current entity db.
"""
try:
async with self.session() as session:
session.add(self)
await session.commit()
return self
except:
async with self.session() as session:
await session.rollback()
raise
@classmethod
async def create(cls, **kwargs):
"""Create and persist a new record for the model
:param kwargs: attributes for the record
:return: the new model instance
"""
return await cls().fill(**kwargs).save()
....
@classmethod
async def first(cls):
async with cls.session() as session:
result = await session.execute(cls.query)
return result.scalars().first()
- Introduce an all feature mixin async
from .activerecordasync import ActiveRecordMixinAsync
class AllFeaturesMixinAsync(ActiveRecordMixinAsync, SmartQueryMixin, ReprMixin, SerializeMixin):
__abstract__ = True
__repr__ = ReprMixin.__repr__
Outcomes
These modifications will allow for async calls for example:
async def get_all_items():
return await Item.all()
...
await Item.where(somefilters).all()
...
await Item.create(data_in)
Brilliant suggestion @aurthurm! I was able to use this to add async support using the following steps:
- Create AsyncRecordMixin which can be mixed in to get _async versions of the original methods.
- Modify the smart_query to be able to get root class from select(cls) as you suggested. I don't like this change because I am changing the source code of another package. Would be nice to be inject the logic without touching the original code.
Here are my changes:
async_record_mixin.py
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.future import select
from sqlalchemy_mixins.activerecord import ModelNotFoundError
from sqlalchemy_mixins.inspection import InspectionMixin
from sqlalchemy_mixins.session import SessionMixin
class AsyncRecordMixin(InspectionMixin, SessionMixin):
"""The async active record mixin."""
__abstract__ = True
@classmethod
def _get_primary_key_name(cls) -> str:
"""
Gets the primary key of the model.
Note: This method can only be used if the model has a single primary key.
:return: The name of the primary key.
:raises InvalidRequestError: If the model does not have a primary key or
has a composite primary key.
"""
primary_keys = cls.__table__.primary_key.columns
if primary_keys is None:
raise InvalidRequestError(
f"Model {cls.__name__} does not have a primary key.")
if len(primary_keys) > 1:
raise InvalidRequestError(
f"Model {cls.__name__} has a composite primary key.")
return primary_keys[0].name
@classmethod
@property
def query(cls, **filters):
"""
Override the default query property to handle async session.
"""
if not hasattr(cls.session, "query"):
return select(cls)
return cls.session.query(cls)
async def save_async(self):
"""
Async version of :meth:`save` method.
:see: :meth:`save` method for more information.
"""
async with self.session() as session:
try:
session.add(self)
await session.commit()
return self
except:
await session.rollback()
raise
@classmethod
async def create_async(cls, **kwargs):
"""
Async version of :meth:`create` method.
:see: :meth:`create`
"""
return await cls().fill(**kwargs).save_async()
async def update_async(self, **kwargs):
"""
Async version of :meth:`update` method.
:see: :meth:`update`
"""
return await self.fill(**kwargs).save_async()
async def delete_async(self):
"""
Async version of :meth:`delete` method.
:see: :meth:`delete`
"""
async with self.session() as session:
try:
session.sync_session.delete(self)
await session.commit()
return self
except:
await session.rollback()
raise
finally:
await session.flush()
@classmethod
async def destroy_async(cls, *ids):
"""
Async version of :meth:`destroy` method.
:see: :meth:`destroy`
"""
primary_key = cls._get_primary_key_name()
if primary_key:
async with cls.session() as session:
try:
for row in await cls.where_async(**{f"{primary_key}__in": ids}):
session.sync_session.delete(row)
await session.commit()
except:
await session.rollback()
raise
await session.flush()
@classmethod
async def select_async(cls, stmt=None, filters=None, sort_attrs=None, schema=None):
async with cls.session() as session:
if stmt is None:
stmt = cls.smart_query(
filters=filters, sort_attrs=sort_attrs, schema=schema)
return (await session.execute(stmt)).scalars()
@classmethod
async def where_async(cls, **filters):
"""
Aync version of where method.
:see: :meth:`where` method for more details.
"""
return await cls.select_async(filters=filters)
@classmethod
async def sort_async(cls, *columns):
"""
Async version of sort method.
:see: :meth:`sort` method for more details.
"""
return await cls.select_async(sort_attrs=columns)
@classmethod
async def all_async(cls):
"""
Async version of all method.
This is same as calling ``(await select_async()).all()``.
:see: :meth:`all` method for more details.
"""
return (await cls.select_async()).all()
@classmethod
async def first_async(cls):
"""
Async version of first method.
This is same as calling ``(await select_async()).first()``.
:see: :meth:`first` method for more details.
"""
return (await cls.select_async()).first()
@classmethod
async def find_async(cls, id_):
"""
Async version of find method.
:see: :meth:`find` method for more details.
"""
primary_key = cls._get_primary_key_name()
if primary_key:
return (await cls.where_async(**{primary_key: id_})).first()
return None
@classmethod
async def find_or_fail_async(cls, id_):
"""
Async version of find_or_fail method.
:see: :meth:`find_or_fail` method for more details.
"""
cursor = await cls.find_async(id_)
if cursor:
return cursor
else:
raise ModelNotFoundError("{} with id '{}' was not found"
.format(cls.__name__, id_))
@classmethod
async def with_async(cls, schema):
"""
Async version of with method.
:see: :meth:`with` method for more details.
"""
return await cls.select_async(cls.with_(schema))
@classmethod
async def with_joined_async(cls, *paths):
"""
Async version of with_joined method.
:see: :meth:`with_joined` method for more details.
"""
return await cls.select_async(cls.with_joined(*paths))
@classmethod
async def with_subquery_async(cls, *paths):
"""
Async version of with_subquery method.
:see: :meth:`with_subquery` method for more details.
"""
return await cls.select_async(cls.with_subquery(*paths))
Then mixin the AsyncRecordMixin
as shown below:
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy_mixins import AllFeaturesMixin
from sqlalchemy_mixins.timestamp import TimestampsMixin
class Base(AsyncAttrs, DeclarativeBase):
"""The base class for all database models."""
__abstract__ = True
pass
class BaseRecord(Base, AsyncRecordMixin, AllFeaturesMixin, TimestampsMixin):
__abstract__ = True
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async_session_maker = async_sessionmaker(
engine,
expire_on_commit=False,
class_=AsyncSession
)
BaseRecord.set_session(async_session_maker)
Use the "_async" versions of the methods as shown below:
bob = await User.create_async(name='Bob')
post1 = await Post.create_async(body='Post 1', user=bob, rating=3)
post2 = await Post.create_async(body='long-long-long-long-long body', rating=2,
user=await User.create_async(name='Bill'),
comments=[await Comment.create_async(body='cool!', user=bob)])
# filter using operators like 'in' and 'contains' and relations like 'user'
# will output this beauty: <Post #1 body:'Post1' user:'Bill'>
print((await Post.where_async(rating__in=[2, 3, 4], user___name__like='%Bi%')).all())
# joinedload post and user
print((await Comment.with_joined_async(Comment.user, Comment.post)).first())
# subqueryload posts
print((await User.with_subquery_async(User.posts)).first())
# sort by rating DESC, user name ASC
print((await Post.sort_async('-rating', 'user___name')).all())
I am a not python expert. So please let me know if there is any issue with this approach or if you have better suggestions.
@aurthurm I think @raajkumars changes are better since we don't have to use a flag to decided whether we should enable async mode or not.
Thank you for the brilliant suggestions. I will look into it in the coming week and make the recommended changes.
- Modify the smart_query to be able to get root class from select(cls) as you suggested. I don't like this change because I am changing the source code of another package. Would be nice to be inject the logic without touching the original code.
What's the best way to achieve this without touching the original source?
- Modify the smart_query to be able to get root class from select(cls) as you suggested. I don't like this change because I am changing the source code of another package. Would be nice to be inject the logic without touching the original code.
What's the best way to achieve this without touching the original source?
Actually, since it is possible to monkey patch easily in python, I managed to fix the original implementation as shown below:
import sqlalchemy_mixins.smartquery as SmaryQuery
original_get_root_cls = SmaryQuery._get_root_cls
def my_get_root_cls(query):
"""Monkey patch SmaryQuery to handle async queries."""
try:
return original_get_root_cls(query)
except ValueError:
# Handle async queries
if query.__dict__["_propagate_attrs"]["plugin_subject"].class_:
return query.__dict__["_propagate_attrs"]["plugin_subject"].class_
raise
SmaryQuery._get_root_cls = lambda query: my_get_root_cls(query)
I put this code block in my async_record_mixin.py file, so that everything is contained in one file and does not change the original behavior unless you are using the async_record_mixin.py.
Is there any update on async support with this library overall? Would be good to have an official path.
@epicwhale There is an open PR that is almost done.