cloudpickle + weakref + index_variadic
Cloudpickle (latest, 3.3.1 at time of writing) fails to pickle jaxtyped functions because of the weakref (introduced in 0.2.35 of jaxtyping).
I have a MWE using uv with inline packages:
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "beartype",
# "cloudpickle==3.1.1",
# "jaxtyping==0.2.35",
# "numpy",
# ]
# ///
import beartype
import cloudpickle
import numpy as np
from jaxtyping import Float, jaxtyped
@jaxtyped(typechecker=beartype.beartype)
def typechecked_fn(x: Float[np.ndarray, " d"]) -> float:
return np.sum(x).item()
def main():
dumped = cloudpickle.dumps(typechecked_fn)
print(dumped)
fn = cloudpickle.loads(dumped)
print(fn(np.array([1.0, 2.0])))
if __name__ == "__main__":
main()
When I run this with uv run scratch.py I get
Traceback (most recent call last):
File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 29, in <module>
main()
File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 22, in main
dumped = cloudpickle.dumps(typechecked_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1537, in dumps
cp.dump(obj)
File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1303, in dump
return super().dump(obj)
^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'weakref.ReferenceType' object
When I update jaxtyping to 0.3.2 (in the script metadata)
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "beartype",
# "cloudpickle==3.1.1",
# "jaxtyping==0.3.2",
# "numpy",
# ]
# ///
import beartype
import cloudpickle
import numpy as np
from jaxtyping import Float, jaxtyped
@jaxtyped(typechecker=beartype.beartype)
def typechecked_fn(x: Float[np.ndarray, " d"]) -> float:
return np.sum(x).item()
def main():
dumped = cloudpickle.dumps(typechecked_fn)
print(dumped)
fn = cloudpickle.loads(dumped)
print(fn(np.array([1.0, 2.0])))
if __name__ == "__main__":
main()
I get the same error:
Traceback (most recent call last):
File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 29, in <module>
main()
File "/users/PAS1576/samuelstevens/projects/saev/scratch.py", line 22, in main
dumped = cloudpickle.dumps(typechecked_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1537, in dumps
cp.dump(obj)
File "/users/PAS1576/samuelstevens/.cache/uv/environments-v2/scratch-f0adbe6995ced415/lib/python3.12/site-packages/cloudpickle/cloudpickle.py", line 1303, in dump
return super().dump(obj)
^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'weakref.ReferenceType' object
This MWE works on jaxtyping 0.2.34, but my other script fails on 0.2.34 with the same error as #198.
Interesting. Indeed jaxtyping internally uses a weak reference.
Tagging @ojw28 who introduced this as part of improving our garbage-handling in #258.
It's not immediately clear to me how this might be tackled. I'd be happy to take a suggestion / a PR on this if someone has any ideas.
It seems that cloudpickle is just a pile of hacks (see this comment). So perhaps it's possible to add a hack to cloudpickle for jaxtyping, since cloudpickle is used for submitit, which is a popular slurm manager.
@patrick-kidger do you have recommendations on how to serialize/parse jaxtyping decorators? I am looking at adding support to cloudpickle for jaxtyping decorators.
If the goal is to add special-case support for jaxtyping into cloudpickle (?) then you could perhaps aim to detect that a function is decorated, unwrap it (.__wrapped__), and then save the original function like normal. Then during serialization simply wrap it back up again.
That's my best guess - certainly probably easier than trying to pickle the wrapped function itself.