DALI
DALI copied to clipboard
Cryptic error message from jax integration
Version
1.40.0
Describe the bug.
When obtaining a batch from the data iterator, I received this error. I am not sure why this is happening and not sure how to debug this myself.
1555 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/nvidia/dali/plugin/jax/iterator.py", line 189, in __next__
1556 return self._next_impl()
1557 ^^^^^^^^^^^^^^^^^
1558 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/nvidia/dali/plugin/jax/iterator.py", line 170, in _next_impl
1559 category_outputs = self._gather_outputs_for_category(pipelines_outputs, category_id)
1560 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1561 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/nvidia/dali/plugin/jax/iterator.py", line 196, in _gather_outputs_for_category
1562 _to_jax_array(pipelines_outputs[pipeline_id][category_id].as_tensor())
1563 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/nvidia/dali/plugin/jax/integration.py", line 43, in _to_jax_array
1564 return jax_array.copy()
1565 ^^^^^^^^^^^^^^^^
1566 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2689, in copy
1567 return array(a, copy=True, order=order)
1568 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1569 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2595, in array
1570 out = _array_copy(object) if copy else object
1571 ^^^^^^^^^^^^^^^^^^^
1572 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 4650, in _array_copy
1573 return copy_p.bind(arr)
1574 ^^^^^^^^^^^^^^^^
1575 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/core.py", line 387, in bind
1576 return self.bind_with_trace(find_top_trace(args), args, params)
1577 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1578 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/core.py", line 391, in bind_with_trace
1579 out = trace.process_primitive(self, map(trace.full_raise, args), params)
1580 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1581 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/core.py", line 879, in process_primitive
1582 return primitive.impl(*tracers, **params)
1583 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1584 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 4691, in _copy_impl
1585 return dispatch.apply_primitive(prim, *args, **kwargs)
1586 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1587 File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
1588 outs = fun(*args)
1589 ^^^^^^^^^^
1590 jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Buffer passed to Execute() as argument 0 to replica 0 is on device cuda:1, but replica is assigned to device cuda:0.
Minimum reproducible example
No response
Relevant log output
No response
Other/Misc.
No response
Check for duplicates
- [X] I have searched the open bugs/issues and have found no duplicates for this bug report
if I reduce the batch size, the problem does not seem to happen. Sorry I can't give you a minimal reproducible snippet.
Hello @quanvuong thanks for reporting the issue. Could you tell more about your setup? It looks like you are using multiple GPUs here? What are the other parameters? What are the batch sizes you mentioned?