reflex icon indicating copy to clipboard operation
reflex copied to clipboard

[ENG-4003] Support DeclarativeBase in core

Open masenf opened this issue 1 year ago β€’ 9 comments

Add serializers and Var definitions for SQLAlchemy DeclarativeBase models.

masenf avatar Nov 07 '24 00:11 masenf

@masenf i currently have smth like this however it needs some love
It tries to prevent recursion errors, which works, but iirc still has a bug.

def is_generator(value: Any) -> bool:
    """Return `value` is a generator."""
    return isgeneratorfunction(value) or isgenerator(value)


def is_iterable_but_not_string(value: Any) -> bool:
    """Return whether `value` is an iterable but not string/bytes."""
    return (
        hasattr(value, "__iter__") and not isinstance(value, str | bytes)
    ) or is_generator(value)


class ModelSerializer:
    exclude_relationships: bool
    lazyload: bool
    ignore: list[str]
    fields: list[str]

    def __init__(
        self,
        *,
        exclude_relationships: bool = False,
        lazyload: bool = False,
        ignore: list[str] | None = None,
        fields: list[str] | None = None,
    ) -> None:
        self.exclude_relationships = exclude_relationships
        self.lazyload = lazyload
        self.ignore = ignore or []
        self.fields = fields or []

    def to_dict(self, model: DataclassMixin) -> dict[str, Any]:
        ctx: dict[str, Any] = {"seen": set()}
        return self.from_value(ctx, model)

    def from_value(self, ctx: dict, value: Any) -> Any:
        if value is None:
            return None
        d: Any = None

        if has_serializer(type(value)):
            if isinstance(value, DeclarativeBase):
                if isinstance(value, MutableProxy):
                    value = value.__wrapped__
                d = self.from_model(ctx, value)
            elif isinstance(value, dict):
                d = serialize(value)
            elif is_iterable_but_not_string(value):
                d = self.from_iterable(ctx, value)
            elif type(value) in [str, int, bool]:
                d = value
            else:
                d = serialize(value)
        else:
            if isinstance(value, DeclarativeBase):
                if isinstance(value, MutableProxy):
                    value = value.__wrapped__
                d = self.from_model(ctx, value)
            elif isinstance(value, dict):
                d = self.from_dict(ctx, value)
            elif is_iterable_but_not_string(value):
                d = self.from_iterable(ctx, value)
            else:
                d = value
        return d

    def from_model(self, ctx: dict, value: Any) -> dict[str, Any]:
        ctx.setdefault("seen", set())
        ctx.setdefault("seen_assoc", set())
        ctx["seen"].add(value)

        state: InstanceState = inspect(value)
        mapper: Mapper = inspect(type(value))

        fields = mapper.columns.keys() + self.fields
        if not self.exclude_relationships:
            for relationship in mapper.relationships.keys():
                if relationship in self.ignore:
                    continue
                if state.attrs[relationship].loaded_value is NO_VALUE:
                    continue
                fields.append(relationship)

        data = {}

        # association proxies are not included in the relationships
        # so we need to add them manually
        association_proxies = []
        for key, descriptor in mapper.all_orm_descriptors.items():
            if descriptor in ctx["seen_assoc"]:
                continue
            if not isinstance(descriptor, AssociationProxy):
                continue
            if (
                target_collection := state.attrs[
                    descriptor.target_collection
                ].loaded_value
            ) is NO_VALUE:
                log.debug(f"{target_collection=} for {key=} is NO_VALUE")
                continue
            # TODO: improve check if value_attr is loaded on target_collection
            # maybe just access attribute and catch exception? should be faster and more reliable than this dirty approach
            if not target_collection:
                association_proxies.append(key)
                continue
            first_target = next(iter(target_collection))
            first_target_state = inspect(first_target)
            if first_target_state.attrs[descriptor.value_attr].loaded_value is NO_VALUE:
                log.debug(f"value_attr is NO_VALUE, not loading {key=}")
                continue
            association_proxies.append(key)
            ctx["seen_assoc"].add(descriptor)

        fields.extend(association_proxies)

        for key in fields:
            if key in self.ignore:
                log.debug(f"ignoring {key=}")
                continue
            if key in self.fields or key in association_proxies:
                log.debug(
                    f"adding {key=} as additional field to {type(value).__name__}"
                )
                loaded_value = getattr(value, key, None)
            else:
                loaded_value = state.attrs[key].loaded_value

            if (
                loaded_value is NO_VALUE
                or loaded_value is LoaderCallableStatus.NO_VALUE
            ) and self.lazyload:
                loaded_value = state.attrs[key].value

            if (
                loaded_value is NO_VALUE
                or loaded_value is LoaderCallableStatus.NO_VALUE
            ) or (
                isinstance(loaded_value, DeclarativeBase)
                and loaded_value in ctx["seen"]
            ):
                continue

            data[key] = self.from_value(ctx, loaded_value)

            # avoid null access errors in frontend to allow auto model creation in dynamic form setters
            #  if data[key] is None:
            #      data[key] = {}

        return data

    def from_dict(self, ctx: dict, value: dict) -> dict:
        return {k: self.from_value(ctx, v) for k, v in value.items()}

    def from_iterable(self, ctx: dict, value: Iterable) -> list:
        return [self.from_value(ctx, v) for v in value]

