recommenders icon indicating copy to clipboard operation
recommenders copied to clipboard

user model as an object instead of keras sequential

Open naarkhoo opened this issue 2 years ago • 1 comments

I try to make it short, basically if I have a model as

user_model = tf.keras.Sequential([
  tf.keras.layers.IntegerLookup(vocabulary=unique_user_ids, mask_token=None),
  tf.keras.layers.Embedding(input_dim=len(unique_user_ids) + 1, output_dim=embedding_dimension)
])

content_model = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.IntegerLookup(vocabulary=contents_df, mask_token=None),
  tf.keras.layers.Embedding(input_dim=len(contents_df) + 1, output_dim=embedding_dimension)
])

candidates=contents_ds.batch(metrics_batchsize).map(content_model)

metrics = tfrs.metrics.FactorizedTopK(
  candidates=candidates
)

task = tfrs.tasks.Retrieval(
  metrics=metrics
)

class ContentModel(tfrs.Model):

  def __init__(self, user_model, content_model):
    super().__init__()
    self.content_model: tf.keras.Model = content_model
    self.user_model: tf.keras.Model = user_model
    self.task: tf.keras.layers.Layer = task

  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
    content_embeddings = self.content_model(features["content_id"])
    user_embeddings = self.user_model(features["user_id"])

    return self.task(user_embeddings, content_embeddings)

model = ContentModel(user_model, content_model)
model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=learning_rate))

cached_train = train.shuffle(view_size).batch(train_batchsize).cache()
cached_test = test.batch(test_batchsize).cache()

model.fit(cached_train, epochs=epochs)

it works perfectly - but if I try to make the user model as an object


class user_model(tf.keras.Model):

  def __init__(self): # use_timestamps
    super().__init__()

    self.user_embedding = tf.keras.Sequential([
        tf.keras.layers.IntegerLookup(vocabulary=unique_user_ids, mask_token=None),
        tf.keras.layers.Embedding(input_dim = len(unique_user_ids) + 1, output_dim = embedding_dimension),
    ])

  def call(self, inputs):
    return self.user_embedding(inputs["user_id"])

it complains about the input


[<ipython-input-14-4f8ec54bf9e2>](https://localhost:8080/#) in <module>
      2 cached_test = test.batch(test_batchsize).cache()
      3 
----> 4 model.fit(cached_train, epochs=epochs)

3 frames

[<ipython-input-11-b3c45f667882>](https://localhost:8080/#) in compute_loss(self, features, training)
     29   def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
     30     content_embeddings = self.content_model(features["content_id"])
---> 31     user_embeddings = self.user_model(features["user_id"])
     32 
     33     return self.task(user_embeddings, content_embeddings)

TypeError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1051, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1040, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1030, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_recommenders/models/base.py", line 68, in train_step
        loss = self.compute_loss(inputs, training=True)
    File "<ipython-input-11-b3c45f667882>", line 31, in compute_loss
        user_embeddings = self.user_model(features["user_id"])

    TypeError: __init__() takes 1 positional argument but 2 were given

i would like to inherent tf.keras.Model instead of tf.keras.Sequential for my user_model and content_model Thanks

naarkhoo avatar Nov 01 '22 21:11 naarkhoo

self.user_model is your user_model class and not an instance of that class. In compute_loss(), you think you're calling user_model.call(), but actually you are calling the user_model.__init__().

You'll be able to spot it much more easily if you rename your class UserModel.

patrickorlando avatar Nov 02 '22 05:11 patrickorlando