models icon indicating copy to clipboard operation
models copied to clipboard

[BUG] batch-predict throws an error when inputs contain list features

Open sararb opened this issue 3 years ago • 0 comments

Bug description

Applying batch-predict on a data frame containing multi-hot / sparse features throws the following error

funcname = '<merlin.models.tf.utils.batch_utils.TFModelEncode ', udf = True

    @contextmanager
    def raise_on_meta_error(funcname=None, udf=False):
        """Reraise errors in this block to show metadata inference failure.
    
        Parameters
        ----------
        funcname : str, optional
            If provided, will be added to the error message to indicate the
            name of the method that failed.
        """
        try:
>           yield

../../../anaconda3/envs/marlin-dev/lib/python3.8/site-packages/dask/dataframe/utils.py:195: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

func = <merlin.models.tf.utils.batch_utils.TFModelEncode object at 0x7f68e80a8760>
udf = True
args = (Dask DataFrame Structure:
              session_id user_id country user_age user_genres item_id item_category item_re...        ...          ...         ...      ...      ...      ...             ...
Dask Name: from_pandas, 1 graph layer,)
kwargs = {'filter_input_columns': ['user_id']}

    def _emulate(func, *args, udf=False, **kwargs):
        """
        Apply a function using args / kwargs. If arguments contain dd.DataFrame /
        dd.Series, using internal cache (``_meta``) for calculation
        """
        with raise_on_meta_error(funcname(func), udf=udf), check_numeric_only_deprecation():
>           return func(*_extract_meta(args, True), **_extract_meta(kwargs, True))

../../../anaconda3/envs/marlin-dev/lib/python3.8/site-packages/dask/dataframe/core.py:6582: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <merlin.models.tf.utils.batch_utils.TFModelEncode object at 0x7f68e80a8760>
df =    session_id  user_id  country  ...  like position  play_percentage
0           1        1        1  ...   1.0        1              1.0
1           1        1        1  ...   1.0        1              1.0

[2 rows x 13 columns]
filter_input_columns = ['user_id'], filter_output_columns = None

    def __call__(
        self,
        df: DataFrameType,
        filter_input_columns: tp.Optional[tp.List[str]] = None,
        filter_output_columns: tp.Optional[tp.List[str]] = None,
    ) -> DataFrameType:
        # Set defaults
        iterator_func = self.data_iterator_func or (lambda x: [x])
        encode_func = self.model_encode_func or (lambda x, y: x(y))
        concat_func = self.output_concat_func or np.concatenate
    
        # Iterate over batches of df and collect predictions
>       outputs = concat_func([encode_func(self.model, batch) for batch in iterator_func(df)])

merlin/models/tf/utils/batch_utils.py:51: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

.0 = <generator object Sequence.__iter__ at 0x7f68e075d900>

>   outputs = concat_func([encode_func(self.model, batch) for batch in iterator_func(df)])

...

self = <merlin.models.tf.loader.Loader object at 0x7f68e0714490>, idx = 0

    def __getitem__(self, idx):
        """
        implemented exclusively for consistency
        with Keras model.fit. Does not leverage
        passed idx in any way
        """
>       return DataLoader.__next__(self)

merlin/models/tf/loader.py:337: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <merlin.models.tf.loader.Loader object at 0x7f68e0714490>

    def __next__(self):
>       return self._get_next_batch()

merlin/models/loader/backend.py:356: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <merlin.models.tf.loader.Loader object at 0x7f68e0714490>

    def _get_next_batch(self):
        """
        adding this cheap shim so that we can call this
        step without it getting overridden by the
        framework-specific parent class's `__next__` method.
        TODO: can this be better solved with a metaclass
        implementation? My gut is that we don't actually
        necessarily *want*, in general, to be overriding
        __next__ and __iter__ methods
        """
        # we've never initialized, do that now
        # need this because tf.keras.Model.fit will
        # call next() cold
        if self._workers is None:
            DataLoader.__iter__(self)
    
        # get the first chunks
        if self._batch_itr is None:
>           self._fetch_chunk()

merlin/models/loader/backend.py:368: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <merlin.models.loader.backend.ChunkQueue object at 0x7f68e07142b0>
dev = 'cpu'

    @annotate("load_chunks", color="darkgreen", domain="nvt_python")
    def load_chunks(self, dev):
        try:
            itr = iter(self.itr)
            if self.dataloader.device != "cpu":
                with self.dataloader._get_device_ctx(dev):
                    self.chunk_logic(itr)
            else:
>               self.chunk_logic(itr)

merlin/models/loader/backend.py:158: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<merlin.models.loader.backend.ChunkQueue object at 0x7f68e07142b0>, <generator object DataFrameIter.__iter__ at 0x7f68e075db30>)
kwds = {}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)

../../../anaconda3/envs/marlin-dev/lib/python3.8/contextlib.py:75: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <merlin.models.loader.backend.ChunkQueue object at 0x7f68e07142b0>
itr = <generator object DataFrameIter.__iter__ at 0x7f68e075db30>

    @annotate("chunk_logic", color="darkgreen", domain="nvt_python")
    def chunk_logic(self, itr):
        spill = None
        for chunks in self.batch(itr):
            if self.stopped:
                return
    
            if spill is not None and not spill.empty:
                chunks.insert(0, spill)
    
            chunks = concat(chunks)
            chunks.reset_index(drop=True, inplace=True)
            chunks, spill = self.get_batch_div_chunk(chunks, self.dataloader.batch_size)
            if self.shuffle:
                chunks = shuffle_df(chunks)
    
            if len(chunks) > 0:
                chunks = self.dataloader.make_tensors(chunks, self.dataloader._use_nnz)
                # put returns True if buffer is stopped before
                # packet can be put in queue. Keeps us from
                # freezing on a put on a full queue
                if self.put(chunks):
                    return
            chunks = None
        # takes care final batch, which is less than batch size
        if not self.dataloader.drop_last and spill is not None and not spill.empty:
