pydantic-core icon indicating copy to clipboard operation
pydantic-core copied to clipboard

Add a WalkCoreSchema

Open adriangb opened this issue 2 years ago • 7 comments

@samuelcolvin @dmontagu we discussed this previously and I shot it down because just implementing what we have in pydantic in Rust would not be much faster (aside from the speedup of calling CPython APIs from Rust) because for every 2-3 key accesses we do in Rust (which would be faster) we'd be calling into Python and back (so the absolute change may not be very large and there's the FFI slowdown to contend with).

I was thinking about it some more and I think if we change the API we have in pydantic to this we can get a much larger speedup. Essentially, instead of having a single callback for all schemas I'm doing a different callback for each schema. This serves as a sort of "filter" to minimize calls into Python. Out of the ~3 "walks" we do in pydantic this covers two:

  • https://github.com/pydantic/pydantic/blob/667cd3776ee40e06018d0b7ff477c6cd0199b098/pydantic/_internal/_core_utils.py#L463-L464
  • https://github.com/pydantic/pydantic/blob/667cd3776ee40e06018d0b7ff477c6cd0199b098/pydantic/_internal/_core_utils.py#L504 (there are some others for discriminated unions and such, I haven't looked into those)

However, this does not cover the case where we need to visit every schema: https://github.com/pydantic/pydantic/blob/667cd3776ee40e06018d0b7ff477c6cd0199b098/pydantic/_internal/_core_utils.py#L449-L450

For that last case I see a couple of options:

  • Add a visit_all_schemas callback that slows things down significantly but allows visiting all schemas (and hence collecting all refs).
  • Add a visit_schema_with_ref that gets called for any schema with a ref. This seems somewhat reasonable but it may be a bit too "specialized" of an implementation for our current use case. That is, it's a bandaid solution to a poor API.
  • Add a more powerful filter predicate system. For example you could have Walk(visit=[if_schema_has_key("ref")(callback), if_schema_has_type("int")(callback), (if_schema_has_type("int") & if_schema_has_key("ref"))(callback)]). This maybe also works to get rid of the dozens of arguments to the constructor this implementation currently has.

adriangb avatar Nov 29 '23 21:11 adriangb

Codecov Report

Attention: Patch coverage is 93.04207% with 43 lines in your changes missing coverage. Please review.

Project coverage is 89.83%. Comparing base (7fa450d) to head (772c8c3). Report is 211 commits behind head on main.

Files Patch % Lines
src/walk_core_schema.rs 93.01% 43 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1099      +/-   ##
==========================================
+ Coverage   89.70%   89.83%   +0.12%     
==========================================
  Files         106      107       +1     
  Lines       16364    16982     +618     
  Branches       35       35              
==========================================
+ Hits        14680    15255     +575     
- Misses       1677     1720      +43     
  Partials        7        7              
Files Coverage Δ
python/pydantic_core/__init__.py 92.59% <ø> (ø)
src/lib.rs 87.50% <100.00%> (+0.35%) :arrow_up:
src/walk_core_schema.rs 93.01% <93.01%> (ø)

Continue to review full report in Codecov by Sentry.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 7fa450d...772c8c3. Read the comment docs.

codecov[bot] avatar Nov 29 '23 21:11 codecov[bot]

CodSpeed Performance Report

Merging #1099 will degrade performances by 18.81%

Comparing walk-core-schema (772c8c3) with main (7fa450d)

Summary

❌ 1 regressions ✅ 139 untouched benchmarks

:warning: Please fix the performance issues or acknowledge them on CodSpeed.

Benchmarks breakdown

Benchmark main walk-core-schema Change
test_core_future_str 31.5 µs 38.8 µs -18.81%

codspeed-hq[bot] avatar Nov 29 '23 21:11 codspeed-hq[bot]

I like the predicate filter idea, I think hopefully it'll lead to a smaller implementation that'll also be easier to maintain if we introduce future schema types.

You could even make the predicates user-suppliable in Python which may help adopting the implementation (move more predicates to Rust as needed).

davidhewitt avatar Dec 01 '23 14:12 davidhewitt

That's an interesting idea. Like make something like:

VisitPredicate = Callable[[CoreSchema], bool]

That gets run from Rust. Then write any predicates we need in Python. For example:

@dataclass
class CombinedPredicate:
    call: Callable[[CoreSchema], bool]
    def __call__(self, schema):
        return self.call(schema)

class CombinablePredicate:
    def __or__(self, other):
        return CombinedPredicate(lambda s: self(s) or other(s))

class HasRef(CombinablePredicate):
    def __call__(self, schema: CoreSchema) -> bool:
        return bool(schema.get('ref', False))

And once those are stabilized we can move them to Rust. Is that what you had in mind?

adriangb avatar Dec 01 '23 15:12 adriangb

@davidhewitt I implemented the filter API as discussed above

adriangb avatar Dec 01 '23 19:12 adriangb

@davidhewitt I benchmarked this and it's coming out no faster than our existing Python version (which calls a Python function at every level in addition to doing the traversal in Python) even when there is no filter (so it never calls into Python).

