Default compilation won't work for tf.distribute
Describe the bug Instantiating the model with tf.distribute,MirroredStrategy and training leads to following error
InvalidArgumentError
--------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) /tmp/ipykernel_23/2054434371.py in/kaggle/working/keras-nlp/keras_nlp/utils/pipeline_model.py in fit(self, x, y, batch_size, sample_weight, validation_data, validation_split, **kwargs) 195 sample_weight=None, 196 validation_data=validation_data, --> 197 **kwargs, 198 ) 199
/opt/conda/lib/python3.7/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
68 # To get the full stack trace, call:
69 # tf.debugging.disable_traceback_filtering()
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 51 ctx.ensure_initialized() 52 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, ---> 53 inputs, attrs, num_outputs) 54 except core._NotOkStatusException as e: 55 if name is not None:
InvalidArgumentError: Graph execution error:
2 root error(s) found. (0) INVALID_ARGUMENT: Detected unsupported operations when trying to compile graph __inference_run_step_111932[] on XLA_GPU_JIT: CollectiveGatherV2 (No registered 'CollectiveGatherV2' OpKernel for XLA_GPU_JIT devices compatible with node {{node CollectiveGatherV2}}){{node CollectiveGatherV2}} The op is created at: File "threading.py", line 890, in _bootstrap self._bootstrap_inner() File "threading.py", line 926, in _bootstrap_inner self.run() File "site-packages/keras/engine/training.py", line 1222, in run_step outputs = model.train_step(data) File "site-packages/keras/engine/training.py", line 1027, in train_step self.optimizer.minimize(loss, self.trainable_variables, tape=tape) File "site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 527, in minimize self.apply_gradients(grads_and_vars) File "site-packages/keras/mixed_precision/loss_scale_optimizer.py", line 1301, in apply_gradients grads_and_vars = self._optimizer.aggregate_gradients(grads_and_vars) File "site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1105, in aggregate_gradients return optimizer_utils.all_reduce_sum_gradients(grads_and_vars) File "site-packages/keras/optimizers/optimizer_v2/utils.py", line 38, in all_reduce_sum_gradients tf.distribute.ReduceOp.SUM, grads [[replica_1/StatefulPartitionedCall]] (1) INVALID_ARGUMENT: Detected unsupported operations when trying to compile graph __inference_run_step_106651[] on XLA_GPU_JIT: CollectiveGatherV2 (No registered 'CollectiveGatherV2' OpKernel for XLA_GPU_JIT devices compatible with node {{node CollectiveGatherV2}}){{node CollectiveGatherV2}} The op is created at: File "threading.py", line 890, in _bootstrap self._bootstrap_inner() File "threading.py", line 926, in _bootstrap_inner self.run() File "site-packages/keras/engine/training.py", line 1222, in run_step outputs = model.train_step(data) File "site-packages/keras/engine/training.py", line 1027, in train_step self.optimizer.minimize(loss, self.trainable_variables, tape=tape) File "site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 527, in minimize self.apply_gradients(grads_and_vars) File "site-packages/keras/mixed_precision/loss_scale_optimizer.py", line 1301, in apply_gradients grads_and_vars = self._optimizer.aggregate_gradients(grads_and_vars) File "site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1105, in aggregate_gradients return optimizer_utils.all_reduce_sum_gradients(grads_and_vars) File "site-packages/keras/optimizers/optimizer_v2/utils.py", line 38, in all_reduce_sum_gradients tf.distribute.ReduceOp.SUM, grads [[StatefulPartitionedCall]] 0 successful operations. 0 derived errors ignored. [Op:__inference_train_function_112237]
To Reproduce
Workaround
Recompiling the model without jit_compile, makes everything work fine.
with strategy.scope():
model_dist = keras_nlp.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")
model_dist.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=keras.metrics.SparseCategoricalAccuracy(),
)
Additional context #858
Would you like to help us fix it? Yes
Ideally our models should not throw this error while using a distributed strategy, it would lead to bad UX to our majority of users.
Potential Fix
It would be to add an extra argument in our base classes that let's our model know if we are using a distributed strategy, in that case we should pass jit_compile=False
@shivance Could you check your TF version on the notebook? The error looks like a stale one which XLA team has fixed.
Hey @chenmoneygithub, on Kaggle the TF version is 2.11

I am actually very confused about the error, which appears to be from an optimizer line. But we have never heard XLA is broken on optimizer so far. Could you try verifying if you can reproduce this error with BertClassifier?
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss='mse', optimizer='sgd', jit_compile=True)
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(10)
model.fit(dataset, epochs=2)
This works well, weird.
BertClassifier also gives the same error ! Colab
Thanks for checking! It's a problem with MirroredStrategy + XLA + sparse gradients, I am talking to TF teams for a solution.
Thanks for checking! It's a problem with MirroredStrategy + XLA + sparse gradients, I am talking to TF teams for a solution.
Hey @chenmoneygithub did TF team fix this up?
@mattdangerw shall I close this issue?