spectree icon indicating copy to clipboard operation
spectree copied to clipboard

feat: extract the default response from annotations

Open kemingy opened this issue 1 year ago • 8 comments

Describe the feature

Extract the default response from endpoint annotations as the one for HTTP 200.

Additional context

No response

kemingy avatar Nov 26 '24 02:11 kemingy

I've been playing around and came up with the following concept (examples at the end). It covers most cases I could think of, and I thought it might be helpful for implementing this feature.

import inspect
from types import UnionType
from typing import Any, Literal, NoReturn, Optional, Union, get_args, get_origin

import flask
from pydantic import BaseModel
from spectree._pydantic import generate_root_model
from typing_extensions import TypedDict


def parse_model(cls: Any, code=200) -> str | tuple[int, Any]:
    if inspect.isclass(cls) and issubclass(cls, flask.Response):
        return f"HTTP_{code}"
    if cls is Union or cls is UnionType:
        cls = get_args(cls)
    else:
        cls = (cls,)
    return (code, cls)


def parse_tuple(ret: type[tuple]) -> str | tuple[int, Any]:
    args = get_args(ret)
    assert len(args) == 2, "Return type should be tuple[model, Literal[code]]"
    cls, code = args
    assert get_origin(code) is Literal, "Return type should be tuple[model, Literal[code]]"
    code = get_args(code)[0]
    if cls:
        return parse_model(cls, code)
    else:
        return f"HTTP_{code}"


def wrap_return_type(item: tuple[str, set[Any]], operation_id: Optional[str] = None) -> str | tuple[str, type[BaseModel]]:
    types = tuple(item[1])
    if len(types) < 0 or (len(types) == 1 and not types[0]):
        return item[0]

    if len(types) == 1:
        model = types[0]
    else:
        model = Union[types]

    if not inspect.isclass(model) or not issubclass(model, BaseModel):
        model = generate_root_model(model, f"{operation_id + '_' if operation_id else ''}{model.__name__ if hasattr(model, '__name__') else '_model'}")  # type: ignore

    return item[0], model


def parse_return_type(return_type: Any, operation_id: Optional[str] = None) -> list[str | tuple[str, type[BaseModel]]]:
    assert return_type is not inspect._empty, "Missing return type"

    responses: list[str | tuple[int, tuple]] = []
    if get_origin(return_type) is tuple:
        responses.append(parse_tuple(return_type))
    elif get_origin(return_type) is UnionType or get_origin(return_type) is Union:
        args = get_args(return_type)
        tples = [t for t in args if get_origin(t) is tuple]
        rest = tuple(t for t in args if get_origin(t) is not tuple)
        for t in tples:
            responses.append(parse_tuple(t))
        if rest:
            r = Union[rest]
            responses.append(parse_model(r))
    elif return_type is not NoReturn:
        responses.append(parse_model(return_type))

    # Aggregate response by code, e.g. `tuple[Foo, Literal[201]] | tuple[Bar, Literal[201]]` should become `201: Union[Foo, Bar]`
    responses_by_code: dict[str, set[Any]] = {}
    for r in responses:
        if isinstance(r, str):
            if r not in responses_by_code:
                responses_by_code[r] = set()
        else:
            code, model = r
            c = f"HTTP_{code}"
            if c not in responses_by_code:
                responses_by_code[c] = set()
            for m in model:
                responses_by_code[c].add(m)

    return [wrap_return_type(item, operation_id) for item in responses_by_code.items()]


"""Examples"""


class Foo(BaseModel):
    name: str


class Bar(TypedDict):
    value: int


print(
    parse_return_type(Foo)[0][1].model_json_schema()
)  # {'properties': {'name': {'title': 'Name', 'type': 'string'}}, 'required': ['name'], 'title': 'Foo', 'type': 'object'}
print(
    parse_return_type(Bar)[0][1].model_json_schema()
)  # {'$defs': {'Bar': {'properties': {'value': {'title': 'Value', 'type': 'integer'}}, 'required': ['value'], 'title': 'Bar', 'type': 'object'}}, '$ref': '#/$defs/Bar', 'title': 'Bar'}