>           spill = self.dataloader.make_tensors(spill, self.dataloader._use_nnz)
 
....
../../../anaconda3/envs/marlin-dev/lib/python3.8/contextlib.py:75: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <merlin.models.tf.loader.Loader object at 0x7f68e0714490>
gdf =    session_id  user_age  item_id  ...  like  position play_percentage
0           1         1        1  ...   1.0         1             1.0
1           1         1        1  ...   1.0         1             1.0

[2 rows x 10 columns]
use_nnz = True

    @annotate("make_tensors", color="darkgreen", domain="nvt_python")
    def make_tensors(self, gdf, use_nnz=False):
        split_idx = self._get_segment_lengths(len(gdf))
    
        # map from big chunk to framework-specific tensors
>       chunks = self._create_tensors(gdf)

../../../anaconda3/envs/marlin-dev/lib/python3.8/contextlib.py:75: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <merlin.models.tf.loader.Loader object at 0x7f68e0714490>
gdf =    session_id  user_age  item_id  ...  like  position play_percentage
0           1         1        1  ...   1.0         1             1.0
1           1         1        1  ...   1.0         1             1.0

[2 rows x 10 columns]

    @annotate("_create_tensors", color="darkgreen", domain="nvt_python")
    def _create_tensors(self, gdf):
        """
        Breaks a dataframe down into the relevant
        categorical, continuous, and label tensors.
        Can be overrideen
        """
        workflow_nodes = (self.cat_names, self.cont_names, self.label_names)
        dtypes = (self._LONG_DTYPE, self._FLOAT32_DTYPE, self._FLOAT32_DTYPE)
        tensors = []
        offsets = make_df(device=self.device)
        for column_names, dtype in zip(workflow_nodes, dtypes):
            ...
            gdf_i = gdf[column_names]
            gdf.drop(columns=column_names, inplace=True)
    
            scalars, lists = self._separate_list_columns(gdf_i)
    
            x = None
            if scalars:
                # should always return dict column_name: values, offsets (optional)
>               x = self._to_tensor(gdf_i[scalars], dtype)

merlin/models/loader/backend.py:587: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <merlin.models.tf.loader.Loader object at 0x7f68e0714490>
gdf =    user_id  country user_genres
0        1        1         foo
1        1        1         foo
dtype = tf.int64

    def _to_tensor(self, gdf, dtype=None):
         .... 
            dlpack = self._pack(gdf.values.T)
        # catch error caused by tf eager context
        # not being initialized
        try:
>           x = self._unpack(dlpack)

merlin/models/tf/loader.py:439: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <merlin.models.tf.loader.Loader object at 0x7f68e0714490>
gdf = array([[1, 1],
       [1, 1],
       ['foo', 'foo']], dtype=object)

    def _unpack(self, gdf):
        if hasattr(gdf, "shape"):
>           return tf.convert_to_tensor(gdf)

merlin/models/tf/loader.py:420: 
... 
            if udf:
                msg += (
                    "You have supplied a custom function and Dask is unable to \n"
                    "determine the type of output that that function returns. \n\n"
                    "To resolve this please provide a meta= keyword.\n"
                    "The docstring of the Dask function you ran should have more information.\n\n"
                )
            msg += (
                "Original error is below:\n"
                "------------------------\n"
                "{1}\n\n"
                "Traceback:\n"
                "---------\n"
                "{2}"
            )
            msg = msg.format(f" in `{funcname}`" if funcname else "", repr(e), tb)
>           raise ValueError(msg) from e
E           ValueError: Metadata inference failed in `<merlin.models.tf.utils.batch_utils.TFModelEncode `.
E           
E           You have supplied a custom function and Dask is unable to 
E           determine the type of output that that function returns. 
E           
E           To resolve this please provide a meta= keyword.
E           The docstring of the Dask function you ran should have more information.
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

Steps/Code to reproduce bug

  1. Run the following test inside test_retrieval.py:
def test_two_tower_v2_export_embeddings_sparse_inputs(
    music_streaming_data: Dataset,
):
    user_schema = music_streaming_data.schema.select_by_tag(Tags.USER)
    candidate_schema = music_streaming_data.schema.select_by_tag(Tags.ITEM)

    query = mm.Encoder(user_schema, mm.MLPBlock([8]))
    candidate = mm.Encoder(candidate_schema, mm.MLPBlock([8]))
    model = mm.TwoTowerModelV2(
        query_tower=query, candidate_tower=candidate, negative_samplers=["in-batch"]
    )

    model, _ = testing_utils.model_test(model, music_streaming_data, reload_model=False)

    queries = model.query_embeddings(music_streaming_data, batch_size=10, index=Tags.USER_ID).compute()
    _check_embeddings(queries, 100, 8, "user_id")

Expected behavior

Batch-predict should support sparse input features. This is for example required by the YoutubeDNN retrieval or Session-based models for exporting the query embeddings.

sararb avatar Oct 11 '22 20:10 sararb