beanie
beanie copied to clipboard
[QUESTION] chained call behavior when use find+limit+skip+aggregate
i try to chained call with method find, sort, limit, skip and finally use aggregate, but got unexpected result.
a. insert 100 records into db b. try to find data use {} as a condition c. use limit(10) to limit return numbers: what i need is 10 d. use aggregate to query across tables
expected: return 10 records
actually: return 100 records
my env:
Package Version
----------------- ---------
beanie 1.11.6
motor 3.0.0
pydantic 1.9.1
pymongo 4.2.0
pytest 7.1.2
pytest-asyncio 0.19.0
pytest-cov 3.0.0
Q1: can i channed call the methods like: DocType.find().sort().limit().skip().aggregate()
?
Q2: what is the correct result of the test case?
reproduce steps: the code is complete, it should run with pytest
- copy the code below
- install the libs
- run
pytest --setup-show --log-cli-level=INFO
codes
- test code dir
.
├── conftest.py
├── __init__.py
├── models.py
├── __pycache__
│ ├── conftest.cpython-38-pytest-7.1.2.pyc
│ ├── __init__.cpython-38.pyc
│ ├── models.cpython-38.pyc
│ └── test_chain_call.cpython-38-pytest-7.1.2.pyc
├── pytest.ini
├── requirements.txt
└── test_chain_call.py
- pytest.ini
[pytest]
asyncio_mode=auto
- requirements.txt
pytest
pytest-cov
pytest-asyncio
beanie
pydantic
- models.py
from beanie import PydanticObjectId, Document
from pydantic import BaseModel, Field
from pymongo import IndexModel, DESCENDING
class Cup(Document):
name: str
width: float
height: float
class Collection:
name = "cup"
indexes = [
IndexModel([("name", DESCENDING)], unique=True),
]
class Water(Document):
name: str
class Collection:
name = "water"
indexes = [
IndexModel([("name", DESCENDING)], unique=True),
]
class WaterInCup(Document):
cup: PydanticObjectId
water: PydanticObjectId
class Collection:
name = "water_in_cup"
indexes = [
IndexModel([("cup", DESCENDING), ("water", DESCENDING)], unique=True),
]
class CupInfo(BaseModel):
id: PydanticObjectId = Field(alias="_id")
name: str
width: float
height: float
class WaterInfo(BaseModel):
id: PydanticObjectId = Field(alias="_id")
name: str
class OutWaterInCup(BaseModel):
id: PydanticObjectId = Field(alias="_id")
cup: PydanticObjectId
water: PydanticObjectId
cup_info: CupInfo
water_info: WaterInfo
- conftest.py
import pytest
import motor.motor_asyncio
from pydantic import BaseSettings
from beanie import init_beanie
import logging
from .models import Cup, Water, WaterInCup
LOGGER = logging.getLogger(__name__)
class Settings(BaseSettings):
uri: str = "mongodb://127.0.0.1:27017/pytest"
db_name: str = "test_db"
settings = Settings()
@pytest.fixture
def motor_client():
LOGGER.info(">>>>>>>>>>init_client")
return motor.motor_asyncio.AsyncIOMotorClient(settings.uri)
@pytest.fixture
def db(motor_client):
LOGGER.info(">>>>>>>>>>init_db")
db = motor_client[settings.db_name]
return db
@pytest.fixture(autouse=True)
async def lifespan(motor_client, db):
# !pre:
LOGGER.info(">>>>>>>>>>init entity")
await init_beanie(database=db, document_models=[Cup, Water, WaterInCup]) # type: ignore
LOGGER.info(">>>>>>>>>>init entity completed")
yield None
# !after:
LOGGER.info(f">>>>>>>>>>clean database:{settings.db_name}")
await motor_client.drop_database(settings.db_name)
LOGGER.info(f">>>>>>>>>>clean finished.")
- test_chain_call.py
import pytest
from .models import Cup, Water, WaterInCup, OutWaterInCup
import logging
LOGGER = logging.getLogger(__name__)
class TestPaginate:
@pytest.fixture(autouse=True)
async def prepare_data(self):
for i in range(100):
cup = await Cup(name=f"Cup_{i}", width=40, height=100).save()
water = await Water(name=f"water_{i}").save()
await WaterInCup(cup=cup.id, water=water.id).save()
async def test_aggregate_query(self):
query = {}
pipeline = [
{
"$lookup": {
"from": "cup",
"let": {"cid": "$cup"},
"pipeline": [
{"$match": {"$expr": {"$eq": ["$_id", "$$cid"]}}},
],
"as": "cup_info",
}
},
{
"$unwind": {
"path": "$cup_info",
"preserveNullAndEmptyArrays": False,
}
},
{
"$lookup": {
"from": "water",
"let": {"wid": "$water"},
"pipeline": [
{"$match": {"$expr": {"$eq": ["$_id", "$$wid"]}}},
],
"as": "water_info",
}
},
{
"$unwind": {
"path": "$water_info",
"preserveNullAndEmptyArrays": False,
}
},
]
from beanie.odm.enums import SortDirection
res = (
await WaterInCup.find(query)
.sort([("_id", SortDirection.DESCENDING)])
.limit(10)
.skip(0)
.aggregate(pipeline, projection_model=OutWaterInCup)
.to_list()
)
LOGGER.critical(
">>>>>> the expected result length is 10 but got 100 instead. <<<<<<"
)
assert len(res) == 100 # PASSED
assert len(res) == 10 # FAILED
after debug i found: the find query's limit, skip, sort_expressions doesn't pass to the aggregation. see https://github.com/roman-right/beanie/blob/main/beanie/odm/queries/find.py#L527-L556
def aggregate(
self,
aggregation_pipeline: List[Any],
projection_model: Optional[Type[FindQueryProjectionType]] = None,
session: Optional[ClientSession] = None,
ignore_cache: bool = False,
**pymongo_kwargs,
) -> Union[
AggregationQuery[Dict[str, Any]],
AggregationQuery[FindQueryProjectionType],
]:
"""
Provide search criteria to the [AggregationQuery](https://roman-right.github.io/beanie/api/queries/#aggregationquery)
:param aggregation_pipeline: list - aggregation pipeline. MongoDB doc:
<https://docs.mongodb.com/manual/core/aggregation-pipeline/>
:param projection_model: Type[BaseModel] - Projection Model
:param session: Optional[ClientSession] - PyMongo session
:param ignore_cache: bool
:return:[AggregationQuery](https://roman-right.github.io/beanie/api/queries/#aggregationquery)
"""
self.set_session(session=session)
return self.AggregationQueryType(
aggregation_pipeline=aggregation_pipeline,
document_model=self.document_model,
projection_model=projection_model,
find_query=self.get_filter_query(),
ignore_cache=ignore_cache,
**pymongo_kwargs,
).set_session(session=self.session)
maybe we can pass them like
return self.AggregationQueryType(
aggregation_pipeline=aggregation_pipeline,
document_model=self.document_model,
projection_model=projection_model,
find_query=self.get_filter_query(),
ignore_cache=ignore_cache,
limit=self.limit_number, # pass the limit
skip=self.skip_number, # pass the skip
sort_expressions=self.sort_expressions, # pass the sort exp
**pymongo_kwargs,
).set_session(session=self.session)
and then use the params in aggregation.py's get_aggregation_pipeline function
def get_aggregation_pipeline(
self,
) -> List[Mapping[str, Any]]:
match_pipeline: List[Mapping[str, Any]] = (
[{"$match": self.find_query}] if self.find_query else []
)
# use the params
sort_pipeline = {"$sort": {i[0]: i[1] for i in self.sort_expressions}}
if sort_pipeline["$sort"]:
match_pipeline.append(sort_pipeline)
if self.skip_number and self.skip_number != 0:
match_pipeline.append({"$skip": self.skip_number})
if self.limit_number and self.limit_number != 0:
match_pipeline.append({"$limit": self.limit_number})
# end of patch
projection_pipeline: List[Mapping[str, Any]] = []
if self.projection_model:
projection = get_projection(self.projection_model)
if projection is not None:
projection_pipeline = [{"$project": projection}]
return match_pipeline + self.aggregation_pipeline + projection_pipeline
@roman-right
Hey! Sorry for the delay. It looks like a bug. I'll pick it up soon. Thank you
This issue is stale because it has been open 30 days with no activity.
This issue was closed because it has been stalled for 14 days with no activity.