flask-sqlalchemy icon indicating copy to clipboard operation
flask-sqlalchemy copied to clipboard

Fix issue #1312: Type (typehint) error when calling `db.Model` subclass constructor with parameters

Open cainmagi opened this issue 1 year ago • 7 comments

Fix the typehint inconsistence of initializing a subclass of db.Model.

  • fixes #1312
  • Fix an issue when tox p fails because mypy is forbidden (necessary for passing tests)
    • This issue may be caused by upgrade of tox>=4. My tox version is 4.12.0.
    • See details here: https://stackoverflow.com/a/47716994/8266012

Checklist:

  • [ ] Add tests that demonstrate the correct behavior of the change. Tests should fail without the change.
    • Only bug fixing. No need to do this.
  • [ ] Add or update relevant docs, in the docs folder and in code.
    • Only bug fixing. No need to do this.
  • [ ] Add an entry in CHANGES.rst summarizing the change and linking to the issue.
    • Only bug fixing. No need to add any change logs for users.
  • [ ] Add .. versionchanged:: entries in any relevant code docs.
    • Only bug fixing. No need to do this.
  • [x] Run pre-commit hooks and fix any issues.
  • [x] Run pytest and tox, no tests failed.

Appendices

Appendix A: The idea of this PR.

This example shows the core idea how this PR works. Test(prop: _Mul) makes prop as a parameter of its type hint. Then the descriptor will solve the type of the property Test().prop dynamically according to the parameter _Mul of the parameterized type hint Test[_Mul]. For example, if a new instance is initialized as test: Test[int], then IntOrStr will solve the type of test.prop as int.

from typing import Union, Any
from typing import Generic, TypeVar
from typing_extensions import overload, reveal_type


_Mul = TypeVar("_Mul", bound=Union[int, str])


class IntOrStr:

    @overload
    def __get__(self, obj: "Test[int]", obj_cls=None) -> int: ...

    @overload
    def __get__(self, obj: "Test[str]", obj_cls=None) -> str: ...

    @overload
    def __get__(
        self, obj: "Test[Union[int, str]]", obj_cls=None
    ) -> Union[int, str]: ...

    def __get__(self, obj: "Test[Any]", obj_cls=None) -> Any:
        return getattr(obj, "_prop")


class Test(Generic[_Mul]):
    prop = IntOrStr()

    def __init__(self, prop: _Mul) -> None:
        self._prop = prop


reveal_type(Test(1).prop)  # int
reveal_type(Test("ss").prop)  # str


def test_union(val: Union[int, str]) -> None:
    """Match the pre-defined 3rd overload."""
    reveal_type(Test(val).prop)  # int | str


def test_any(val: Any) -> None:
    """Here, Any can be matched with any overload. So the first overload is
    preferred."""
    reveal_type(Test(val).prop)  # int

Appendix B: Validate the performance of this PR

This PR only changes the behavior during the static type checking. In other words, at run time, the codes work totally the same as the original version. I made this customization because I think the original codes only have type checking issues but work perfectly at run time. That's why I do not submit any changelogs here, because actually I did not change any functionalities or behaviors at run time.

The following code can be used for testing the performance of the type hints provided by this PR. It works perfected in most cases. However, there are two known cases that the type check shows false negative results. See test_not_as_expectation.

It seems that the static type checking codes cannot be added to pytest. That's why I do not add more tests for this PR.

from typing import Type, Any, TYPE_CHECKING
from typing_extensions import Never, reveal_type

import dataclasses
import sqlalchemy.orm as sa_orm

from flask_sqlalchemy import SQLAlchemy
from flask_sqlalchemy.model import DefaultMeta, Model


def test_default() -> None:
    db = SQLAlchemy()
    reveal_type(db)  # SQLAlchemy[type[Model]]
    reveal_type(db.Model)  # type[_FSAModel_KW]


def test_unknown_class(model_class: Any) -> None:
    db = SQLAlchemy(model_class=model_class)
    reveal_type(db)  # SQLAlchemy[Any]
    reveal_type(db.Model)  # type[_FSAModel_KW]


def test_meta_class_v1(meta_class: Type[sa_orm.DeclarativeMeta]) -> None:
    if TYPE_CHECKING:
        db = SQLAlchemy(model_class=meta_class(type))
        reveal_type(db)  # SQLAlchemy[type[__class_type]]
        reveal_type(db.Model)  # type[_FSAModel_KW]

        class TestClass(metaclass=meta_class):
            pass

        db_v2 = SQLAlchemy(model_class=TestClass)
        reveal_type(db_v2)  # SQLAlchemy[type[TestClass]]
        reveal_type(db_v2.Model)  # type[_FSAModel_KW]


def test_meta_class_v2(meta_class: Type[DefaultMeta]) -> None:

    if TYPE_CHECKING:
        db = SQLAlchemy(model_class=meta_class(type))
        reveal_type(db)  # SQLAlchemy[type[__class_type]]
        reveal_type(db.Model)  # type[_FSAModel_KW]

        class TestClass(metaclass=meta_class):
            pass

        db_v2 = SQLAlchemy(model_class=TestClass)
        reveal_type(db_v2)  # SQLAlchemy[type[TestClass]]
        reveal_type(db_v2.Model)  # type[_FSAModel_KW]


