Incorporating CrudRepository Interface with Native Query Decorator
Description:
This pull request introduces a CrudRepository interface inspired by Spring-Boot ORM conventions, offering a convenient abstraction layer for database operations. The primary goal is to streamline database interactions for developers by providing a standardized interface while still allowing flexibility for custom queries when needed.
Key Features:
-
CrudRepository Interface:
- The CrudRepository interface abstracts common CRUD operations such as save, delete, find_by_id, find_all, etc.
- Developers can define their repository interfaces inheriting from CrudRepository and automatically gain implementations for these operations.
-
Automatic Table Creation:
- A create_all_tables method is provided within CrudRepository, facilitating automatic table creation based on SQLAlchemy metadata.
-
Native Query Decorator:
- The native_query decorator allows developers to execute custom SQL queries while still leveraging the framework's benefits.
- Queries are executed through SQLAlchemy's connection, and the results are seamlessly mapped to SQLModel instances.
-
Type Annotations and Generics:
- Strong type annotations and generics are utilized throughout the codebase, enhancing readability and maintainability.
- Pydantic models and SQLModel classes are supported, providing a flexible schema definition.
-
Error Handling and Logging:
- Error handling is implemented to catch exceptions during database operations, ensuring robustness.
- Logging is utilized to record errors and provide visibility into potential issues.
Side Note:
Feedback on the design and implementation of these concepts is welcomed. While the current implementation offers a basic foundation, detailed implementations of these concepts are still subject to change based on community feedback and evolving requirements. As we strive to improve the framework, any input is valuable.
Current Implementation:
from typing import Any, Callable, Generic, Iterable, Optional, Type, TypeVar, get_args
from sqlalchemy.sql import text
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import Engine, create_engine
from sqlmodel import Session, SQLModel, select
T = TypeVar("T", SQLModel, BaseModel)
ID = TypeVar("ID", UUID,int)
import logging
logger = logging.getLogger(__name__)
class CrudRepository(Generic[ID,T]):
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.id_type ,self.model_class = self._get_model_id_type_with_class()
@classmethod
def create_all_tables(cls, url: str) -> Engine:
engine = create_engine(url, echo=False)
SQLModel.metadata.create_all(engine)
return engine
@classmethod
def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]:
return get_args(tp= cls.__mro__[0].__orig_bases__[0])
def _commit_operation_in_session(self,session_operation: Callable[[Session], None], session: Session) -> bool:
try:
session_operation(session)
session.commit()
except Exception as error:
logger.error(error)
return False
return True
def _create_session(self) -> Session:
return Session(self.engine, expire_on_commit= True)
def find_by_id(self, id: ID) -> tuple[T, Session]:
session = self._create_session()
statement = select(self.model_class).where(self.model_class.id == id) # type: ignore
return (session.exec(statement).one(), session)
def find_all_by_ids(self, ids: list[ID]) -> tuple[Iterable[T], Session]:
session = self._create_session()
statement = select(self.model_class).where(self.model_class.id in ids) # type: ignore
return (session.exec(statement).all(), session)
def find_all(self) -> tuple[Iterable[T], Session]:
session = self._create_session()
statement = select(self.model_class) # type: ignore
return (session.exec(statement).all(), session)
def save(self, entity: T, session: Optional[Session] = None) -> T:
self._commit_operation_in_session(lambda session: session.add(entity), session or self._create_session())
return entity
def save_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool:
return self._commit_operation_in_session(
lambda session: session.add_all(entities), session or self._create_session()
)
def delete(self, entity: T, session: Optional[Session] = None) -> bool:
return self._commit_operation_in_session(
lambda session: session.delete(entity), session or self._create_session()
)
def delete_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool:
session = session or self._create_session()
for entity in entities:
session.delete(entity)
session.commit()
return True
def native_query(query: str, return_type: Type[T]) -> Any:
def decorated(func: Callable[..., T]) -> Callable[..., T]:
def wrapper(self: CrudRepository, **kwargs) -> T:
with self.engine.connect() as connection:
sql= text(query.format(**kwargs))
query_result = connection.execute(sql)
query_result_dicts = query_result.mappings().all()
if return_type.__name__ == "Iterable":
cls_inside_inside_iterable = get_args(return_type)[0]
return [cls_inside_inside_iterable.model_validate(query_result) for query_result in query_result_dicts] # type: ignore
return return_type.model_validate(list(query_result_dicts).pop()) # Create an instance of the specified model class
return wrapper
return decorated # type: ignore
Example Usage:
from sqlmodel import Field, SQLModel
from sqlmodel.repository.crud_repository import CrudRepository, native_query
class Hero(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str
secret_name: str
age: int | None = None
class HeroRepository(CrudRepository[int, Hero]):
@native_query("SELECT * FROM hero WHERE name = '{name}'", Hero)
def get_hero_by_name(self, name: str) -> Hero:
...
# Usage
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
engine = CrudRepository.create_all_tables(sqlite_url)
hero_repo = HeroRepository(engine)
deadpond = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_repo.save(deadpond)
print(hero_repo.find_all())
print(hero_repo.get_hero_by_name(name="Deadpond"))