strawberry icon indicating copy to clipboard operation
strawberry copied to clipboard

Fields defined with `... = field(resolver=...)` are treated like regular dataclass attributes by type checkers

Open lukaspiatkowski opened this issue 6 months ago • 0 comments

Describe the Bug

The code below doesn't give any warnings using mypy with Strawberry plugin or PyRight:

from strawberry import field, type


def resolver() -> int:
    return 42


@type
class Foo:
    field1: str
    field2: int = field(resolver=resolver)

    @field
    def field3(self) -> str:
        return "Hello, world!"


reveal_type(Foo.__init__)  # def (self: interface.graphql.schema.Foo, *, field1: builtins.str)
reveal_type(Foo.field1)  # builtins.str
reveal_type(Foo.field2)  # builtins.int
reveal_type(Foo.field3)  # strawberry.field.StrawberryField

print(Foo(field1="test").field2)
# In runtime the above prints
# <bound method resolver of Foo(field1='test')>


def bar(foo: Foo) -> int:
    # This should not work, foo.field2 is a method
    return foo.field2


def baz(foo: Foo, val: int) -> None:
    # This should not work, foo.field2 is a method
    foo.field2 = val

So in short: when you define field2: int = field(resolver=resolver) the type checkers will treat field2 as an attribute on the instances of your class and allow you to read from or write to that attribute as if it was plain int and not a resolver. It leads to runtime errors.

System Information

  • Python: 3.11.5
  • Mypy: 1.7.0
  • PyRight: 1.1.331
  • Strawberry: 0.219.2

Additional Context

First of all I wanted to thank you for this amazing project. I've spend 5 months on rewriting our codebase with ~1000 objects and ~1000 resolvers from Graphene to Strawberry and the type improvements we got from it are amazing. There were few issues on the way, mostly due to our customizations, but this one bothered me as Strawberry didn't help to catch it.

The original Graphene code I had was something like this:

def resolve_field2(self: Foo) -> int:
    # do some auth checks
	return self.field2

class Foo(graphene.ObjectType):
	field2: int

Foo.resolve_field2 = resolve_field2

Due to rewriting a lot of objects at once I've used a codemod, so I haven't caught a bug in this version:

def resolve_field2(self: Foo) -> int:
    # do some auth checks
	return self.field2

@strawberry.type
class Foo:
	field2: int = strawberry.field(resolver=resolve_field2)

Running this code results in a runtime error, because self.field2 is a method, not an integer. Instead I had to write something like this:

def resolve_field2(self: Foo) -> int:
    # do some auth checks
	return self._field2

@strawberry.type
class Foo:
    _field2: strawberry.Private[int]
	field2: int = strawberry.field(resolver=resolve_field2)

It got me thinking how to improve this situation and remove this "footgun" waiting for someone else to stumble upon. The below is the best I can think of without modifying Strawberry:

import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from typing import (
    Any,
    Generic,
    Literal,
    Optional,
    Type,
    TypeVar,
    Union,
    final,
    overload,
)

import strawberry
from strawberry import type
from strawberry.extensions.field_extension import FieldExtension
from strawberry.field import _RESOLVER_TYPE, StrawberryField
from strawberry.permission import BasePermission
from strawberry.types.types import StrawberryObjectDefinition

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)


@final
class Resolver(ABC, Generic[T_co]):  # type: ignore
    @abstractmethod
    def __do_not_instantiate_this(self) -> None:
        ...


class ResolverFakeStrawberryObjectDefinition(StrawberryObjectDefinition):
    def __init__(self) -> None:
        super().__init__(
            name=Resolver.__name__,
            is_input=False,
            is_interface=False,
            origin=Resolver,
            description=None,
            interfaces=[],
            extend=False,
            directives=None,
            is_type_of=None,
            resolve_type=None,
            fields=[],
            concrete_of=None,
            type_var_map={},
        )

    def resolve_generic(self, wrapped_cls: Type[Any]) -> Type[Any]:
        passed_types = wrapped_cls.__args__
        assert (
            len(passed_types) == 1
        ), f"WaveAwaitable should be generic over one arg: {passed_types}"
        return StrawberryAnnotation(passed_types[0]).resolve()  # type: ignore

    @property
    def is_graphql_generic(self) -> bool:
        return True


Resolver.__strawberry_definition__ = ResolverFakeStrawberryObjectDefinition()  # type: ignore


