starlette-admin icon indicating copy to clipboard operation
starlette-admin copied to clipboard

Enhancement: doc could illustrate multi tenancy (is this the right way, by the way?)

Open sglebs opened this issue 1 year ago • 6 comments

https://jowilf.github.io/starlette-admin/api/contrib/sqlalchemy/modelview/#starlette_admin.contrib.sqla.view.ModelView.get_list_query is an interesting feature. Right away I thought: great, I can have Admin show everything for any SUPERUSER logged in, but only show data for the company where I am registered if I have an ADMINISTRATOR role. So, I would implement this by filtering put automatically all say Tasks where Task.tenant_id != loggedUsed.tenant_id . Something like this:

    def get_list_query(self):
        if <currentUser>.role.value == UserRole.SUPERUSER:
            return super().get_list_query()
        else:
            return super().get_list_query().where(User.company_id == <currentUser>.company_id)

    def get_count_query(self):
        if <currentUser>.role.value == UserRole.SUPERUSER:
            return super().get_count_query()
        else:
            return super().get_count_query().where(User.company_id == <currentUser>.company_id)

The problem is: how can I get the object/model that represents the current logged used from here? Or perhaps to the session (I could cache the role and the company_id).

Maybe this documentation page is the perfect place to insert such a tip?

If, on the other hand, multi tenancy filtering can/must be done another way, please let me know.

Thanks for listening!

sglebs avatar Aug 30 '23 01:08 sglebs

Interesting: the search query get_search_query(request, term) does receive a Request as parameter. Unfortunately the 2 above do not.

sglebs avatar Aug 30 '23 10:08 sglebs

Interesting: the search query get_search_query(request, term) does receive a Request as parameter. Unfortunately the 2 above do not.

Indeed adding a request to those methods will solve the issue, but you can also override the find_all and count methods to support this

jowilf avatar Aug 30 '23 13:08 jowilf

How about adding these 2 in the superclass:

   def get_list_query_for_request(self, request):
      return self.get_list_query()  # backwards compatible for people who overrode the original version which does not take a request

    def get_count_query_for_request(self, request):
      return self.get_count_query() # backwards compatible for people who overrode the original version which does not take 

If your framework added these at the top superclass and used these 2, then it would be backwards compatible with people who already overrode the existing ones, while allowing people like me to override with a behavior that depends on the request. win-win.

Just a suggestion.

Meanwhile I will study those 2 methods/suggestions you gave. The documentation is very sparse, I have no idea how to use them. I will browse some of the code to see if I can figure things out.

Thanks for listening.

sglebs avatar Aug 31 '23 01:08 sglebs

Perhaps even simpler would be to change the signature of the existing methods in a non-breaking change.

   def get_list_query(self, request=None):
      ...

    def get_count_query(self, request=None):
      ...

I can see that when you call it from find_all, you do have the request object available:

    async def find_all(
        self,
        request: Request,
        skip: int = 0,
        limit: int = 100,
        where: Union[Dict[str, Any], str, None] = None,
        order_by: Optional[List[str]] = None,
    ) -> Sequence[Any]:
        session: Union[Session, AsyncSession] = request.state.session
        stmt = self.get_list_query().offset(skip)
...

the patch could be:

    async def find_all(
        self,
        request: Request,
        skip: int = 0,
        limit: int = 100,
        where: Union[Dict[str, Any], str, None] = None,
        order_by: Optional[List[str]] = None,
    ) -> Sequence[Any]:
        session: Union[Session, AsyncSession] = request.state.session
        stmt = self.get_list_query(request=request).offset(skip)
...

Anyway, just trying to find a convenient solution for users.

sglebs avatar Aug 31 '23 01:08 sglebs

I created a workaround class, replicating both method mentioned in https://github.com/jowilf/starlette-admin/issues/274#issuecomment-1699237492 but changing the calls to self.get_list_query_for_request and self.get_count_query_for_request.

The class:


# This class only exists because of https://github.com/jowilf/starlette-admin/issues/274#issuecomment-1699237492
class WorakaroundModelView(ModelView):


    async def find_all(
        self,
        request: Request,
        skip: int = 0,
        limit: int = 100,
        where: Union[Dict[str, Any], str, None] = None,
        order_by: Optional[List[str]] = None,
    ) -> Sequence[Any]:
        session: Union[Session, AsyncSession] = request.state.session
        stmt = self.get_list_query_for_request(request).offset(skip)
        if limit > 0:
            stmt = stmt.limit(limit)
        if where is not None:
            if isinstance(where, dict):
                where = build_query(where, self.model)
            else:
                where = await self.build_full_text_search_query(
                    request, where, self.model
                )
            stmt = stmt.where(where)  # type: ignore
        stmt = stmt.order_by(*build_order_clauses(order_by or [], self.model))
        for field in self.get_fields_list(request, RequestAction.LIST):
            if isinstance(field, RelationField):
                stmt = stmt.options(joinedload(getattr(self.model, field.name)))
        if isinstance(session, AsyncSession):
            return (await session.execute(stmt)).scalars().unique().all()
        return (
            (await anyio.to_thread.run_sync(session.execute, stmt))
            .scalars()
            .unique()
            .all()
        )

    async def count(
        self,
        request: Request,
        where: Union[Dict[str, Any], str, None] = None,
    ) -> int:
        session: Union[Session, AsyncSession] = request.state.session
        stmt = self.get_count_query_for_request(request)
        if where is not None:
            if isinstance(where, dict):
                where = build_query(where, self.model)
            else:
                where = await self.build_full_text_search_query(
                    request, where, self.model
                )
            stmt = stmt.where(where)  # type: ignore
        if isinstance(session, AsyncSession):
            return (await session.execute(stmt)).scalar_one()
        return (await anyio.to_thread.run_sync(session.execute, stmt)).scalar_one()

Now I do this:

    def get_list_query_for_request(self, request: Request):
        if request.session["user_role"] in [UserRole.SUPERADMIN.name]:
            return self.get_list_query()
        else:
            return self.get_list_query().where(User.company_id == request.session["user_company_id"])

    def get_count_query_for_request(self, request: Request):
        if request.session["user_role"] in [UserRole.SUPERADMIN.name]:
            return self.get_count_query()
        else:
            return self.get_count_query().where(User.company_id == request.session["user_company_id"])

and indeed the rows which do not belong to the same tenant as the user are discarded when the user is not a SUPERADMIN.

Unfortunately there seems to be a problem with pagination. It lists only 5 items and seems to think it is listing 10:

image

Maybe I need to override something else too?

sglebs avatar Sep 01 '23 22:09 sglebs

The inconsistency above may have been cause by the lack of a proper join. Now I changed some of them to be like this for example (watch the join):

    def get_list_query_for_request(self, request: Request):
        if request.session["user_role"] in [UserRole.SUPERADMIN.name]:
            return self.get_list_query()
        else:
            return self.get_list_query().join(Attachment.task).where(Task.company_id == request.session["user_company_id"])

    def get_count_query_for_request(self, request: Request):
        if request.session["user_role"] in [UserRole.SUPERADMIN.name]:
            return self.get_count_query()
        else:
            return self.get_count_query().join(Attachment.task).where(Task.company_id == request.session["user_company_id"])

It looks like the counters are ok now. Checking some more.

sglebs avatar Sep 01 '23 22:09 sglebs