[BUG] Error when re-compile and fit a model
Bug description
Created a TwoTower Model or a Wide and Deep Model, firstly compile the model with the "adam" optiomizer and fit, and then call model.compile again with "sgd" as optimizer. It raises this error: " AssertionError: Called a function referencing variables which have been deleted."
Steps/Code to reproduce bug
import os
import tensorflow as tf
import merlin
from merlin.datasets.synthetic import generate_data
import merlin.models.tf as ml
from merlin.schema import Schema, Tags
DATA_FOLDER = os.environ.get("DATA_FOLDER", "/workspace/data/")
NUM_ROWS = 1000000
train, valid = generate_data("e-commerce-large", int(NUM_ROWS), set_sizes=(0.7, 0.3))
schema = train.schema
model = ml.TwoTowerModel(schema,
query_tower=ml.MLPBlock([512, 256]),
item_tower=ml.MLPBlock([512, 256]))
#first time compile
model.compile(optimizer="adam")
model.fit(train, batch_size=1024, epochs=1)
#second time compile
model.compile(optimizer="sgd", metrics=[tf.keras.metrics.AUC()])
model.fit(train, batch_size=1024, epochs=1)
Error:
AssertionError Traceback (most recent call last)
Input In [7], in <cell line: 2>()
1 model.compile(optimizer="sgd", metrics=[tf.keras.metrics.AUC()])
----> 2 model.fit(train, batch_size=1024, epochs=1)
File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/models/base.py:713, in BaseModel.fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing, train_metrics_steps, **kwargs)
705 callbacks = self._add_metrics_callback(callbacks, train_metrics_steps)
707 fit_kwargs = {
708 k: v
709 for k, v in locals().items()
710 if k not in ["self", "kwargs", "train_metrics_steps", "__class__"]
711 }
--> 713 return super().fit(**fit_kwargs)
File /usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
File /tmp/__autograph_generated_fileh2iicnzm.py:15, in outer_factory.<locals>.inner_factory.<locals>.tf__train_function(iterator)
13 try:
14 do_return = True
---> 15 retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
16 except:
17 do_return = False
File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/models/base.py:580, in BaseModel.train_step(self, data)
577 # Run backwards pass.
578 self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
--> 580 metrics = self.compute_metrics(outputs, training=True)
581 # Adding regularization loss to metrics
582 metrics["regularization_loss"] = tf.reduce_sum(cast_losses_to_common_dtype(self.losses))
AssertionError: in user code:
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1051, in train_function *
return step_function(self, iterator)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1040, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1030, in run_step **
outputs = model.train_step(data)
File "/usr/local/lib/python3.8/dist-packages/merlin/models/tf/models/base.py", line 580, in train_step
metrics = self.compute_metrics(outputs, training=True)
AssertionError: Called a function referencing variables which have been deleted. This likely means that function-local variables were created and not referenced elsewhere in the program. This is generally a mistake; consider storing variables in an object attribute on first call.
Same error raised with Wide and Deep Model:
model = ml.WideAndDeepModel(
ecommerce_data.schema,
wide_schema=wide_schema,
deep_schema=deep_schema,
wide_preprocess=ml.CategoricalOneHot(wide_schema),
deep_block=ml.MLPBlock([32, 16]),
prediction_tasks=ml.BinaryClassificationTask("click"),
)
#first time compile
model.compile(optimizer="adam")
model.fit(ecommerce_data, batch_size=1024, epochs=1)
#second time compile
model.compile(optimizer="sgd")
model.fit(ecommerce_data, batch_size=1024, epochs=1)
Expected behavior
Like keras model, it is supposed to support comiple and fit for multiple times without any error:
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
#Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
#Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
#Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
#Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
#convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
model = keras.Sequential(
[
keras.Input(shape=input_shape),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes, activation="softmax"),
]
)
batch_size = 128
epochs = 1
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
model.compile(loss="categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
Environment details
- Merlin version: using merlin-tensorflow:22.07 with the latest main branch pulled.
- Platform:
- Python version:
- PyTorch version (GPU?):
- Tensorflow version (GPU?):
@timmy00 , please triage this bug and assign priority and severity labels. Detailed instructions are available here
@timmy00 , please triage this bug and assign priority and severity labels. Detailed instructions are available here
Fixed, thank you!
@marcromeyn this problem is still valid. Oliver has a WIP PR https://github.com/NVIDIA-Merlin/models/pull/787 to fix.