tensorflow-wavenet
tensorflow-wavenet copied to clipboard
TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: uint8, int32, int64
When running "train.py", I encountered the following error:
Traceback (most recent call last):
File "C:/Users/96492/Anaconda3/Lib/tensorflow-wavenet-master/train.py", line 337, in
encoded = tf.one_hot( input_batch, depth=self.quantization_channels, dtype=tf.float32)
I tried to solve the problem from line 515 of "model.py", and I think there is an error in the type of input_batch
, then I want to replace it with tf.cast(input_batch, tf.int32)
.
Although the code can be run next, I wonder if such a modification is correct?
@Ahapy Did you solve the problem? I got exact same problem.