recommenders icon indicating copy to clipboard operation
recommenders copied to clipboard

Memory leak from evaluate() when using custom training loop and GPU

Open msvensson222 opened this issue 3 years ago • 2 comments

I am using multiple GPU's and AI-Platform for training my model, with a custom training loop in order to speed up the validation between epochs. During each epoch the memory slowly builds up, until OOM-exception is thrown a couple of epochs into training. Code for reference:

devices = tf.config.list_logical_devices('GPU')
strategy = tf.distribute.MirroredStrategy(devices)
with strategy.scope():
    model = create_model(...)
    model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
    
    min_loss = 9999999999 # Used for early stopping
    for epoch in range(N_EPOCHS):
        start_training = time.time()
        model.fit(train.batch(BATCH_SIZE).cache())

        model.retrieval_task.factorized_metrics = (
            tfrs.metrics.FactorizedTopK(
                candidates=tfrs.layers.factorized_top_k.BruteForce().index_from_dataset(
                    items_ds.batch(1024).map(lambda item: (item["item_no"], model.item_model(item)))
                )
            )
        )
        model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))

        val_metrics = model.evaluate(val.batch(BATCH_SIZE), return_dict=True)
        loss = val_metrics['total_loss']

        # Early stopping, tracking total validation loss
        if loss < min_loss:
            best_model = model
            min_loss = loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter > PATIENCE:
                model = best_model
                break

After around 4 epochs, during the .evaulate(...) part, the GPU's run out of memory: "Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.45GiB (rounded to 3708977408)requested by op retrieval_1/brute_force_2/TopKV2"

Any idea why, and what I can do to prevent this?

msvensson222 avatar Oct 14 '21 05:10 msvensson222

I suspect the model object holds on to references of the old BruteForce layers, from previous iterations.

This is in principle fixable in Keras itself, but it may be more practical to work around this by re-creating the model (and restoring from a checkpoint) every couple of iterations. Would this work for you?

maciejkula avatar Oct 18 '21 23:10 maciejkula

I tried to use Tensorflow's Checkpoint and CheckpointManager as described in the documentation but unfortunately encounters a similar Out-Of-Memory error:

  (4) Resource exhausted:  OOM when allocating tensor with shape[3708977407] and type int8 on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[node retrieval_1/brute_force/TopKV2 (defined at root/.local/lib/python3.7/site-packages/tensorflow_recommenders/layers/factorized_top_k.py:571) ]]

It should be noted however, I am very unfamiliar with such Checkpoint-usage, and therefore might have made mistakes in the implementation. Below you find the reference code, please let me know if you see any room for improvement, and what I could test to avoid such OOM-error.

N_EPOCHS = 15
PATIENCE = 2

with strategy.scope():
    model = create_model(...)
    model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
    ckpt = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=PATIENCE + 1)

    min_loss = 9999999999
    for epoch in range(N_EPOCHS):

        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            print("Restored from {}".format(manager.latest_checkpoint))
        else:
            print("Initializing from scratch.")

        model.fit(train.batch(GLOBAL_BATCH_SIZE).cache())  # Split global batch size across GPU's
        model.retrieval_task.factorized_metrics = (
            tfrs.metrics.FactorizedTopK(
                candidates=tfrs.layers.factorized_top_k.BruteForce().index_from_dataset(
                    items_ds.batch(1024).map(lambda item: (item["item_no"], model.item_model(item)))
                )
            )
        )
        model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))

        val_accuracy = model.evaluate(val.batch(GLOBAL_BATCH_SIZE), return_dict=True)
        loss = val_accuracy['total_loss']
        print("Top 100 accuracy: ", val_accuracy['factorized_top_k/top_100_categorical_accuracy'])  # total_loss
        print("Total loss: ", loss)

        # Early stopping
        if loss < min_loss:
            min_loss = loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter > PATIENCE:
                ckpt.restore(manager.checkpoints[0])
                print("EARLY STOPPING ACTIVATED")
                print("Restored from {}".format(manager.checkpoints[0]))
                break

        save_path = manager.save()
        print("Saved checkpoint for epoch {}: {}".format(epoch, save_path))

msvensson222 avatar Oct 20 '21 08:10 msvensson222