print(
    parse_return_type(Foo | Bar)[0][1].model_json_schema()
)  # {'$defs': {'Bar': {'properties': {'value': {'title': 'Value', 'type': 'integer'}}, 'required': ['value'], 'title': 'Bar', 'type': 'object'}, 'Foo': {'properties': {'name': {'title': 'Name', 'type': 'string'}}, 'required': ['name'], 'title': 'Foo', 'type': 'object'}}, 'anyOf': [{'$ref': '#/$defs/Foo'}, {'$ref': '#/$defs/Bar'}], 'title': 'Union'}
print(
    parse_return_type(tuple[Foo | Bar, Literal[200]])[0][1].model_json_schema()
)  # {'$defs': {'Bar': {'properties': {'value': {'title': 'Value', 'type': 'integer'}}, 'required': ['value'], 'title': 'Bar', 'type': 'object'}, 'Foo': {'properties': {'name': {'title': 'Name', 'type': 'string'}}, 'required': ['name'], 'title': 'Foo', 'type': 'object'}}, 'anyOf': [{'$ref': '#/$defs/Foo'}, {'$ref': '#/$defs/Bar'}], 'title': '_model'}
print(
    parse_return_type(tuple[Foo, Literal[200]] | tuple[Bar, Literal[200]])[0][1].model_json_schema()
)  # {'$defs': {'Bar': {'properties': {'value': {'title': 'Value', 'type': 'integer'}}, 'required': ['value'], 'title': 'Bar', 'type': 'object'}, 'Foo': {'properties': {'name': {'title': 'Name', 'type': 'string'}}, 'required': ['name'], 'title': 'Foo', 'type': 'object'}}, 'anyOf': [{'$ref': '#/$defs/Foo'}, {'$ref': '#/$defs/Bar'}], 'title': 'Union'}
print(
    parse_return_type(tuple[Foo, Literal[200]] | Bar)[0][1].model_json_schema()
)  # {'$defs': {'Bar': {'properties': {'value': {'title': 'Value', 'type': 'integer'}}, 'required': ['value'], 'title': 'Bar', 'type': 'object'}, 'Foo': {'properties': {'name': {'title': 'Name', 'type': 'string'}}, 'required': ['name'], 'title': 'Foo', 'type': 'object'}}, 'anyOf': [{'$ref': '#/$defs/Foo'}, {'$ref': '#/$defs/Bar'}], 'title': 'Union'}
print(
    parse_return_type(tuple[Foo, Literal[200]] | tuple[Bar | str, Literal[200]])[0][1].model_json_schema()
)  # {'$defs': {'Bar': {'properties': {'value': {'title': 'Value', 'type': 'integer'}}, 'required': ['value'], 'title': 'Bar', 'type': 'object'}, 'Foo': {'properties': {'name': {'title': 'Name', 'type': 'string'}}, 'required': ['name'], 'title': 'Foo', 'type': 'object'}}, 'anyOf': [{'$ref': '#/$defs/Bar'}, {'type': 'string'}, {'$ref': '#/$defs/Foo'}], 'title': 'Union'}
print(
    parse_return_type(tuple[Foo, Literal[200]] | tuple[Bar, Literal[201]])
)  # [('HTTP_200', <class 'app.api.v1.rest.api_helper.Foo'>), ('HTTP_201', <class 'abc.Bar'>)]
print(parse_return_type(str)[0][1].model_json_schema())  # {'title': 'str', 'type': 'string'}
print(parse_return_type(list[str])[0][1].model_json_schema())  # {'items': {'type': 'string'}, 'title': 'list', 'type': 'array'}

It can be applied as

return_type = inspect.signature(func).return_annotation
for response in parse_return_type(return_type, operation_id):
    if isinstance(response, str):
        resp.codes.append(response)
    else:
        code, cls = response
        resp.add_model(int(code[5:]), cls, description=resp.code_descriptions.get(code))

HitkoDev avatar Jun 19 '25 12:06 HitkoDev

We should design the user interface first. I would expect something like this:

  • only for HTTP_200 (most common case)
@app.route("/api/user", methods=["POST"])
@spec.validate()
def user_profile(json: Profile) -> Message:
    pass
  • multiple code:model

Since Response() instance cannot be used for typing directly. (unless we implement Response[xxx])

@app.route("/api/user", methods=["POST"])
@spec.validate()
def user_profile(json: Profile) -> Annotated[Message, Response(HTTP_403=None)]:
    pass

But this is not the final decision. So welcome feedback.

kemingy avatar Jun 21 '25 00:06 kemingy

I've looked at the underlying frameworks and it seems that:

  • for flask and quark, views can return either T or tuple[T, S]
    def get() -> MyData:
      ...
      return MyData(...)
    
    def put() -> tuple[MyData, Literal[201]]:
      ...
      return MyData(...), 201
    
  • starlette always returns a Response object
    class TypedResponse(starlette.Response, Generic[T, S]):
      content: T
      status_code: S
    
    async def view() -> TypedResponse[MyData, Literal[201]]:
      return TypedResponse(MyData(...), status_code=201)
    
  • flacon views have no return type, and would require typed resp argument, e.g.
    class TypedResponse(falcon.Response, Generic[T, S]):
      media: T
      status: S
    
    def view(req: falcon.Request, resp: TypedResponse[MyData, Literal[201]]):
      resp.media = MyData(...)
      resp.status = 201
    

This interface aligns well with existing type checking tools to make sure views return what they declare to the API. T could be annotated to add metadata as Annotated[MyData, ResponseMeta(description="Create my model")], and we can use unions to specify different responses or response types, e.g. Union[MyData, tuple[Annotated[None, ResponseMeta(description="Model not found")], Literal[404]]].

HitkoDev avatar Jun 23 '25 08:06 HitkoDev

  • Do we need to provide different Response types for different web frameworks?
  • flask has a very complex return type, it can be json, json+status, json+headers, json+status+headers, response. It's hard to cover them.
  • Does this work for multiple response-status pairs? Like Union[Annotated[MyData, HTTPStatus.OK], Annotated[AnotherData, HTTPStatus.CREATED]].

kemingy avatar Jun 24 '25 08:06 kemingy