def test_sqlalchemy2_base(model_class: Type[sa_orm.DeclarativeBase]) -> None:
    db = SQLAlchemy(model_class=model_class)
    reveal_type(db)  # SQLAlchemy[type[DeclarativeBase]]
    reveal_type(db.Model)  # type[_FSAModel_KW]


def test_sqlalchemy2_nometa(model_class: Type[sa_orm.DeclarativeBaseNoMeta]) -> None:
    db = SQLAlchemy(model_class=model_class)
    reveal_type(db)  # SQLAlchemy[type[DeclarativeBaseNoMeta]]
    reveal_type(db.Model)  # type[_FSAModel_KW]


def test_sqlalchemy2_dataclass(model_class: Type[sa_orm.MappedAsDataclass]) -> None:
    db = SQLAlchemy(model_class=model_class)
    reveal_type(db)  # SQLAlchemy[type[MappedAsDataclass]]
    reveal_type(db.Model)  # type[_FSAModel_DataClass]


def test_sqlalchemy2_hybird() -> None:

    class Base1(sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase):
        pass

    db1 = SQLAlchemy(model_class=Base1)
    reveal_type(db1)  # SQLAlchemy[type[Base1]]
    reveal_type(db1.Model)  # type[_FSAModel_DataClass]

    class Base2(sa_orm.DeclarativeBase, sa_orm.MappedAsDataclass):
        pass

    db2 = SQLAlchemy(model_class=Base2)
    reveal_type(db2)  # SQLAlchemy[type[Base2]]
    reveal_type(db2.Model)  # type[_FSAModel_DataClass]

    class AnyClass:
        pass

    class Base3(sa_orm.DeclarativeBase, AnyClass):
        pass

    db3 = SQLAlchemy(model_class=Base3)
    reveal_type(db3)  # SQLAlchemy[type[Base3]]
    reveal_type(db3.Model)  # type[_FSAModel_KW]


def test_class_init_kw() -> None:
    class BaseKW(sa_orm.DeclarativeBase):
        pass

    db = SQLAlchemy(model_class=BaseKW)
    reveal_type(db)  # SQLAlchemy[type[BaseKW]]
    reveal_type(db.Model)  # type[_FSAModel_KW]

    class ModelSa(BaseKW):
        __tablename__ = "modelsas"
        id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True)
        name: sa_orm.Mapped[str]

    class ModelDb(db.Model):
        id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True)
        name: sa_orm.Mapped[str]

    # Well done! Now db.Model works in the same way compared to the base class.
    reveal_type(ModelSa.__init__)  # (self: ModelSa, **kw: Any) -> None
    reveal_type(ModelDb.__init__)  # (self: ModelDb, **kw: Any) -> None


def test_class_init_kw_v2() -> None:
    BaseKWMeta = sa_orm.declarative_base()
    reveal_type(BaseKWMeta)  # Any
    assert isinstance(BaseKWMeta, sa_orm.DeclarativeMeta)
    assert not issubclass(BaseKWMeta, sa_orm.DeclarativeBase)
    assert not issubclass(BaseKWMeta, sa_orm.DeclarativeBaseNoMeta)
    assert not issubclass(BaseKWMeta, sa_orm.MappedAsDataclass)

    db = SQLAlchemy(model_class=BaseKWMeta)
    reveal_type(db)  # SQLAlchemy[DeclarativeMeta]
    reveal_type(db.Model)  # type[_FSAModel_KW]

    class ModelSa(BaseKWMeta):
        __tablename__ = "modelsas"
        id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True)
        name: sa_orm.Mapped[str]

    class ModelDb(db.Model):
        __tablename__ = "modeldbs"
        id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True)
        name: sa_orm.Mapped[str]

    # Note that the typehint of `ModelSa.__init__` is wrong. It is not consistent with
    # the run-time usage. However, `ModelDb.__init__` is consistent. In otherwords,
    # both ModelSa(name="name") and ModelDb(name="name") can work at run time.
    reveal_type(ModelSa.__init__)  # (self: ModelSa) -> None
    reveal_type(ModelDb.__init__)  # (self: ModelDb, **kw: Any) -> None


def test_class_init_dataclass() -> None:
    class BaseDataClass(sa_orm.DeclarativeBase, sa_orm.MappedAsDataclass):
        pass

    db = SQLAlchemy(model_class=BaseDataClass)
    reveal_type(db)  # SQLAlchemy[type[BaseDataClass]]
    reveal_type(db.Model)  # type[_FSAModel_DataClass]

    class ModelSa(BaseDataClass):
        __tablename__ = "modelsas"
        id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True, init=False)
        name: sa_orm.Mapped[str]

    class ModelDb(db.Model):
        id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True, init=False)
        name: sa_orm.Mapped[str]

    # Well done! Now db.Model works in the same way compared to the base class.
    reveal_type(
        ModelSa.__init__
    )  # (self: ModelSa, name: SQLCoreOperations[str] | str) -> None
    reveal_type(
        ModelDb.__init__
    )  # (self: ModelDb, name: SQLCoreOperations[str] | str) -> None


