keras-nlp
keras-nlp copied to clipboard
Distributed training not working (batch size calculation)
Describe the bug This is an issue I am having with keras-nlp, but I am not sure if it can be solved here or should be reported under keras or tensorflow.
Currently, the batch size is not calculated correctly when performing multi-worker distributed training with JAX backend:
Traceback (most recent call last):
File "mycode.py", line 293, in <module>
history = classifier.fit(
File "/usr/local/lib/python3.10/dist-packages/keras_nlp/src/utils/pipeline_model.py", line 194, in fit
return super().fit(
File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.10/dist-packages/keras/src/distribution/distribution_lib.py", line 467, in distribute_dataset
raise ValueError(
ValueError: The batch size of the input dataset is unknown. Please config the batch size for the input dataset, e.g via `dataset.batch(batch_size)`
To Reproduce Run (multi-worker?) distributed training with JAX backend.
The issue seems to stem from https://github.com/keras-team/keras-nlp/blob/778ccd72fe5d74e8eedc7d38dfb57561821b7851/keras_nlp/src/utils/pipeline_model.py#L181 where mapping a preprocessor over the dataset leads to failure at https://github.com/keras-team/keras/blob/3105247028bb0a7e6d2f05f5daa44c9cfafd3e67/keras/src/distribution/distribution_lib.py#L465
Here is minimal example where tensorflow.python.data.experimental.ops.distribute.compute_batch_size()
returns -1 after mapping:
import tensorflow as tf
from tensorflow.python.data.experimental.ops import distribute as tf_data_distribute
from keras_nlp.src.utils.keras_utils import pack_x_y_sample_weight
ds = tf.data.Dataset.range(8)
ds = ds.batch(3)
print(f"True batch size (before): {len(list(ds.as_numpy_iterator()))}")
print(f"Calculated batch size (before): {tf_data_distribute.compute_batch_size(ds)}")
ds = ds.map(pack_x_y_sample_weight, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
print(f"True batch size (after): {len(list(ds.as_numpy_iterator()))}")
print(f"Calculated batch size (after): {tf_data_distribute.compute_batch_size(ds)}")
Expected behavior A batched tf.data.Dataset() object is recognized as being batched.