When it comes to flask return type and multiple response-status pairs, I think the solution is pretty straight-forward:

responses = get_args(response_type) if response_type is Union or response_type is UnionType else (response_type,)

parsed_responses = []
for item in responses:
    response = ResponseMeta(description=None)
    model = item
    code = 200
    if item is Annotation:
        model, response = get_args(item)
    if model is tuple:
        args = get_args(model)
        model = args[0]
        if len(args) > 1:
            if args[1] is Literal:  # json+status(+headers)
                code = get_args(args[1])[0]
            elif len(args) == 2 and args[1] is Literal:  # json+headers+status
                code = get_args(args[2])[0]
    parsed_responses.append((model, code, response))
    ...

# We now have a list of all `(model, code, meta)` for the given view

I'm not sure whether adding annotations to flask.Response has any benefit over the current approach, since it just adds more code without improving type safety:

@spectree.validate(
    resp=Response(
        HTTP_200=(MyModel, "Get my model"),
    )
)
def view() -> flask.Response:
    return jsonify(data)

# vs

@spectree.validate()
def view() -> Annotated[
    flask.Response,
    Response(
        HTTP_200=(MyModel, "Get my model"),
    ),
]:
    return jsonify(data)

One way we could avoid having different Response types for different web frameworks would be for each plugin to wrap the view function, e.g. (for falcon):

@wraps(view)
def wrap(*args, **kwargs):
    ...  # parse req data and add it to kwargs

    response = view(*args, **kwargs)  # flask-style return value
    data, code, headers = unpack(response)

    ...  # validate response

    resp.media = data
    resp.status = code
    for k, v in headers.items():
        resp.append_header(k, v)

This would potentially also reduce the amount of framework-specific code we need, while providing the same level of e2e type safety regardless of the framework, which might not be possible if we strictly adhere to each framework's Response type. I don't think such solution would require any breaking changes either.

How does that sound?

HitkoDev avatar Jun 24 '25 15:06 HitkoDev

I'm not sure whether adding annotations to flask.Response has any benefit over the current approach, since it just adds more code without improving type safety.

Agree. This doesn't improve any user experience.

One way we could avoid having different Response types for different web frameworks would be for each plugin to wrap the view function.

I don't quite get it. Can you provide a flask + falcon user interface example?

kemingy avatar Jun 26 '25 09:06 kemingy

Flask

class FindMyModel(MethodView):
    @spectree.validate()
    def get(
        self,
        id: int,
    ) -> Union[
        Annotated[
            MyModel,
            RespMeta(description="Model for the id"),
        ],
        Annotated[
            tuple[None, Literal[404]],
            RespMeta(description="Model not found"),
        ],
    ]:
        model = orm.get(MyModel, id)

        if not model:
            return None, 404

        return model

    @spectree.validate()
    def post(
        self,
        id: int,
        json: MyModelCreate,
    ) -> Union[
        Annotated[
            MyModel,
            RespMeta(description="Updated model"),
        ],
        Annotated[
            tuple[None, Literal[404]],
            RespMeta(description="Model not found"),
        ],
    ]:
        model = orm.get(MyModel, id)

        if not model:
            return None, 404

        model.update(json)

        return model


app.add_url_rule("/my-model/<int:id>", view_func=FindMyModel.as_view("find-my-model"))

Falcon

class FindMyModel:
    def on_get(
        self,
        req: falcon.Request,
        resp: falcon.Response,
        id: int,
    ) -> Union[
        Annotated[
            MyModel,
            RespMeta(description="Model for the id"),
        ],
        Annotated[
            tuple[None, Literal[404]],
            RespMeta(description="Model not found"),
        ],
    ]:
        model = orm.get(MyModel, id)

        if not model:
            return None, 404

        return model

    @spectree.validate()
    def on_post(
        self,
        req: falcon.Request,
        resp: falcon.Response,
        id: int,
        json: MyModelCreate,
    ) -> Union[
        Annotated[
            MyModel,
            RespMeta(description="Updated model"),
        ],
        Annotated[
            tuple[None, Literal[404]],
            RespMeta(description="Model not found"),
        ],
    ]:
        model = orm.get(MyModel, id)

        if not model:
            return None, 404

        model.update(json)

        return model


app.add_route("/my-model/{id:int}", FindMyModel())

Falcon code is mostly identical to Flask, and we don't set resp.media; instead, FalconPlugin.validate unpacks the returned value and sets resp attributes for us. (note: if return value is None, we still look at resp.media and resp.status_code - that way we maintain backwards compatibility and support for custom return values, i.e. file downloads)

HitkoDev avatar Jun 27 '25 12:06 HitkoDev

I see. This might affect more than we thought.

  1. It's backward compatible, but returning a jsonify() will break the type check in flask
  2. Some users are using the returned object from falcon, even though it's not recommended (but we're also using it here). This might affect their logic (when returning a Pydantic model) or break the type check (when returning their original objects)

As we discussed above, this doesn't improve the user experience. I don't have a good solution for now. Maybe we should hold this issue before we have a better solution.

kemingy avatar Jun 27 '25 15:06 kemingy