recommenders icon indicating copy to clipboard operation
recommenders copied to clipboard

tfrs.layers.dcn.Cross - missed compute_output_shape method

Open schistyakov opened this issue 2 years ago • 0 comments

I am implementing Retrieval model on sequential query data. Each historical data contain several categorical variables, which converted to embeddings, concatenated and pushed to GRU to create a single query embedding.

I use tfrs.layers.dcn.Cross layer to improve embedding before GRU, to apply it I call it within tf.keras.layers.TimeDistributed layer. It generates the following error

 raise NotImplementedError(
    828         'Please run in eager mode or implement the `compute_output_shape` '
--> 829         'method on your layer (%s).' % self.__class__.__name__)
    830 
    831   @doc_controls.for_subclass_implementers

NotImplementedError: Exception encountered when calling layer "time_distributed_5" (type TimeDistributed).

Please run in eager mode or implement the `compute_output_shape` method on your layer (Cross).

Call arguments received by layer "time_distributed_5" (type TimeDistributed):
  • inputs=tf.Tensor(shape=(None, None, 10), dtype=float32)
  • training=None
  • mask=None

here is a short example to reproduce an error:

model = tf.keras.Sequential([
    tf.keras.layers.Embedding(input_dim=10, output_dim=10),
    tf.keras.layers.TimeDistributed(tfrs.layers.dcn.Cross())
])

since Cross layer has the same output shape as it has in input. I make a workaround to solve it. So it works for me, but it would be nice to implement it without workaround

class MyCross(tfrs.layers.dcn.Cross):
    def compute_output_shape(self, input_shape):
        return input_shape

model = tf.keras.Sequential([
    tf.keras.layers.Embedding(input_dim=10, output_dim=10),
    tf.keras.layers.TimeDistributed(MyCross())

])

My environment:

import tensorflow as tf
import tensorflow_recommenders as tfrs

print(f'{tf.__version__}')
print(f'{tfrs.__version__}')

2.9.2
v0.7.2

schistyakov avatar Nov 10 '22 19:11 schistyakov