trax icon indicating copy to clipboard operation
trax copied to clipboard

JAX Backend error on using GPU in Kaggle Kernel

Open SauravMaheshkar opened this issue 4 years ago • 0 comments

Description

Whenever I try to use GPU's with Kaggle Kernels it returns the following error

TypeError: pmap() got an unexpected keyword argument 'donate_argnums'

Environment information

Kaggle Kernels with GPU as an accelerator

For bugs: reproduction and error logs

# Steps to reproduce:

1. Create a new Kaggle Kernel
2. Install the Trax library
3. Set the accelerator to a GPU
4. try `import trax`

# Error logs:

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in <module>
    988 
    989 
--> 990 @functools.partial(fastmath.pmap, axis_name='devices', donate_argnums=(0,))
    991 def _make_weights_and_state_same_across_hosts(weights_and_state):
    992   """Makes train and eval model's weights and state the same across hosts."""

/opt/conda/lib/python3.7/site-packages/trax/fastmath/ops.py in pmap(*args, **kwargs)
    314 def pmap(*args, **kwargs):
    315   """Parallel-map to apply a function on multiple accelerators in parallel."""
--> 316   return backend()['pmap'](*args, **kwargs)
    317 
    318 

TypeError: pmap() got an unexpected keyword argument 'donate_argnums'

SauravMaheshkar avatar Oct 30 '20 08:10 SauravMaheshkar