@overload
def field(
    *,
    resolver: _RESOLVER_TYPE[T],
    name: Optional[str] = None,
    is_subscription: bool = False,
    description: Optional[str] = None,
    init: Literal[False] = False,
    permission_classes: Optional[list[Type[BasePermission]]] = None,
    deprecation_reason: Optional[str] = None,
    default: Any = dataclasses.MISSING,
    default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
    metadata: Optional[Mapping[Any, Any]] = None,
    directives: Optional[Sequence[object]] = (),
    extensions: Optional[list[FieldExtension]] = None,
    graphql_type: Optional[Any] = None,
) -> Resolver[T]:
    ...


@overload
def field(
    *,
    name: Optional[str] = None,
    is_subscription: bool = False,
    description: Optional[str] = None,
    init: Literal[True] = True,
    permission_classes: Optional[list[Type[BasePermission]]] = None,
    deprecation_reason: Optional[str] = None,
    default: Any = dataclasses.MISSING,
    default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
    metadata: Optional[Mapping[Any, Any]] = None,
    directives: Optional[Sequence[object]] = (),
    extensions: Optional[list[FieldExtension]] = None,
    graphql_type: Optional[Any] = None,
) -> Any:
    ...


@overload
def field(
    resolver: _RESOLVER_TYPE[T],
    *,
    name: Optional[str] = None,
    is_subscription: bool = False,
    description: Optional[str] = None,
    permission_classes: Optional[list[Type[BasePermission]]] = None,
    deprecation_reason: Optional[str] = None,
    default: Any = dataclasses.MISSING,
    default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
    metadata: Optional[Mapping[Any, Any]] = None,
    directives: Optional[Sequence[object]] = (),
    extensions: Optional[list[FieldExtension]] = None,
    graphql_type: Optional[Any] = None,
) -> StrawberryField:
    ...


def field(
    resolver: Optional[_RESOLVER_TYPE[Any]] = None,
    *,
    name: Optional[str] = None,
    is_subscription: bool = False,
    description: Optional[str] = None,
    permission_classes: Optional[list[Type[BasePermission]]] = None,
    deprecation_reason: Optional[str] = None,
    default: Any = dataclasses.MISSING,
    default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
    metadata: Optional[Mapping[Any, Any]] = None,
    directives: Optional[Sequence[object]] = (),
    extensions: Optional[list[FieldExtension]] = None,
    graphql_type: Optional[Any] = None,
    # This init parameter is used by PyRight to determine whether this field
    # is added in the constructor or not. It is not used to change
    # any behavior at the moment.
    init: Literal[True, False, None] = None,
) -> Any:
    return strawberry.field(  # type: ignore
        resolver=resolver,  # type: ignore
        name=name,
        is_subscription=is_subscription,
        description=description,
        permission_classes=permission_classes,
        deprecation_reason=deprecation_reason,
        default=default,
        default_factory=default_factory,
        metadata=metadata,
        directives=directives,
        extensions=extensions,
        graphql_type=graphql_type,
        init=init,  # type: ignore
    )


def resolver() -> int:
    return 42


@type
class Foo:
    field1: str
    field2: Resolver[int] = field(resolver=resolver)

    @field
    def field3(self) -> str:
        return "Hello, world!"


reveal_type(Foo.__init__)  # def (self: interface.graphql.schema.Foo, *, field1: builtins.str)
reveal_type(Foo.field1)  # builtins.str
reveal_type(Foo.field2)  # Resolver[int]
reveal_type(Foo.field3)  # strawberry.field.StrawberryField

print(Foo(field1="test").field2)
# In runtime the above prints
# <bound method resolver of Foo(field1='test')>

The above changes result in this code below giving type errors as expected:

def bar(foo: Foo) -> int:
    # Expression of type "Resolver[int]" cannot be assigned to return type "int"
    #   "Resolver[int]" is incompatible with "int"
    return foo.field2


def baz(foo: Foo, val: int) -> None:
    # Cannot assign member "field2" for type "Foo"
    #   "int" is incompatible with "Resolver[int]"
    foo.field2 = val


# Cannot instantiate abstract class "Resolver"
#   "Resolver.__do_not_instantiate_this" is not implemented
Resolver()

The main idea is to change the one of the overrides of the strawberry.field to return -> Resolver[T] if resolver = ... is provided. This forces you to annotate your field as field2: Resolver[int] = ... and since you cannot create an instance of Resolver it will prevent you from mistakenly writing or reading that field.

This would be a breaking change for Strawberry and arguably you sacrifice readability for type safety, but I personally think it is worth it, just wanted to share my thoughts and my solution here.

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

lukaspiatkowski avatar Feb 09 '24 14:02 lukaspiatkowski