recommenders
recommenders copied to clipboard
Memory leak from evaluate() when using custom training loop and GPU
I am using multiple GPU's and AI-Platform for training my model, with a custom training loop in order to speed up the validation between epochs. During each epoch the memory slowly builds up, until OOM-exception is thrown a couple of epochs into training. Code for reference:
devices = tf.config.list_logical_devices('GPU')
strategy = tf.distribute.MirroredStrategy(devices)
with strategy.scope():
model = create_model(...)
model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
min_loss = 9999999999 # Used for early stopping
for epoch in range(N_EPOCHS):
start_training = time.time()
model.fit(train.batch(BATCH_SIZE).cache())
model.retrieval_task.factorized_metrics = (
tfrs.metrics.FactorizedTopK(
candidates=tfrs.layers.factorized_top_k.BruteForce().index_from_dataset(
items_ds.batch(1024).map(lambda item: (item["item_no"], model.item_model(item)))
)
)
)
model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
val_metrics = model.evaluate(val.batch(BATCH_SIZE), return_dict=True)
loss = val_metrics['total_loss']
# Early stopping, tracking total validation loss
if loss < min_loss:
best_model = model
min_loss = loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter > PATIENCE:
model = best_model
break
After around 4 epochs, during the .evaulate(...) part, the GPU's run out of memory:
"Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.45GiB (rounded to 3708977408)requested by op retrieval_1/brute_force_2/TopKV2"
Any idea why, and what I can do to prevent this?
I suspect the model
object holds on to references of the old BruteForce
layers, from previous iterations.
This is in principle fixable in Keras itself, but it may be more practical to work around this by re-creating the model (and restoring from a checkpoint) every couple of iterations. Would this work for you?
I tried to use Tensorflow's Checkpoint
and CheckpointManager
as described in the documentation but unfortunately encounters a similar Out-Of-Memory error:
(4) Resource exhausted: OOM when allocating tensor with shape[3708977407] and type int8 on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[node retrieval_1/brute_force/TopKV2 (defined at root/.local/lib/python3.7/site-packages/tensorflow_recommenders/layers/factorized_top_k.py:571) ]]
It should be noted however, I am very unfamiliar with such Checkpoint-usage, and therefore might have made mistakes in the implementation. Below you find the reference code, please let me know if you see any room for improvement, and what I could test to avoid such OOM-error.
N_EPOCHS = 15
PATIENCE = 2
with strategy.scope():
model = create_model(...)
model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
ckpt = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=PATIENCE + 1)
min_loss = 9999999999
for epoch in range(N_EPOCHS):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
model.fit(train.batch(GLOBAL_BATCH_SIZE).cache()) # Split global batch size across GPU's
model.retrieval_task.factorized_metrics = (
tfrs.metrics.FactorizedTopK(
candidates=tfrs.layers.factorized_top_k.BruteForce().index_from_dataset(
items_ds.batch(1024).map(lambda item: (item["item_no"], model.item_model(item)))
)
)
)
model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
val_accuracy = model.evaluate(val.batch(GLOBAL_BATCH_SIZE), return_dict=True)
loss = val_accuracy['total_loss']
print("Top 100 accuracy: ", val_accuracy['factorized_top_k/top_100_categorical_accuracy']) # total_loss
print("Total loss: ", loss)
# Early stopping
if loss < min_loss:
min_loss = loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter > PATIENCE:
ckpt.restore(manager.checkpoints[0])
print("EARLY STOPPING ACTIVATED")
print("Restored from {}".format(manager.checkpoints[0]))
break
save_path = manager.save()
print("Saved checkpoint for epoch {}: {}".format(epoch, save_path))