models icon indicating copy to clipboard operation
models copied to clipboard

[BUG] `post` arg gives error when we set it in GPT2Block or XLNetBlock

Open rnyak opened this issue 3 years ago • 0 comments

Bug description

I get the following error when I set post='sequence_mean' arg.

AttributeError                            Traceback (most recent call last)
Cell In [40], line 39
     22 model = mm.Model(
     23     mm.InputBlockV2(
     24         schema,
   (...)
     35     ),
     36 )
     38 model.compile(run_eagerly=False, optimizer='adam')
---> 39 model.fit(loader, epochs=1)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/models/base.py:831, in BaseModel.fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing, train_metrics_steps, pre, **kwargs)
    828     self._reset_compile_cache()
    829     self.train_pre = pre
--> 831 out = super().fit(**fit_kwargs)
    833 if pre:
    834     del self.train_pre

File /usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/models/base.py:1006, in Model._maybe_build(self, inputs)
   1004             child._feature_shapes = feature_shapes
   1005             child._feature_dtypes = feature_dtypes
-> 1006 super()._maybe_build(inputs)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/models/base.py:1024, in Model.build(self, input_shape)
   1022 for layer in self.blocks:
   1023     try:
-> 1024         layer.build(input_shape)
   1025     except TypeError:
   1026         t, v, tb = sys.exc_info()

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/transformers/block.py:113, in TransformerBlock.build(self, input_shape)
    105 def build(self, input_shape=None):
    106     """Builds the sequential block
    107 
    108     Parameters
   (...)
    111         The input shape, by default None
    112     """
--> 113     combinators.build_sequentially(
    114         self, [*list(self.to_call_pre), self.transformer, *list(self.to_call_post)], input_shape
    115     )

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:829, in build_sequentially(self, layers, input_shape)
    827 for layer in layers:
    828     try:
--> 829         layer.build(input_shape)
    830     except TypeError:
    831         t, v, tb = sys.exc_info()

AttributeError: 'str' object has no attribute 'build'

Steps/Code to reproduce bug

please run the code below to repro the error:

def classification_loader(sequence_testing_data: Dataset):
    def _target_to_onehot(inputs, targets):
        targets = tf.squeeze(tf.one_hot(targets, 63))
        return inputs, targets

    schema = sequence_testing_data.schema.select_by_name(
        ["item_id_seq", "categories", "user_country"]
    )
    schema["user_country"] = schema["user_country"].with_tags(
        schema["user_country"].tags + "target"
    )
    sequence_testing_data.schema = schema
    dataloader = mm.Loader(sequence_testing_data, batch_size=50, transform=_target_to_onehot)
    return dataloader, schema

from merlin.datasets.synthetic import generate_data
sequence_testing_data = generate_data("sequence-testing", num_rows=1000)
loader, schema = classification_loader(sequence_testing_data)

model = mm.Model(
    mm.InputBlockV2(
        schema,
        embeddings=mm.Embeddings(schema, sequence_combiner=None),
    ),
    mm.XLNetBlock(
        d_model=48,
        n_head=8,
        n_layer=2,
        post='sequence_mean',
    ),
    mm.CategoricalOutput(
        to_call=schema["user_country"],
    ),
)

model.compile(run_eagerly=False, optimizer='adam')
model.fit(loader, epochs=1)

rnyak avatar Oct 26 '22 19:10 rnyak