jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

cloudpickle + weakref + index_variadic

Open samuelstevens opened this issue 6 months ago • 4 comments

Cloudpickle (latest, 3.3.1 at time of writing) fails to pickle jaxtyped functions because of the weakref (introduced in 0.2.35 of jaxtyping).

I have a MWE using uv with inline packages:

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "beartype",
#     "cloudpickle==3.1.1",
#     "jaxtyping==0.2.35",
#     "numpy",
# ]
# ///
import beartype
import cloudpickle
import numpy as np
from jaxtyping import Float, jaxtyped


@jaxtyped(typechecker=beartype.beartype)
def typechecked_fn(x: Float[np.ndarray, " d"]) -> float:
    return np.sum(x).item()


def main():
    dumped = cloudpickle.dumps(typechecked_fn)
    print(dumped)
    fn = cloudpickle.loads(dumped)
    print(fn(np.array([1.0, 2.0])))


if __name__ == "__main__":
    main()

When I run this with uv run scratch.py I get

Traceback (most recent call last):
  File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 29, in <module>
    main()
  File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 22, in main
    dumped = cloudpickle.dumps(typechecked_fn)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1537, in dumps
    cp.dump(obj)
  File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1303, in dump
    return super().dump(obj)
           ^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'weakref.ReferenceType' object

When I update jaxtyping to 0.3.2 (in the script metadata)

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "beartype",
#     "cloudpickle==3.1.1",
#     "jaxtyping==0.3.2",
#     "numpy",
# ]
# ///
import beartype
import cloudpickle
import numpy as np
from jaxtyping import Float, jaxtyped


@jaxtyped(typechecker=beartype.beartype)
def typechecked_fn(x: Float[np.ndarray, " d"]) -> float:
    return np.sum(x).item()


def main():
    dumped = cloudpickle.dumps(typechecked_fn)
    print(dumped)
    fn = cloudpickle.loads(dumped)
    print(fn(np.array([1.0, 2.0])))


if __name__ == "__main__":
    main()

I get the same error:

Traceback (most recent call last):
  File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 29, in <module>
    main()
  File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 22, in main
    dumped = cloudpickle.dumps(typechecked_fn)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1537, in dumps
    cp.dump(obj)
  File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1303, in dump
    return super().dump(obj)
           ^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'weakref.ReferenceType' object

This MWE works on jaxtyping 0.2.34, but my other script fails on 0.2.34 with the same error as #198.

samuelstevens avatar Jun 26 '25 19:06 samuelstevens

Interesting. Indeed jaxtyping internally uses a weak reference.

Tagging @ojw28 who introduced this as part of improving our garbage-handling in #258.

It's not immediately clear to me how this might be tackled. I'd be happy to take a suggestion / a PR on this if someone has any ideas.

patrick-kidger avatar Jun 26 '25 22:06 patrick-kidger

It seems that cloudpickle is just a pile of hacks (see this comment). So perhaps it's possible to add a hack to cloudpickle for jaxtyping, since cloudpickle is used for submitit, which is a popular slurm manager.

samuelstevens avatar Jun 27 '25 02:06 samuelstevens

@patrick-kidger do you have recommendations on how to serialize/parse jaxtyping decorators? I am looking at adding support to cloudpickle for jaxtyping decorators.

samuelstevens avatar Jul 18 '25 01:07 samuelstevens

If the goal is to add special-case support for jaxtyping into cloudpickle (?) then you could perhaps aim to detect that a function is decorated, unwrap it (.__wrapped__), and then save the original function like normal. Then during serialization simply wrap it back up again.

That's my best guess - certainly probably easier than trying to pickle the wrapped function itself.

patrick-kidger avatar Jul 18 '25 09:07 patrick-kidger