sqlalchemy-mixins
sqlalchemy-mixins copied to clipboard
Add Async Support
Linked Issue: Add Async Support
Modifications/Additions:
- Modifications of smart_querys's _get_root_cls to be able to work with sqlalchemy 2.0 select(cls) as query
- Modification of Session Mission to cater for async scoped session
- Addition of a standalone ActiveRecordAsync mixin that is fully async
- Addition of the AllFeaturesMixinAsync
These modification adds the following posiblities
# create async scoped session
from asyncio import current_task
from sqlalchemy import create_engine, String
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_scoped_session,
create_async_engine,
)
from sqlalchemy.orm import sessionmaker
from sqlalchemy_mixins import AllFeaturesMixinAsync, smart_query
async_engine = create_async_engine(
"your connection string here with async driver",
pool_pre_ping=True,
echo=False,
future=True,
)
session= async_scoped_session(sessionmaker(
bind=async_engine, expire_on_commit=False, autoflush=False, class_=AsyncSession
), scopefunc=current_task)
class DBModel(AllFeaturesMixinAsync):
__name__: str
__abstract__ = True
__mapper_args__ = {"eager_defaults": True}
# __allow_unmapped__ = True
uid = Column(
String,
primary_key=True,
index=True,
nullable=False,
default=get_flake_uid,
)
@declared_attr
def __tablename__(cls) -> str:
""" "
Generate tablename automatically
"""
return cls.__name__.lower()
@classmethod
def _import(cls, schema_in: Union[InDBSchemaType, dict]):
"""Convert Pydantic schema to dict"""
if isinstance(schema_in, dict):
return schema_in
data = schema_in.dict(exclude_unset=True)
return data
@classmethod
async def get(cls, **kwargs):
"""Return the first value in database based on given args.
Example:
User.get(id=5)
"""
stmt = cls.where(**kwargs)
async with cls.session() as session:
results = await session.execute(stmt)
return results.scalars().first()
@classmethod
async def get_all(cls, **kwargs):
stmt = cls.where(**kwargs)
async with cls.session() as session:
results = await session.execute(stmt)
return results.unique().scalars().all()
@classmethod
async def bulk_update_where(cls, update_data: List, filters: dict):
"""
@param update_data a List of dictionary update values.
@param filters is a dict of filter values.
e.g [{'uid': 34, update_values}, ...]
"""
to_update = [cls._import(data) for data in update_data]
# stmt = update(cls).where(filters).values(to_save).execution_options(synchronize_session="fetch")
query = smart_query(query=update(cls), filters=filters)
stmt = query.values(to_update).execution_options(
synchronize_session="fetch")
async with cls.session() as session:
results = await session.execute(stmt)
updated = results.scalars().all()
return updated
@classmethod
async def bulk_update_with_mappings(cls, mappings: List) -> None:
"""
@param mappings a List of dictionary update values with pks.
e.g [{'uid': 34, update_values}, ...]
?? there must be zero many-to-many relations
NB: Function does not return anything
"""
if len(mappings) == 0:
return
from sqlalchemy.sql.expression import bindparam
to_update = [cls._import(data) for data in mappings]
for item in to_update:
item["_uid"] = item["uid"]
query = update(cls).where(cls.uid == bindparam("_uid"))
binds = {}
for key in to_update[0]:
if key != "_uid":
binds[key] = bindparam(key)
stmt = query.values(binds).execution_options(
synchronize_session=None) # "fetch" not available
async with cls.session() as session:
await session.execute(stmt, to_update)
await session.flush()
await session.commit()
@classmethod
async def count_where(cls, filters):
"""
:param filters:
:return: int
"""
filter_stmt = smart_query(query=select(cls), filters=filters)
count_stmt = select(func.count(filter_stmt.c.uid)
).select_from(filter_stmt)
async with cls.session() as session:
res = await session.execute(count_stmt)
return res.scalars().one()
@classmethod
async def get_by_uids(cls, uids: List[Any]):
stmt = select(cls).where(cls.uid.in_(uids)) # type: ignore
async with cls.session() as session:
results = await session.execute(stmt.order_by(cls.name))
return results.scalars().all()
class Item(DBModel):
"""Item"""
name = Column(String, nullable=False)
description = Column(String, nullable=False)
active = Column(Boolean(), default=False)
# get all items
item = await Item.all()
# get an Item by uid
item = await Item.get(uid=2)
# get an Item by other filters
item = await Item.get_all(name__in=["Amanda", "Mathew"], active=True)
Hey @aurthurm Do you mind resolving the conflicts?
@michaelbukachi i have updated the pull request according to suggestions from @raajkumars
@aurthurm Type checks are failing
@michaelbukachi i have resolved the typing issues. Thank you for adding this github workflow. I have learnt something through this contribution.
Hey @aurthurm Nice work! We need to add tests for the new functionality though.
Hey! Great work! Any chance this will be released soon?
+1 eagerly looking forward to having robust support for async! :)
@aurthurm Any progress on the tests
Good job @aurthurm , Would it be possible to release this feature at your earliest convenience?, as I would greatly appreciate having it available for my company's upcoming new project
thx!
Ive been a bit tied up at work. Ill try my beat to write the required tests and examples soon