jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'

Open jaanli opened this issue 1 year ago • 4 comments

I need to remove type hints from functions that are type checked and need to be called in joblib.Parallel or other multiprocessing pipelines; getting tracebacks like this:

joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
    return _lookup_class_or_track(class_tracker_id, skeleton_class)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
    _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
  File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
    self.data[ref(key, self._remove)] = value
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
    return hash(cls._get_props())
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
    cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
    main(cfg)
  File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
    trainer.evaluate(
  File "/home/ray/trainer.py", line 427, in evaluate
    predictions, objective = model_and_objective(batch)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
    predictions = Parallel(n_jobs=self.n_jobs)(
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
    return output if self.return_generator else list(output)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
    yield from self._retrieve()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
    self._raise_error_fast()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
    error_job.get_result(self.timeout)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
    return self._return_or_raise()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
    raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
    return _lookup_class_or_track(class_tracker_id, skeleton_class)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
    _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
  File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
    self.data[ref(key, self._remove)] = value
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
    return hash(cls._get_props())
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
    cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
    main(cfg)
  File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
    trainer.evaluate(
  File "/home/ray/trainer.py", line 427, in evaluate
    predictions, objective = model_and_objective(batch)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
    predictions = Parallel(n_jobs=self.n_jobs)(
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
    return output if self.return_generator else list(output)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
    yield from self._retrieve()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
    self._raise_error_fast()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
    error_job.get_result(self.timeout)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
    return self._return_or_raise()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
    raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.

jaanli avatar Apr 05 '24 14:04 jaanli

Looks like they're not getting de/serialised correctly, so the index_variadic attribute doesn't make it across.

If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement __setstate__ and __getstate__?)

patrick-kidger avatar Apr 05 '24 23:04 patrick-kidger

I'm facing this same issue when trying to save an optax optimizer state using cloudpickle. Hope this issue gets fixed.

File "/root/optimizer.py", line 89, in run_train_on_modal optimizer_state = optax.tree_utils.tree_set(optimizer_state, inner_state=cloudpickle.load(f)) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class return _lookup_class_or_track(class_tracker_id, skeleton_class) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^ File "/usr/lib/python3.11/weakref.py", line 428, in __setitem__ self.data[ref(key, self._remove)] = value ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 321, in __hash__ return hash(cls._get_props()) ^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 306, in _get_props cls.index_variadic, ^^^^^^^^^^^^^^^^^^ AttributeError: type object 'Float[Array, '*shape']' has no attribute 'index_variadic'

sachith-gunasekara avatar May 31 '24 09:05 sachith-gunasekara

Do you have a MWE?

patrick-kidger avatar May 31 '24 10:05 patrick-kidger

Yes, I'm training a model with JAX and Equinox, and I am trying to save the optimizer state.

`lr_scheduler = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=learning_rate, warmup_steps=warmup_iters if init_from == 'scratch' else 0, decay_steps=lr_decay_iters - iter_num, end_value=min_lr, ) optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=lr_scheduler, b1=beta1, b2=beta2)

optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array))

checkpoint_params = { "optimizer_state": optimizer_state }

with open(checkpoint_params_file, "wb") as f: cloudpickle.dump(checkpoint_params, f)`

sachith-gunasekara avatar May 31 '24 19:05 sachith-gunasekara

Looks like they're not getting de/serialised correctly, so the index_variadic attribute doesn't make it across.

If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement __setstate__ and __getstate__?)

I am also encountering this issue, but only with ray (just using cloudpickle on its own seems to work now). This MWE reproduces the issue:

pip install jaxtyping jax 'ray[default]'
import jax
import ray
from jax import numpy as jnp

from jaxtyping import Int


ray.init()


@ray.remote(max_retries=0)
def f(x: Int[jax.Array, "one two"]):
    return x * 2


a = ray.put(jnp.arange(10))
ray.get(f.remote(a))

I tried implementing __setstate__ and __getstate__ as follows:

# jaxtyping/_array_types.py

@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
    class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
        # ...

        def __getstate__(cls):
            return cls._get_props()

        def __setstate__(cls, props):
            (
                cls.index_variadic,
                cls.dims,
                cls.array_type,
                cls.dtypes,
                cls.dim_str,
            ) = props
        
        # ...

But as best I can tell, neither one gets called at all.

LoganWalls avatar Jul 30 '24 14:07 LoganWalls

It looks like ray.cloudpickle.cloudpickle. internally synthesises a class via types.new_class:

https://github.com/ray-project/ray/blob/200c54859dc87f02f7b40e003917b53e68356a60/python/ray/cloudpickle/cloudpickle.py#L536

and then immediately tries to hash it:

https://github.com/ray-project/ray/blob/200c54859dc87f02f7b40e003917b53e68356a60/python/ray/cloudpickle/cloudpickle.py#L124

which fails, as this class does not yet have our attributes set.

ray's approach seems a bit dodgy due to exactly the kind of failure we're seeing here! Anyway, I've worked around this in #237 by just always hashing to zero.

Thank you for the MWE, that was invaluable to figure this one out! :)

patrick-kidger avatar Aug 01 '24 03:08 patrick-kidger

You guys, gals, and nonbinary pals rock!!

jaanli avatar Aug 01 '24 12:08 jaanli

Seem to have a related issue with Grain dataloader, which involve also cloudpickle and index_variadic. This error only happens when I set worker_count > 0:

ERROR:absl:Error occurred in child process with worker_index: 7 Traceback (most recent call last): File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 176, in _worker_loop element_producer = _get_element_producer_from_queue( File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 148, in _get_element_producer_from_queue element_producer_fn: GetElementProducerFn[Any] = cloudpickle.loads( File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 539, in _make_skeleton_class return _lookup_class_or_track(class_tracker_id, skeleton_class) File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 124, in _lookup_class_or_track _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id File "/usr/local/lib/python3.10/weakref.py", line 429, in setitem self.data[ref(key, self._remove)] = value File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 339, in hash return hash(cls._get_props()) File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 324, in _get_props cls.index_variadic, AttributeError: type object 'Float[Array, 'N C H W']' has no attribute 'index_variadic'
The above exception was the direct cause of the following exception:

danbnyn avatar Sep 01 '24 14:09 danbnyn

Ah, this has already been fixed and I just haven't done a new release for it yet.

I've done a version bump + new release in https://github.com/patrick-kidger/jaxtyping/pull/246

patrick-kidger avatar Sep 01 '24 14:09 patrick-kidger

Hey! I am seeing this issue again, was the fix submitted ?

purpledog avatar Feb 18 '25 18:02 purpledog

It was. FWIW the latest release of jaxtyping added support for pickle, it's possible that something else broke here.

The above MWE for ray still passes on the latest release, though. If you're experiencing an issue then can you verify that you're on version 0.2.38 and if so create a MWE?

patrick-kidger avatar Feb 18 '25 18:02 patrick-kidger

I am running v0.2.33, so it looks like I am one behind...

purpledog avatar Feb 18 '25 18:02 purpledog