text
text copied to clipboard
Text generation with an RNN - error in the model class
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
model = MyModel(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
rnn_units=rnn_units)
TypeError: Layer.init() takes 1 positional argument but 2 were given
The error in your code lies in the way you are calling the superclass's __init__ method in your MyModel class. The super().__init__ function should not take self as a parameter. In Python 3, you can simply call super().__init__() without any arguments.
Change this line:
super().__init__(self)
to this:
super().__init__()
So, your corrected class definition should look like this:
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
Even if this is done, an error is produced further down in the code. Right after creating an object of this class, there's this code in the tutorial:
example_batch_predictions = model(input_example_batch)
print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
which procudes the following error:
[<ipython-input-31-d5691f3250ba>](https://localhost:8080/#) in <cell line: 1>()
1 for input_example_batch, target_example_batch in dataset.take(1):
----> 2 example_batch_predictions = model(input_example_batch)
3 print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
1 frames
[<ipython-input-29-84c06ee9303e>](https://localhost:8080/#) in call(self, inputs, states, return_state, training)
12 x = self.embedding(x, training=training)
13 if states is None:
---> 14 states = self.gru.get_initial_state(x)
15 x, states = self.gru(x, initial_state=states, training=training)
16 x = self.dense(x, training=training)
InvalidArgumentError: Exception encountered when calling MyModel.call().
{{function_node __wrapped__Pack_N_2_device_/job:localhost/replica:0/task:0/device:GPU:0}} Shapes of all inputs must match: values[0].shape = [64,100,256] != values[1].shape = [] [Op:Pack] name:
Arguments received by MyModel.call():
• inputs=tf.Tensor(shape=(64, 100), dtype=int64)
• states=None
• return_state=False
• training=False
This was produced in Google Colab with TensorFlow 2.17. I don't really know how to solve this. It would be nice if the TensorFlow team updated this tutorial with working code.
After some research, I have found 4 different and independent errors with the code present in the tutorial. I hope these will help, and with the last two in particular, I hope other people can find a proper solution and integrate a fix into Tensorflow.
Error found by @808vita
Symptom
TypeError: Layer.init() takes 1 positional argument but 2 were given
Cause
A mistake in the Python code related to Python itself, not Keras or Tensorflow. When calling super() without any argument inside a class, the result is equivalent to the same object as self but using methods from the parent class. It's not necessary to pass self when calling its methods.
Incidentally, this mistake may have been present right from the start, with the Python version with which the tutorial was written.
Solution
Replace
super().__init__(self)
With
super().__init__()
Error found by @alexdrymonitis
Symptom
Shapes of all inputs must match: values[0].shape = [64,100,256] != values[1].shape = []
This error needs an update to the Text generation with an RNN tutorial.
Cause
Breaking change in the Keras API introduced in Tensorflow version 2.16 (according to the documentation). Starting from Tensorflow version 2.16, the method tf.keras.layers.GRUCell.get_inital_state does not take a full batch input data tensor as an argument anymore, but instead, it takes a scalar integer tensor corresponding to the batch size.
The tutorial hasn't been updated to reflect the API change.
Solution
Replace
states = self.gru.get_initial_state(x)
With
states = self.gru.get_initial_state(tf.shape(x)[0])
This error needs an update to the Text generation with an RNN tutorial.
Error found by me
Symptom
When running on a GPU, but not on a CPU. Using tensorflow==2.17.0 and keras==3.5.0 on Python 3.11, Cuda 12.6.1 and Cudnn 8.9.7.
Traceback (most recent call last):
File "/mnt/d/dev/tftest/./testmodel.py", line 94, in <module>
example_batch_predictions = model(input_example_batch)
File "/mnt/d/dev/tftest/lvenv/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/mnt/d/dev/tftest/./testmodel.py", line 78, in call
x, states = self.gru(x, initial_state=states, training=training)
ValueError: Exception encountered when calling MyModel.call().
too many values to unpack (expected 2)
Cause
For unknown reasons, tf.keras.layers.GRU, when called on a GPU, returns one single list containing the output tensor as its first element and all state tensors as the subsequent elements, instead of a tuple containing the output tensor first and a list of state tensors next. This behavior wasn't present in version 2.12 of Tensorflow, but it was present with the LSTM RNN layer.
I assume this behavior is a bug in the RNN implementation on GPU (since it is also present with LSTM).
Workaround
Replace
states = self.gru.get_initial_state(x)
With
r = self.gru(x, initial_state=states, training=training)
x, states = r[0], r[1:]
This error needs a fix in the Tensorflow Cuda code.
Error found by me
Symptom
When running on a GPU, but not on a CPU. Using tensorflow==2.17.0 and keras==3.5.0 on Python 3.11, Cuda 12.6.1 and Cudnn 8.9.7.
This error only happens when training, not when running the model over a tensor extracted from the dataset.
Traceback (most recent call last):
File "/mnt/d/dev/tftest/./testmodel.py", line 119, in <module>
history = model.fit(dataset, epochs=EPOCHS, callbacks=[])
File "/mnt/d/dev/tftest/lvenv/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/mnt/d/dev/tftest/./testmodel.py", line 78, in call
r = self.gru(x, initial_state=states, training=training)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Exception encountered when calling GRU.call().
Iterating over a symbolic `tf.Tensor` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code for more information.
Cause
Unknown. I am still investigating, so far without result. Like the error above, I assume it is a bug in the RNN or GRU Cuda code. Most likely related to GRU and not RNN in general since this issue does not happen with the SimpleRNN or LSTM layers.
Solution or workaround
None found so far. I really hope someone finds one because it's blocking me from working on a related project.