trax
trax copied to clipboard
JAX Backend error on using GPU in Kaggle Kernel
Description
Whenever I try to use GPU's with Kaggle Kernels it returns the following error
TypeError: pmap() got an unexpected keyword argument 'donate_argnums'
Environment information
Kaggle Kernels with GPU as an accelerator
For bugs: reproduction and error logs
# Steps to reproduce:
1. Create a new Kaggle Kernel
2. Install the Trax library
3. Set the accelerator to a GPU
4. try `import trax`
# Error logs:
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in <module>
988
989
--> 990 @functools.partial(fastmath.pmap, axis_name='devices', donate_argnums=(0,))
991 def _make_weights_and_state_same_across_hosts(weights_and_state):
992 """Makes train and eval model's weights and state the same across hosts."""
/opt/conda/lib/python3.7/site-packages/trax/fastmath/ops.py in pmap(*args, **kwargs)
314 def pmap(*args, **kwargs):
315 """Parallel-map to apply a function on multiple accelerators in parallel."""
--> 316 return backend()['pmap'](*args, **kwargs)
317
318
TypeError: pmap() got an unexpected keyword argument 'donate_argnums'