keras
keras copied to clipboard
ValueError: Received incompatible devices for jitted computation. (demo_jax_distributed.py)
I get the following error when running examples/demo_jax_distributed.py
on a Cloud TPU VM (tpuv2):
Traceback (most recent call last):
File "/home/colby/jax_test.py", line 341, in <module>
loss, accuracy = model.evaluate(eval_data)
File "/home/colby/.local/lib/python3.10/site-packages/keras_core/src/utils/traceback_utils.py", line 123, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/colby/.local/lib/python3.10/site-packages/keras_core/src/backend/jax/trainer.py", line 597, in evaluate
logs, state = self.test_function(state, data)
ValueError: Received incompatible devices for jitted computation. Got argument state[0][0] of JAXTrainer.make_test_function.<locals>.compiled_test_step with shape float32[3,3,1,12] and device ids [0, 1, 2, 3, 6, 7, 4, 5] on platform TPU and sharding_constraint inside jit with device ids [0] on platform TPU at /home/colby/.local/lib/python3.10/site-packages/keras_core/src/backend/jax/trainer.py:910 (_enforce_jax_state_sharding)
Packaged Versions:
- Keras-core: 0.1.7
- jax[tpu]: 0.4.19
Thanks for the reporting. Let me take a look.
Ping