Creating instances of `jaxtyped` dataclasses is slow
Annotating a dataclass with @jaxtyped makes creating instances of that class ~1000x slower.
This is especially problematic in cases where the entire package is jaxtyped with install_import_hook(), because it is not possible to exclude a frequently used dataclass from being jaxtyped.
Here is a small benchmark:
from dataclasses import dataclass
import time
from jaxtyping import jaxtyped
from beartype import beartype
N = 1000
class VanillaClass:
def __init__(self, foo: str):
self.foo = foo
@dataclass
class VanillaDataclass:
foo: str
@jaxtyped(typechecker=beartype)
class JaxtypedClass:
def __init__(self, foo: str):
self.foo = foo
@jaxtyped(typechecker=beartype)
@dataclass
class JaxtypedDataclass:
foo: str
@beartype
class BeartypeClass:
def __init__(self, foo: str):
self.foo = foo
@beartype
@dataclass
class BeartypeClassDataclass:
foo: str
for c in [
VanillaClass,
VanillaDataclass,
JaxtypedClass,
JaxtypedDataclass,
BeartypeClass,
BeartypeClassDataclass
]:
now = time.time_ns()
for _ in range(N):
c("foo")
run_time = (time.time_ns() - now) / N
print(f"{c.__name__:>25}: {run_time} ns")
Output:
VanillaClass: 98.0 ns
VanillaDataclass: 128.0 ns
JaxtypedClass: 125.0 ns
JaxtypedDataclass: 282535.0 ns
BeartypeClass: 223.0 ns
BeartypeDataclass: 243.0 ns
I do not know if this helps, but quick profiling revealed that most of the time is spend in the _check_dataclass_annotations() method.
Hmm, is this not just the overhead from doing the actual type checking itself?
For what it's worth I don't think we currently respect @typing.no_type_check for dataclasses (only functions), but I'd be happy to take a PR updating that!
I am not very used to the actual beartype and jaxtyping code, so take everything I say with a grain of salt, but I think it is suspicious that beartype is ~1000x times faster than jaxtyped. In my impression they achieve the same thing in this case, as no special jaxtyping array syntax is used.
Did a bit of line profiling on the above example, see below.
So while there's ~20% of time that could probably be saved (e.g. by caching results of the operations except the jaxtyped() call itself), it's mostly the type checking call, unfortunately.
Edit: hold on: isn't the actual type checking call just 3.6% of time?... (the last line)
In which case, can't we simply cache this f somewhere in the dataclass itself, or somewhere else?
Timer unit: 1e-06 s
Total time: 0.201161 s
Function: _check_dataclass_annotations at line 1
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1 def _check_dataclass_annotations(self, typechecker):
2 """Creates and calls a function that checks the attributes of `self`
3
4 `self` should be a dataclass instance. `typechecker` should be e.g.
5 `beartype.beartype` or `typeguard.typechecked`.
6 """
7 1000 1748.0 1.7 0.9 parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
8 1000 113.0 0.1 0.1 values = {}
9 2000 1397.0 0.7 0.7 for field in dataclasses.fields(self):
10 1000 124.0 0.1 0.1 annotation = field.type
11 1000 161.0 0.2 0.1 if isinstance(annotation, str):
12 # Don't check stringified annotations. These are basically impossible to
13 # resolve correctly, so just skip them.
14 continue
15 1000 742.0 0.7 0.4 if get_origin(annotation) is type:
16 args = get_args(annotation)
17 if len(args) == 1 and isinstance(args[0], str):
18 # We also special-case this one kind of partially-stringified type
19 # annotation, so as to support Equinox <v0.11.1.
20 # This was fixed in Equinox in
21 # https://github.com/patrick-kidger/equinox/pull/543
22 continue
23 1000 75.0 0.1 0.0 try:
24 1000 537.0 0.5 0.3 value = getattr(self, field.name) # noqa: F841
25 1000 117.0 0.1 0.1 except AttributeError:
26 1000 109.0 0.1 0.1 continue # allow uninitialised fields, which are allowed on dataclasses
27
28 parameters.append(
29 inspect.Parameter(
30 field.name,
31 inspect.Parameter.POSITIONAL_OR_KEYWORD,
32 annotation=field.type,
33 )
34 )
35 values[field.name] = value
36
37 1000 1853.0 1.9 0.9 signature = inspect.Signature(parameters)
38 2000 22255.0 11.1 11.1 f = _make_fn_with_signature(
39 1000 152.0 0.2 0.1 self.__class__.__name__,
40 1000 137.0 0.1 0.1 self.__class__.__qualname__,
41 1000 137.0 0.1 0.1 self.__class__.__module__,
42 1000 73.0 0.1 0.0 signature,
43 1000 70.0 0.1 0.0 output=False,
44 )
45 1000 164024.0 164.0 81.5 f = jaxtyped(f, typechecker=typechecker)
46 1000 7337.0 7.3 3.6 f(self, **values)
Oh interesting! Thank you for profiling this -- I agree, caching sounds reasonable.
I'd be happy to take a PR on this!