import timeit
from typing import Any, Callable

from pydantic._internal._core_utils import walk_core_schema

from pydantic_core import CoreSchema, WalkCoreSchema
from pydantic_core import core_schema as cs


def plain_ser_func(x: Any) -> str:
    return 'abc'


def wrap_ser_func(x: Any, handler: cs.SerializerFunctionWrapHandler) -> Any:
    return handler(x)



def no_info_val_func(x: Any) -> Any:
    return x



def no_info_wrap_val_func(x: Any, handler: cs.ValidatorFunctionWrapHandler) -> Any:
    return handler(x)


class NamedClass:
    pass


schema = cs.union_schema(
    [
        cs.any_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)),
        cs.none_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)),
        cs.bool_schema(serialization=cs.simple_ser_schema('bool')),
        cs.int_schema(serialization=cs.simple_ser_schema('int')),
        cs.float_schema(serialization=cs.simple_ser_schema('float')),
        cs.decimal_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)),
        cs.str_schema(serialization=cs.simple_ser_schema('str')),
        cs.bytes_schema(serialization=cs.simple_ser_schema('bytes')),
        cs.date_schema(serialization=cs.simple_ser_schema('date')),
        cs.time_schema(serialization=cs.simple_ser_schema('time')),
        cs.datetime_schema(serialization=cs.simple_ser_schema('datetime')),
        cs.timedelta_schema(serialization=cs.simple_ser_schema('timedelta')),
        cs.literal_schema(
            expected=[1, 2, 3],
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.is_instance_schema(
            cls=NamedClass,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.is_subclass_schema(
            cls=NamedClass,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.callable_schema(
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.list_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.tuple_positional_schema(
            [cs.int_schema(serialization=cs.simple_ser_schema('int'))],
            extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.tuple_variable_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.set_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.frozenset_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.generator_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.dict_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.no_info_after_validator_function(
            no_info_val_func,
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.no_info_before_validator_function(
            no_info_val_func,
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.no_info_wrap_validator_function(
            no_info_wrap_val_func,
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.no_info_plain_validator_function(
            no_info_val_func,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.with_default_schema(
            cs.int_schema(),
            default=1,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.nullable_schema(
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.union_schema(
            [
                cs.int_schema(),
                cs.str_schema(),
            ],
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.tagged_union_schema(
            {
                'a': cs.int_schema(),
                'b': cs.str_schema(),
            },
            'type',
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.chain_schema(
            [
                cs.int_schema(),
                cs.str_schema(),
            ],
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.lax_or_strict_schema(
            cs.int_schema(),
            cs.str_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.json_or_python_schema(
            cs.int_schema(),
            cs.str_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.typed_dict_schema(
            {'a': cs.typed_dict_field(cs.int_schema())},
            computed_fields=[
                cs.computed_field(
                    'b',
                    cs.int_schema(),
                )
            ],
            extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.model_schema(
            NamedClass,
            cs.model_fields_schema(
                {'a': cs.model_field(cs.int_schema())},
                extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')),
                computed_fields=[
                    cs.computed_field(
                        'b',
                        cs.int_schema(),
                    )
                ],
            ),
        ),
        cs.dataclass_schema(
            NamedClass,
            cs.dataclass_args_schema(
                'Model',
                [cs.dataclass_field('a', cs.int_schema())],
                computed_fields=[
                    cs.computed_field(
                        'b',
                        cs.int_schema(),
                    )
                ],
            ),
            ['a'],
        ),
        cs.call_schema(
            cs.arguments_schema(
                [cs.arguments_parameter('x', cs.int_schema())],
                serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
            ),
            no_info_val_func,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.custom_error_schema(
            cs.int_schema(),
            custom_error_type='CustomError',
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.json_schema(
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.url_schema(
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.multi_host_url_schema(
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.definitions_schema(
            cs.int_schema(),
            [
                cs.int_schema(ref='#/definitions/int'),
            ],
        ),
        cs.definition_reference_schema(
            '#/definitions/int',
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.uuid_schema(
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
    ]
)


def walk_core() -> None:
    WalkCoreSchema().walk(schema)

Recurse = Callable[[cs.CoreSchema, 'Walk'], cs.CoreSchema]
Walk = Callable[[cs.CoreSchema, Recurse], cs.CoreSchema]


def visit_pydantic(schema: cs.CoreSchema, recurse: Recurse) -> CoreSchema:
    return recurse(schema, visit_pydantic)

def walk_pydantic() -> None:
    walk_core_schema(schema, visit_pydantic)


print(timeit.timeit(walk_core, number=1000))
print(timeit.timeit(walk_pydantic, number=1000))

adriangb avatar Dec 03 '23 16:12 adriangb

Let's revisit this once #1085 gets merged which might improve performance significantly.

samuelcolvin avatar Jan 25 '24 13:01 samuelcolvin

Closing as stale, given that we closed #615 above

sydney-runkle avatar Aug 17 '24 17:08 sydney-runkle