def test_not_allowed() -> None:

    Pass1 = sa_orm.declarative_base()
    reveal_type(Pass1)  # Any
    assert isinstance(Pass1, sa_orm.DeclarativeMeta)
    assert not issubclass(Pass1, sa_orm.DeclarativeBase)
    assert not issubclass(Pass1, sa_orm.DeclarativeBaseNoMeta)
    assert not issubclass(Pass1, sa_orm.MappedAsDataclass)

    reveal_type(SQLAlchemy(model_class=Pass1))  # SQLAlchemy[type[DeclarativeBase]]

    class Pass2(metaclass=DefaultMeta):
        __fsa__ = SQLAlchemy()
        registry = sa_orm.registry()
        __tablename__ = "pass2s"
        id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True)

    reveal_type(SQLAlchemy(model_class=Pass2))  # SQLAlchemy[type[Pass2]]

    class Pass3(Model):
        pass

    reveal_type(SQLAlchemy(model_class=Pass3))  # SQLAlchemy[type[Pass3]]

    class Pass4(sa_orm.DeclarativeBase):
        pass

    reveal_type(SQLAlchemy(model_class=Pass4))  # SQLAlchemy[type[Pass4]]

    class Pass5(sa_orm.DeclarativeBaseNoMeta):
        pass

    reveal_type(SQLAlchemy(model_class=Pass5))  # SQLAlchemy[type[Pass5]]

    class Pass6(sa_orm.DeclarativeBase, sa_orm.MappedAsDataclass):
        pass

    reveal_type(SQLAlchemy(model_class=Pass6))  # SQLAlchemy[type[Pass6]]

    class NotPass1:
        pass

    # As expectation
    # Argument of type "type[NotPass1]" cannot be assigned to parameter "model_class"
    # of type "_FSA_MCT_T@SQLAlchemy" in function "__init__"
    SQLAlchemy(model_class=NotPass1)

    class NotPass2(sa_orm.DeclarativeMeta):
        pass

    # As expectation
    # Argument of type "type[NotPass2]" cannot be assigned to parameter "model_class"
    # of type "_FSA_MCT_T@SQLAlchemy" in function "__init__"
    SQLAlchemy(model_class=NotPass2)

    @dataclasses.dataclass
    class NotPass3:
        a: int = dataclasses.field(default=1)

    # As expectation
    # Argument of type "type[NotPass3]" cannot be assigned to parameter "model_class"
    # of type "_FSA_MCT_T@SQLAlchemy" in function "__init__"
    SQLAlchemy(model_class=NotPass3)


def test_not_as_expectation() -> Never:
    """The following cases show the limitation of the current implementation. In the
    following tests, such usages should not be allowed and will raise run time errors.
    However, the current implementation cannot reveal the errors during the static
    type checks.

    I do not have a good idea to solve them. I think maybe these cases can be left as
    they are.
    """

    class Base(sa_orm.DeclarativeBase, sa_orm.MappedAsDataclass):
        pass

    class Unexpected1(Base):
        __tablename__ = "unexpecteds1"
        id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True, init=False)
        a: sa_orm.Mapped[int]

    # `Unexpected1` should not be used as the model_class. But the type checker allows
    # this usage.
    reveal_type(SQLAlchemy(model_class=Unexpected1))  # SQLAlchemy[type[Unexpected1]]

    class Unexpected2(sa_orm.MappedAsDataclass):
        pass

    # `Unexpected2` does not inherit `DeclarativeBase`. It should not be used as the
    # `model_class`. However, currently, we allow it. That's because Python does not
    # support typing like `Intersection[ParentClass1, ParentClass2]` yet. If we want
    # `SQLAlchemy` to be aware of `MappedAsDataclass`, this class has to be accepted
    # as the input candidate of `model_class`.
    reveal_type(SQLAlchemy(model_class=Unexpected2))  # SQLAlchemy[type[Unexpected2]]

    raise TypeError


if __name__ == "__main__":
    test_default()
    test_unknown_class(type)
    test_meta_class_v1(DefaultMeta)
    test_meta_class_v2(DefaultMeta)
    test_sqlalchemy2_base(type("new", (sa_orm.DeclarativeBase,), dict()))
    test_sqlalchemy2_nometa(type("new", (sa_orm.DeclarativeBaseNoMeta,), dict()))
    test_sqlalchemy2_dataclass(
        type("new", (sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase), dict())
    )
    test_sqlalchemy2_hybird()
    test_class_init_kw()
    test_class_init_kw_v2()
    test_class_init_dataclass()
    test_not_allowed()

    try:
        test_not_as_expectation()
        print("Not as expectation.")
    except Exception:
        print("As expectation.")

cainmagi avatar Mar 27 '24 22:03 cainmagi

Would love this merged!

coessiane avatar Jun 18 '24 09:06 coessiane