jax
jax copied to clipboard
Cannot interpret 'key<fry>' as a data type
Description
2024-05-21 07:41:32,181 ERROR worker.py:405-- Unhandled error (suppress with 'RAY IGNORE UNHANDLED ERRORS=1'): ec[36mray::MeshHostWorker.run executableFile "/root/cy/temp/geesibling/python/geesibling/adapters/jax/pipeline/devicecontext.py", line 365, in run executableself.do recv(instruction.micro batch id.File "/root/cy/temp/geesibling/python/geesibling/adapters/jax/pipeline/devicecontext.py", line 392, in do recvrecv buffercupy.zeros(var.aval.shape,dtype=var.aval.dtype)File "/root/miniconda3/envs/framework-cy/lib/python3.9/site-packages/cupy/ creation/basic.py", line 248, in zerosa=cupy.ndarray(shape,dtype,order=order)File "cupy/ core/core.pyx",line 132,in cupy. core.core.ndarray. newFile "cupy/ core/core.pyx",line 204,in cupy. core.core. _ndarray base. init.get dtype with itemsizeFile "cupy/ core/ dtvpe,pyx",line 61.in cupy, core. dtypeTypeError:Cannot interpret 'key
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.7 jaxlib: 0.4.7 numpy: 1.23.0 python: 3.9.13 (main, Oct 13 2022, 21:15:33) [GCC 11.2.0] jax.devices (8 total, 8 local): [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0) StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0) ... StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0) StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0)] process_count: 1
Can you provide a way for us to reproduce the error you're seeing?
The stack trace suggests the error is coming from cupy, so my guess would be that cupy doesn't accept a custom JAX dtype, but have a reproducer would still help to diagnose this.
def do_recv(self, micro_batch_id, input_vars, src_rank, group_name='default'):
src_gpu_idx = 0
for var in input_vars:
with cupy.cuda.Device(0):
if var.aval.dtype==np.bool_:
recv_buffer = cupy.zeros(var.aval.shape,dtype=np.int32)
else:
recv_buffer = cupy.zeros(var.aval.shape,dtype=var.aval.dtype)
col.recv_multigpu(recv_buffer, src_rank,src_gpu_idx, group_name)
cupy.cuda.Device(0).synchronize()
recv_buffer = recv_buffer.get()
if var.aval.dtype==np.bool_:
recv_buffer = recv_buffer.astype(np.bool_)
val = jax.device_put(recv_buffer)
if var in self.buffers[-1]:
self.buffers[-1][var] = val
else:
self.buffers[micro_batch_id][var] = val
When receiving data with dtype of fry type, how to receive fry type data here and how to process it?
It looks like you need something like this at the beginning of your function
if jax.dtypes.issubdtype(var.dtype, jax.dtypes.prng_key):
var = jax.random.key_data(var)
impl = jax.random.key_impl(var)
And if you need to convert back to a typed key, use var = jax.random.wrap_key_data(var, impl=impl).