_validate_params fails on zero-sized arrays
Hi,
@niketkumar @cpgaffney1,
cc @dionhaefner
The following attempts to serialize a zero-sized array, but it fails validation in _validate_params.
I believe the problem is that _validate_params expects to find for every 'foo/.zarray' entry, a matching data entry foo/0. However, this code produces tensorstore entries: 'a/0', 'a/.zarray', 'z/.zarray', but not z/0 since there is no data in the z tensor.
I'm actually not sure if tensorstore saves an entry z/0 or not, or what the intended behavior should be.
Any insight would be greatly appreciated!
import jax.numpy as jnp
import jax.tree_util as jtu
import tempfile
import orbax.checkpoint as ocp
target = {
'a': jnp.array([1, 2, 3], jnp.int32),
'z': jnp.zeros((0,)),
}
orbax_checkpointer = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler()
)
with tempfile.TemporaryDirectory() as ckpt_path:
overwrite = True
save_args = jtu.tree_map(lambda _: ocp.SaveArgs(), target)
orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
(jax_env) henry@henry-gs65:orbax$ python flax4309.py
Traceback (most recent call last):
File "/home/henry/ai/projects/orbax/flax4309.py", line 18, in <module>
orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/checkpointer.py", line 216, in save
self._handler.finalize(tmpdir.get())
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 1004, in finalize
self._handler_impl.finalize(directory)
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 806, in finalize
asyncio_utils.run_sync(
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
return asyncio.run(coro)
^^^^^^^^^^^^^^^^^
File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 190, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/henry/miniconda3/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
return future.result()
^^^^^^^^^^^^^^^
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 704, in merge_ocdbt_per_process_files
await _validate_params(directory, ts_context, use_zarr3=use_zarr3)
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 625, in _validate_params
raise ValueError(
ValueError: Save failed: 1/2 params are missing in checkpoint:
z.
Tensorstore KvStore: KvStore({
'base': {
'driver': 'file',
'path': '/tmp/tmpbxi1zpec.orbax-checkpoint-tmp-0/',
},
'cache_pool': 'cache_pool#ocdbt',
'config': {
'compression': {'id': 'zstd'},
'max_decoded_node_bytes': 100000000,
'max_inline_value_bytes': 1024,
'uuid': '3ef941407cca4f778414e9e92b15dedb',
'version_tree_arity_log2': 4,
},
'context': {
'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
'data_copy_concurrency': {},
'file_io_concurrency': {'limit': 128},
'file_io_sync': True,
'ocdbt_coordinator': {},
},
'driver': 'ocdbt',
'experimental_read_coalescing_interval': '1ms',
'experimental_read_coalescing_merged_bytes': 500000000000,
'experimental_read_coalescing_threshold_bytes': 1000000,
}).
Thanks for spotting this, 0-sized array handling is not well defined and we have no tests (internal or external) for it. We will clarify the intended behavior, add tests, and resolve the validation issue, and get back to you.
@cpgaffney1 Still open, checkpoint saving still crashing. Any updates? :)
Hi, addressing this in https://github.com/google/orbax/pull/1570.