orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Array has been deleted

Open fding opened this issue 6 months ago • 4 comments

Hi, we are trying out the orbax (0.4.1) AsyncCheckpointer (used through CheckpointManager). We are getting "Array has been deleted" errors. It seems as if the async checkpointer is trying to copy a jax.Array from device to memory, but that array is no longer available. The Orbax documentations says that "From start to finish, async checkpointing for a train state of arrays works by first performing a blocking copy of the arrays from device to host", but I wonder if there any gotchas in how we should use orbax checkpointing.

Here is the stack trace:

Exception in thread Thread-314 (_finalize):
Traceback (most recent call last):
File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
     self.run()
File "/usr/lib/python3.11/threading.py", line 982, in run
     self._target(*self._args, **self._kwargs)
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py", line 956, in _finalize
    self.wait_until_finished(join_finalize_thread=Fale)
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py", line 888, in wait_until_finished
     checkpointer.wait_until_finished()  # pytype: disable=attribute-error
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 262, in wait_until_finished
     self._async_manager.wait_until_finished()
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 154, in wait_until_finished
     self.check_for_errors()
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 145, in check_for_errors
     raise exception  # pylint: disable=raising-bad-type
     ^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 97, in _thread_func
     future.result()
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 456, in result
     return self.__get_result()
            ^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
     raise self._exception
File "/usr/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/aggregate_handlers.py", line 75, in _serialize_fn
     msgpack = msgpack_utils.msgpack_serialize(serializable_dict)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 216, in msgpack_serialize
     return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/msgpack/__init__.py", line 36, in packb
     return Packer(**kwargs).pack(o)
File "msgpack/_packer.pyx", line 285, in msgpack._cmsgpack.Packer._pack
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 78, in _msgpack_ext_pack
     return msgpack.ExtType(_MsgpackExtType.NDARRAY, _ndarray_to_bytes(x))
                                                     ^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 40, in _ndarray_to_bytes
     arr = np.array(arr)
           ^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 377, in __array__
     return np.asarray(self._value, dtype=dtype)
                       ^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 340, in wrapper
     return func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 562, in _value
     self._check_if_deleted()
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 530, in _check_if_deleted
     raise RuntimeError(
RuntimeError: Array has been deleted with shape=float32[256].

fding avatar Dec 13 '23 19:12 fding