Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
I need to remove type hints from functions that are type checked and need to be called in joblib.Parallel or other multiprocessing pipelines; getting tracebacks like this:
joblib.externals.loky.process_executor._RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
call_item = call_queue.get(block=True, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
return _lookup_class_or_track(class_tracker_id, skeleton_class)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
self.data[ref(key, self._remove)] = value
File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
return hash(cls._get_props())
File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
main(cfg)
File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
trainer.evaluate(
File "/home/ray/trainer.py", line 427, in evaluate
predictions, objective = model_and_objective(batch)
File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
predictions = Parallel(n_jobs=self.n_jobs)(
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
return output if self.return_generator else list(output)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
yield from self._retrieve()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
self._raise_error_fast()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
error_job.get_result(self.timeout)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
return self._return_or_raise()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
joblib.externals.loky.process_executor._RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
call_item = call_queue.get(block=True, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
return _lookup_class_or_track(class_tracker_id, skeleton_class)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
self.data[ref(key, self._remove)] = value
File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
return hash(cls._get_props())
File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
main(cfg)
File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
trainer.evaluate(
File "/home/ray/trainer.py", line 427, in evaluate
predictions, objective = model_and_objective(batch)
File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
predictions = Parallel(n_jobs=self.n_jobs)(
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
return output if self.return_generator else list(output)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
yield from self._retrieve()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
self._raise_error_fast()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
error_job.get_result(self.timeout)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
return self._return_or_raise()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
Looks like they're not getting de/serialised correctly, so the index_variadic attribute doesn't make it across.
If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement __setstate__ and __getstate__?)
I'm facing this same issue when trying to save an optax optimizer state using cloudpickle. Hope this issue gets fixed.
File "/root/optimizer.py", line 89, in run_train_on_modal optimizer_state = optax.tree_utils.tree_set(optimizer_state, inner_state=cloudpickle.load(f)) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class return _lookup_class_or_track(class_tracker_id, skeleton_class) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^ File "/usr/lib/python3.11/weakref.py", line 428, in __setitem__ self.data[ref(key, self._remove)] = value ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 321, in __hash__ return hash(cls._get_props()) ^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 306, in _get_props cls.index_variadic, ^^^^^^^^^^^^^^^^^^ AttributeError: type object 'Float[Array, '*shape']' has no attribute 'index_variadic'
Do you have a MWE?
Yes, I'm training a model with JAX and Equinox, and I am trying to save the optimizer state.
`lr_scheduler = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=learning_rate, warmup_steps=warmup_iters if init_from == 'scratch' else 0, decay_steps=lr_decay_iters - iter_num, end_value=min_lr, ) optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=lr_scheduler, b1=beta1, b2=beta2)
optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array))
checkpoint_params = { "optimizer_state": optimizer_state }
with open(checkpoint_params_file, "wb") as f: cloudpickle.dump(checkpoint_params, f)`
Looks like they're not getting de/serialised correctly, so the
index_variadicattribute doesn't make it across.If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement
__setstate__and__getstate__?)
I am also encountering this issue, but only with ray (just using cloudpickle on its own seems to work now). This MWE reproduces the issue:
pip install jaxtyping jax 'ray[default]'
import jax
import ray
from jax import numpy as jnp
from jaxtyping import Int
ray.init()
@ray.remote(max_retries=0)
def f(x: Int[jax.Array, "one two"]):
return x * 2
a = ray.put(jnp.arange(10))
ray.get(f.remote(a))
I tried implementing __setstate__ and __getstate__ as follows:
# jaxtyping/_array_types.py
@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
# ...
def __getstate__(cls):
return cls._get_props()
def __setstate__(cls, props):
(
cls.index_variadic,
cls.dims,
cls.array_type,
cls.dtypes,
cls.dim_str,
) = props
# ...
But as best I can tell, neither one gets called at all.
It looks like ray.cloudpickle.cloudpickle. internally synthesises a class via types.new_class:
https://github.com/ray-project/ray/blob/200c54859dc87f02f7b40e003917b53e68356a60/python/ray/cloudpickle/cloudpickle.py#L536
and then immediately tries to hash it:
https://github.com/ray-project/ray/blob/200c54859dc87f02f7b40e003917b53e68356a60/python/ray/cloudpickle/cloudpickle.py#L124
which fails, as this class does not yet have our attributes set.
ray's approach seems a bit dodgy due to exactly the kind of failure we're seeing here! Anyway, I've worked around this in #237 by just always hashing to zero.
Thank you for the MWE, that was invaluable to figure this one out! :)
You guys, gals, and nonbinary pals rock!!
Seem to have a related issue with Grain dataloader, which involve also cloudpickle and index_variadic. This error only happens when I set worker_count > 0:
ERROR:absl:Error occurred in child process with worker_index: 7
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 176, in _worker_loop
element_producer = _get_element_producer_from_queue(
File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 148, in _get_element_producer_from_queue
element_producer_fn: GetElementProducerFn[Any] = cloudpickle.loads(
File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 539, in _make_skeleton_class
return _lookup_class_or_track(class_tracker_id, skeleton_class)
File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 124, in _lookup_class_or_track
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
File "/usr/local/lib/python3.10/weakref.py", line 429, in setitem
self.data[ref(key, self._remove)] = value
File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 339, in hash
return hash(cls._get_props())
File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 324, in _get_props
cls.index_variadic,
AttributeError: type object 'Float[Array, 'N C H W']' has no attribute 'index_variadic'
The above exception was the direct cause of the following exception:
Ah, this has already been fixed and I just haven't done a new release for it yet.
I've done a version bump + new release in https://github.com/patrick-kidger/jaxtyping/pull/246
Hey! I am seeing this issue again, was the fix submitted ?
It was. FWIW the latest release of jaxtyping added support for pickle, it's possible that something else broke here.
The above MWE for ray still passes on the latest release, though. If you're experiencing an issue then can you verify that you're on version 0.2.38 and if so create a MWE?
I am running v0.2.33, so it looks like I am one behind...