jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Creating instances of `jaxtyped` dataclasses is slow

Open padix-key opened this issue 1 year ago • 5 comments

Annotating a dataclass with @jaxtyped makes creating instances of that class ~1000x slower. This is especially problematic in cases where the entire package is jaxtyped with install_import_hook(), because it is not possible to exclude a frequently used dataclass from being jaxtyped.

Here is a small benchmark:

from dataclasses import dataclass
import time
from jaxtyping import jaxtyped
from beartype import beartype


N = 1000


class VanillaClass:

    def __init__(self, foo: str):
        self.foo = foo

@dataclass
class VanillaDataclass:

    foo: str

@jaxtyped(typechecker=beartype)
class JaxtypedClass:

    def __init__(self, foo: str):
        self.foo = foo

@jaxtyped(typechecker=beartype)
@dataclass
class JaxtypedDataclass:

    foo: str

@beartype
class BeartypeClass:

    def __init__(self, foo: str):
        self.foo = foo

@beartype
@dataclass
class BeartypeClassDataclass:

    foo: str


for c in [
    VanillaClass,
    VanillaDataclass,
    JaxtypedClass,
    JaxtypedDataclass,
    BeartypeClass,
    BeartypeClassDataclass
]:
    now = time.time_ns()
    for _ in range(N):
        c("foo")
    run_time = (time.time_ns() - now) / N
    print(f"{c.__name__:>25}: {run_time} ns")

Output:

             VanillaClass: 98.0 ns
         VanillaDataclass: 128.0 ns
            JaxtypedClass: 125.0 ns
        JaxtypedDataclass: 282535.0 ns
            BeartypeClass: 223.0 ns
        BeartypeDataclass: 243.0 ns

padix-key avatar Jul 17 '24 10:07 padix-key

I do not know if this helps, but quick profiling revealed that most of the time is spend in the _check_dataclass_annotations() method.

padix-key avatar Jul 17 '24 10:07 padix-key

Hmm, is this not just the overhead from doing the actual type checking itself?

For what it's worth I don't think we currently respect @typing.no_type_check for dataclasses (only functions), but I'd be happy to take a PR updating that!

patrick-kidger avatar Jul 18 '24 19:07 patrick-kidger

I am not very used to the actual beartype and jaxtyping code, so take everything I say with a grain of salt, but I think it is suspicious that beartype is ~1000x times faster than jaxtyped. In my impression they achieve the same thing in this case, as no special jaxtyping array syntax is used.

padix-key avatar Jul 22 '24 11:07 padix-key

Did a bit of line profiling on the above example, see below.

So while there's ~20% of time that could probably be saved (e.g. by caching results of the operations except the jaxtyped() call itself), it's mostly the type checking call, unfortunately.

Edit: hold on: isn't the actual type checking call just 3.6% of time?... (the last line)

In which case, can't we simply cache this f somewhere in the dataclass itself, or somewhere else?

Timer unit: 1e-06 s

Total time: 0.201161 s
Function: _check_dataclass_annotations at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     1                                           def _check_dataclass_annotations(self, typechecker):
     2                                               """Creates and calls a function that checks the attributes of `self`
     3                                           
     4                                               `self` should be a dataclass instance. `typechecker` should be e.g.
     5                                               `beartype.beartype` or `typeguard.typechecked`.
     6                                               """
     7      1000       1748.0      1.7      0.9      parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
     8      1000        113.0      0.1      0.1      values = {}
     9      2000       1397.0      0.7      0.7      for field in dataclasses.fields(self):
    10      1000        124.0      0.1      0.1          annotation = field.type
    11      1000        161.0      0.2      0.1          if isinstance(annotation, str):
    12                                                       # Don't check stringified annotations. These are basically impossible to
    13                                                       # resolve correctly, so just skip them.
    14                                                       continue
    15      1000        742.0      0.7      0.4          if get_origin(annotation) is type:
    16                                                       args = get_args(annotation)
    17                                                       if len(args) == 1 and isinstance(args[0], str):
    18                                                           # We also special-case this one kind of partially-stringified type
    19                                                           # annotation, so as to support Equinox <v0.11.1.
    20                                                           # This was fixed in Equinox in
    21                                                           # https://github.com/patrick-kidger/equinox/pull/543
    22                                                           continue
    23      1000         75.0      0.1      0.0          try:
    24      1000        537.0      0.5      0.3              value = getattr(self, field.name)  # noqa: F841
    25      1000        117.0      0.1      0.1          except AttributeError:
    26      1000        109.0      0.1      0.1              continue  # allow uninitialised fields, which are allowed on dataclasses
    27                                           
    28                                                   parameters.append(
    29                                                       inspect.Parameter(
    30                                                           field.name,
    31                                                           inspect.Parameter.POSITIONAL_OR_KEYWORD,
    32                                                           annotation=field.type,
    33                                                       )
    34                                                   )
    35                                                   values[field.name] = value
    36                                           
    37      1000       1853.0      1.9      0.9      signature = inspect.Signature(parameters)
    38      2000      22255.0     11.1     11.1      f = _make_fn_with_signature(
    39      1000        152.0      0.2      0.1          self.__class__.__name__,
    40      1000        137.0      0.1      0.1          self.__class__.__qualname__,
    41      1000        137.0      0.1      0.1          self.__class__.__module__,
    42      1000         73.0      0.1      0.0          signature,
    43      1000         70.0      0.1      0.0          output=False,
    44                                               )
    45      1000     164024.0    164.0     81.5      f = jaxtyped(f, typechecker=typechecker)
    46      1000       7337.0      7.3      3.6      f(self, **values)

aldanor avatar Aug 24 '24 00:08 aldanor

Oh interesting! Thank you for profiling this -- I agree, caching sounds reasonable.

I'd be happy to take a PR on this!

patrick-kidger avatar Aug 25 '24 12:08 patrick-kidger