benedikt-bartscher avatar Nov 07 '24 00:11 benedikt-bartscher

Sorry, I forgot half of the code

class DataclassMixin(
    MappedAsDataclass,
    dataclass_callable=pydantic_dataclass,
    kw_only=True,
    eq=False,
    repr=False,  # disable because it's unreadable and slow with recursive/self-referencing models
):
    serialize_fields: list[str] | None = None
    serialize_ignore: list[str] | None = None

    def __repr__(self) -> str:
        cls = type(self)
        cls_name = cls.__name__
        if hasattr(self, "id"):
            return f"{cls_name} {self.id}"  # pyright: ignore[reportAttributeAccessIssue]
        return f"{cls_name} instance at {id(self)}"

    # needed to prevent pydantic from breaking sqlalchemy
    def __post_init__(self) -> None:
        pass

    def __iter__(self) -> Generator[tuple[str, Any]]:
        """Iterator that yields the items from ``self.dict().items()``."""
        yield from self.to_dict().items()

    def json(self) -> str:
        """Return a JSON string representation of the object."""
        return json.dumps(self.to_dict(), default=serialize)

    def to_dict(
        self,
        *,
        exclude_relationships: bool = False,
        lazyload: bool = False,
    ) -> dict[str, Any]:
        try:
            serializer = ModelSerializer(
                exclude_relationships=exclude_relationships,
                lazyload=lazyload,
                ignore=self.serialize_ignore,
                fields=self.serialize_fields,
            )
            d = serializer.to_dict(self)
            #  d = self._serialize_fields(d)
            log.debug(f"DataclassMixin.to_dict {self.serialize_fields=} {d=}")
        except Exception as e:
            log.error(f"Error in to_dict: {e}")
            raise
        return d


@serializer
def serialize_dataclass_mixin(dataclass_mixin: DataclassMixin) -> dict:
    ret = dataclass_mixin.to_dict()
    return ret

benedikt-bartscher avatar Nov 07 '24 00:11 benedikt-bartscher

i currently have smth like this

very interesting, thanks for sharing.

i think at least some of the recursive serialize calls may not be necessary with the new var system, i'll see if i can incorporate some of this. i also have some test cases that i'll be adding, but i suspect you might have some additional interesting cases that would be good to include as well.

masenf avatar Nov 07 '24 00:11 masenf

very interesting, thanks for sharing.

You are welcome, thanks for implementing this! I think my approach could benefit from yours by splitting it into multiple @serializers.

i think at least some of the recursive serialize calls may not be necessary with the new var system, i'll see if i can incorporate some of this.

If you have a recursion in your datastructure the var system won't save you. Reflex currently does not track used attributes and always serializes the "whole" model. This will cause recursion errors if your models have a cycle.

benedikt-bartscher avatar Nov 07 '24 00:11 benedikt-bartscher

This will cause recursion errors if your models have a cycle.

how do we expect it to serialize if it has recursion though? :thinking:

adhami3310 avatar Nov 07 '24 01:11 adhami3310

how do we expect it to serialize if it has recursion though? :thinking:

Currently it just stops at the first Level of repetition. A cooler approach would allow referential serialisation as i did it in a PoC with pickle (https://github.com/reflex-dev/reflex/pull/3340)

benedikt-bartscher avatar Nov 07 '24 09:11 benedikt-bartscher

repro-sqlalchemy-association-proxy.tar.gz

uploading for testing

masenf avatar Nov 07 '24 17:11 masenf

uploading for testing

I just looked at it, are there any specific issues you want to showcase?

Btw, we have the following for lambda creators:

def globalize_creator(lambda_func: Callable) -> Callable:
    # We need a lambda here
    if not (
        isinstance(lambda_func, types.LambdaType) and lambda_func.__name__ == "<lambda>"
    ):
        raise ValueError("globalize_creator expects a lambda function")

    # Create a unique name for the lambda function
    proposed_name = f"lambda_creator_{len(dir(dynamic_creators))}"
    while hasattr(dynamic_creators, proposed_name):
        proposed_name += "_next"

    # Add the lambda function to the dynamic_creators module
    setattr(dynamic_creators, proposed_name, lambda_func)

    # Set the name, qualname and module of the lambda function
    lambda_func.__name__ = proposed_name
    lambda_func.__qualname__ = proposed_name
    lambda_func.__module__ = dynamic_creators.__name__

    return lambda_func

it's basically a copy of your workaround in reflex, but for lambdas

I suggest trying MappedAsDataclass and using back_popluates with explicit declaration of both sides instead of secondary lambdas. I would also move the @ModelRegistry.register to Base instead of User.

benedikt-bartscher avatar Nov 07 '24 19:11 benedikt-bartscher