lingvo icon indicating copy to clipboard operation
lingvo copied to clipboard

Apply gradients with distributed_shampoo fails

Open c0derzer0 opened this issue 3 years ago • 1 comments

Hello,

I am trying to run distributed shampoo implemented in lingvo. I have the compatible tensorflow version with lignvo.

Python 3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> tf.__version__
'2.7.1'
>>>

But when I do opt.apply_gradients(), it throws an error about tf.variable being unhashable. apply_gradients() works fine with other optimizers.

File "program.py", line 177, in train_step
  optimizer.apply_gradients(
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/tensorflow/python/training/optimizer.py", line 605, in apply_gradients
  self._create_slots(var_list)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/lingvo/core/distributed_shampoo.py", line 336, in _create_slots
  self._partitioner_metadata[v] = TensorPartitioner.partition_metadata(
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/tensorflow/python/ops/variables.py", line 1085, in __hash__
  raise TypeError(
TypeError: Variable is unhashable. Instead, use variable.ref() as the key. (Variable: <tf.Variable 'conv2d_23/kernel:0' shape=(7, 7, 8, 32) dtype=float32, numpy=
array([[[[-6.29460454e-01, -4.73986834e-01, -3.37867022e-01, ...,
         5.98385096e-01, -1.81423143e-01,  4.59209830e-02],
       [ 1.96352646e-01, -1.11289904e-01,  3.36863637e-01, ...,

Appreciate any help to fix the issue.

c0derzer0 avatar Mar 30 '22 18:03 c0derzer0

had the same issue here, i tried adding .ref() to the v there and the error message becomes

2022-06-17 09:24:55.671316: W tensorflow/core/framework/op_kernel.cc:1733] INTERNAL: 'cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, 0, reinterpret_cast<CUstream>(stream), params, nullptr)' failed with 'CUDA_ERROR_ILLEGAL_ADDRESS'

DyeKuu avatar Jun 17 '22 10:06 DyeKuu