tf-keras
tf-keras copied to clipboard
New optimizers are incompatible with `jit_compile` and `MirroredStrategy`
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): No
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Debian
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.11.0-rc1, 2.12.0.dev20221026
- Python version: 3.9.9
- GPU model and memory: 4 x NVIDIA T4 on GCP VM
Describe the problem.
The new optimizers in Keras 2.11 seem to be incompatible with multi GPU MirroredStrategy
training and jit compilation.
Describe the current behavior.
The following code example, taken from the docs, will fail to execute in a multi GPU environment with the mentioned runtime error:
import tensorflow as tf
dataset = tf.data.Dataset.from_tensors(([1.0], [1.0])).repeat(100).batch(10)
with tf.distribute.MirroredStrategy().scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss="mse", jit_compile=True, optimizer=tf.keras.optimizers.SGD())
model.fit(dataset, epochs=2)
RuntimeError: `merge_call` called while defining a new graph or a tf.function. This can often happen if the
function `fn` passed to `strategy.run()` contains a nested `@tf.function`, and the nested `@tf.function` contains
a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function `fn`
uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet
supported. Instead, please avoid nested `tf.function`s or control flow statements that may potentially cross a
synchronization boundary, for example, wrap the `fn` passed to `strategy.run` or the entire `strategy.run` inside
a `tf.function` or move the control flow out of `fn`. If you are subclassing a `tf.keras.Model`, please avoid
decorating overridden methods `test_step` and `train_step` in `tf.function`.
The same example correctly works when either running on a single GPU or switching back to the legacy optimizers.
When compiling the model without jit compilation
model.compile(loss="mse", optimizer=tf.keras.optimizers.SGD())
execution also fails due to XLA not being able to access resources on different devices:
File "/home/lgeiger/.local/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1216, in _distributed_apply_gradients_fn
distribution.extended.update(
File "/home/lgeiger/.local/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1211, in apply_grad_to_update_var
return self._update_step_xla(grad, var, id(self._var_key(var)))
Node: 'update_1_1/StatefulPartitionedCall'
4 root error(s) found.
(0) INVALID_ARGUMENT: Trying to access resource Resource-25-at-0x557111e6b130 (defined @ /home/lgeiger/.local/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py:378) located in device /job:localhost/replica:0/task:0/device:GPU:0 from device /job:localhost/replica:0/task:0/device:GPU:3
Cf. https://www.tensorflow.org/xla/known_issues#tfvariable_on_a_different_device
[[{{node update_3_1/StatefulPartitionedCall}}]]
(1) INVALID_ARGUMENT: Trying to access resource Resource-25-at-0x557111e6b130 (defined @ /home/lgeiger/.local/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py:378) located in device /job:localhost/replica:0/task:0/device:GPU:0 from device /job:localhost/replica:0/task:0/device:GPU:2
Cf. https://www.tensorflow.org/xla/known_issues#tfvariable_on_a_different_device
[[{{node update_2_1/StatefulPartitionedCall}}]]
[[GroupCrossDeviceControlEdges_1/Identity_5/_219]]
(2) INVALID_ARGUMENT: Trying to access resource Resource-25-at-0x557111e6b130 (defined @ /home/lgeiger/.local/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py:378) located in device /job:localhost/replica:0/task:0/device:GPU:0 from device /job:localhost/replica:0/task:0/device:GPU:2
Cf. https://www.tensorflow.org/xla/known_issues#tfvariable_on_a_different_device
[[{{node update_2_1/StatefulPartitionedCall}}]]
(3) INVALID_ARGUMENT: Trying to access resource Resource-25-at-0x557111e6b130 (defined @ /home/lgeiger/.local/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py:378) located in device /job:localhost/replica:0/task:0/device:GPU:0 from device /job:localhost/replica:0/task:0/device:GPU:1
Cf. https://www.tensorflow.org/xla/known_issues#tfvariable_on_a_different_device
[[{{node update_1_1/StatefulPartitionedCall}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_2295]
Only completely disabling jit compilation using
model.compile(loss="mse", jit_compile=False, optimizer=tf.keras.optimizers.SGD(jit_compile=False))
which will often be a great performance regression prevents this error from happening with the new optimizers.
Describe the expected behavior.
Since the optimizers are not experimental anymore and the code example stems directly from the official multi GPU guide I would expect the optimizers to support multi GPU training as it seems like an essential feature for people not running on TPUs and potentially will break many users that try to upgrade from TF 2.10.0.
@fchollet To me this looks like an issue with the new Keras optimizers but let me know if this is caused by TensorFlow itself. In any case it would be great to get this fixed before the stable release.
- Do you want to contribute a PR? (yes/no): no (I would like to if I would know an easy fix)
- If yes, please read this page for instructions
- Briefly describe your candidate solution(if contributing):