strawberry-sqlalchemy
strawberry-sqlalchemy copied to clipboard
Relationship fails when attribute name != column name
My model has a Survey class with owner_id attribute, which is using a different column name (user_id) for historic reasons
class User(Base):
__tablename__ = "user"
user_id: Mapped[int] = mapped_column("id", primary_key=True)
username: Mapped[str]
class Survey(Base):
__tablename__ = "survey"
survey_id: Mapped[int] = mapped_column("id", primary_key=True)
name: Mapped[str]
owner_id: Mapped[int] = mapped_column("user_id", ForeignKey("user.id"))
owner: Mapped[User] = relationship("User", backref="surveys", lazy=True)
import models
@strawberry_sqlalchemy_mapper.type(models.User)
class User:
pass
@strawberry_sqlalchemy_mapper.type(models.Survey)
class Survey:
pass
@strawberry.type
class Query:
@strawberry.field
def survey(self, info: Info, survey_id: int) -> typing.Optional[Survey]:
db = info.context["db"]
return db.execute(select(models.Survey).where(models.Survey.survey_id == survey_id)).scalars().first()
In relationship_resolver_for, the code tries to access getattr(self, sql_column_name) instead of getattr(self, python_attr_name)
query MyQuery {
survey(surveyId: 1) {
name
owner {
username
}
}
}
File ".../strawberry_sqlalchemy_mapper/mapper.py", line 409, in <listcomp>
getattr(self, local.key)
AttributeError: 'Survey' object has no attribute 'user_id'
@TimDumol , we ran into this issue ourselves and see errors from two places where the relationship value is resolved on the respective row using the sql_column_name rather than the python_attr_name
StrawberrySQLAlchemyLoader#loader_for
def group_by_remote_key(row: Any) -> Tuple:
return tuple(
[
getattr(row, remote.key) <- uses sql_column_name
for _, remote in relationship.local_remote_pairs
]
)
StrawberrySQLAlchemyMapper#relationship_resolver_for
relationship_key = tuple(
[
getattr(self, local.key) <- uses sql_column_name
for local, _ in relationship.local_remote_pairs
]
)
We have a temporary work around by overriding the respective methods and building a column name to attribute name map from the respective relationship mapper but keen to have a central fix for this.
I'm happy to contribute a fix if we can agree an approach.
Example fix:
def build_get_col(mapper):
attr_names = mapper.attr.keys()
col_to_attr = {
mapper.c[attr_name].name: attr_name for attr_name in attr_names if attr_name in mapper.c
}
def get_col(row: Any, col: str):
attr = col_to_attr[col]
return getattr(row, attr)
return get_col
##StrawberrySQLAlchemyLoader
def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
"""
Retrieve or create a DataLoader for the given relationship
"""
try:
return self._loaders[relationship]
except KeyError:
related_model = relationship.entity.entity
get_col = build_get_col(related_model.mapper) #get_col created here
async def load_fn(keys: List[Tuple]) -> List[Any]:
query = select(related_model).filter(
tuple_(
*[remote for _, remote in relationship.local_remote_pairs]
).in_(keys)
)
if relationship.order_by:
query = query.order_by(*relationship.order_by)
rows = self.bind.scalars(query).all()
def group_by_remote_key(row: Any) -> Tuple:
return tuple(
[
get_col(row, remote.key)
for _, remote in relationship.local_remote_pairs
]
)
grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list)
for row in rows:
grouped_keys[group_by_remote_key(row)].append(row)
if relationship.uselist:
return [grouped_keys[key] for key in keys]
else:
return [
grouped_keys[key][0] if grouped_keys[key] else None
for key in keys
]
self._loaders[relationship] = DataLoader(load_fn=load_fn)
return self._loaders[relationship]
##StrawberrySQLAlchemyMapper
def relationship_resolver_for(
self, relationship: RelationshipProperty
) -> Callable[..., Awaitable[Any]]:
"""
Return an async field resolver for the given relationship,
so as to avoid n+1 query problem.
"""
get_col = build_get_col(relationship.parent) #get_col created here
async def resolve(self, info: Info):
instance_state = cast(InstanceState, inspect(self))
if relationship.key not in instance_state.unloaded:
related_objects = getattr(self, relationship.key)
else:
relationship_key = tuple(
[
get_col(self, local.key)
for local, _ in relationship.local_remote_pairs
]
)
if any(item is None for item in relationship_key):
if relationship.uselist:
return []
else:
return None
if isinstance(info.context, dict):
loader = info.context["sqlalchemy_loader"]
else:
loader = info.context.sqlalchemy_loader
related_objects = await loader.loader_for(relationship).load(
relationship_key
)
return related_objects
setattr(resolve, _IS_GENERATED_RESOLVER_KEY, True)
return resolve
Hi @cpsnowden - sorry totally forgot I assigned myself to this. Your proposed fix looks good to me. Feel free to PR it!
Thanks @TimDumol - see that @gravy-jones-locker is addressing this in https://github.com/strawberry-graphql/strawberry-sqlalchemy/pull/25