Check for `Prediction` namedtuple in train/test steps
Based on by comments in #685
Goals :soccer:
Check for Prediction namedtuple in train/test steps when unpacking data.
Implementation Details :construction:
Wrapping unpack_x_y_sample_weight to handle Prediction tuple as input
Testing Details :mag:
updated test of negative sampling layer in BatchedDataset where the output is a prediction tuple
Click to view CI Results
GitHub pull request #702 of commit d6be41cc6409b661cae3a4a2e3cfe6fa2eb461f4, no merge conflicts.
Running as SYSTEM
Setting status of d6be41cc6409b661cae3a4a2e3cfe6fa2eb461f4 to PENDING with url https://10.20.13.93:8080/job/merlin_models/1100/console and message: 'Pending'
Using context: Jenkins
Building on master in workspace /var/jenkins_home/workspace/merlin_models
using credential nvidia-merlin-bot
> git rev-parse --is-inside-work-tree # timeout=10
Fetching changes from the remote Git repository
> git config remote.origin.url https://github.com/NVIDIA-Merlin/models/ # timeout=10
Fetching upstream changes from https://github.com/NVIDIA-Merlin/models/
> git --version # timeout=10
using GIT_ASKPASS to set credentials This is the bot credentials for our CI/CD
> git fetch --tags --force --progress -- https://github.com/NVIDIA-Merlin/models/ +refs/pull/702/*:refs/remotes/origin/pr/702/* # timeout=10
> git rev-parse d6be41cc6409b661cae3a4a2e3cfe6fa2eb461f4^{commit} # timeout=10
Checking out Revision d6be41cc6409b661cae3a4a2e3cfe6fa2eb461f4 (detached)
> git config core.sparsecheckout # timeout=10
> git checkout -f d6be41cc6409b661cae3a4a2e3cfe6fa2eb461f4 # timeout=10
Commit message: "Check for `Prediction` namedtuple in train/test steps"
> git rev-list --no-walk 9a05026d5cc3528d108ce602f55c0d89eebba4c9 # timeout=10
[merlin_models] $ /bin/bash /tmp/jenkins8248793787395172768.sh
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Requirement already satisfied: testbook in /usr/local/lib/python3.8/dist-packages (0.4.2)
Requirement already satisfied: nbformat>=5.0.4 in /usr/local/lib/python3.8/dist-packages (from testbook) (5.4.0)
Requirement already satisfied: nbclient>=0.4.0 in /usr/local/lib/python3.8/dist-packages (from testbook) (0.6.6)
Requirement already satisfied: traitlets>=5.1 in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (5.3.0)
Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.9.1)
Requirement already satisfied: jupyter-core in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.11.1)
Requirement already satisfied: fastjsonschema in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (2.16.1)
Requirement already satisfied: jupyter-client>=6.1.5 in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (7.3.4)
Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (1.5.5)
Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (22.1.0)
Requirement already satisfied: importlib-resources>=1.4.0; python_version =2.6->nbformat>=5.0.4->testbook) (5.9.0)
Requirement already satisfied: pkgutil-resolve-name>=1.3.10; python_version =2.6->nbformat>=5.0.4->testbook) (1.3.10)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (0.18.1)
Requirement already satisfied: entrypoints in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (0.4)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (2.8.2)
Requirement already satisfied: pyzmq>=23.0 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (23.2.1)
Requirement already satisfied: tornado>=6.0 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (6.2)
Requirement already satisfied: zipp>=3.1.0; python_version =1.4.0; python_version jsonschema>=2.6->nbformat>=5.0.4->testbook) (3.8.1)
Requirement already satisfied: six>=1.5 in /var/jenkins_home/.local/lib/python3.8/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (1.15.0)
============================= test session starts ==============================
platform linux -- Python 3.8.10, pytest-7.1.2, pluggy-1.0.0
rootdir: /var/jenkins_home/workspace/merlin_models/models, configfile: pyproject.toml
plugins: anyio-3.6.1, xdist-2.5.0, forked-1.4.0, cov-3.0.0
collected 678 items
tests/unit/config/test_schema.py .... [ 0%]
tests/unit/datasets/test_advertising.py .s [ 0%]
tests/unit/datasets/test_ecommerce.py ..sss [ 1%]
tests/unit/datasets/test_entertainment.py ....sss. [ 2%]
tests/unit/datasets/test_social.py . [ 2%]
tests/unit/datasets/test_synthetic.py ...... [ 3%]
tests/unit/implicit/test_implicit.py . [ 3%]
tests/unit/lightfm/test_lightfm.py . [ 4%]
tests/unit/tf/test_core.py ...... [ 5%]
tests/unit/tf/test_dataset.py ................ [ 7%]
tests/unit/tf/test_public_api.py . [ 7%]
tests/unit/tf/blocks/test_cross.py ........... [ 9%]
tests/unit/tf/blocks/test_dlrm.py .......... [ 10%]
tests/unit/tf/blocks/test_interactions.py . [ 10%]
tests/unit/tf/blocks/test_mlp.py ................................. [ 15%]
tests/unit/tf/blocks/test_optimizer.py s................................ [ 20%]
..................... [ 23%]
tests/unit/tf/blocks/retrieval/test_base.py . [ 23%]
tests/unit/tf/blocks/retrieval/test_matrix_factorization.py .. [ 24%]
tests/unit/tf/blocks/retrieval/test_two_tower.py ........... [ 25%]
tests/unit/tf/blocks/sampling/test_cross_batch.py . [ 25%]
tests/unit/tf/blocks/sampling/test_in_batch.py . [ 25%]
tests/unit/tf/core/test_aggregation.py ......... [ 27%]
tests/unit/tf/core/test_base.py .. [ 27%]
tests/unit/tf/core/test_combinators.py s................... [ 30%]
tests/unit/tf/core/test_index.py ... [ 30%]
tests/unit/tf/core/test_prediction.py .. [ 31%]
tests/unit/tf/core/test_tabular.py .... [ 31%]
tests/unit/tf/core/test_transformations.py s............................ [ 36%]
.................. [ 38%]
tests/unit/tf/data_augmentation/test_misc.py . [ 38%]
tests/unit/tf/data_augmentation/test_negative_sampling.py .........F [ 40%]
tests/unit/tf/data_augmentation/test_noise.py ..... [ 41%]
tests/unit/tf/examples/test_01_getting_started.py . [ 41%]
tests/unit/tf/examples/test_02_dataschema.py . [ 41%]
tests/unit/tf/examples/test_03_exploring_different_models.py . [ 41%]
tests/unit/tf/examples/test_04_export_ranking_models.py . [ 41%]
tests/unit/tf/examples/test_05_export_retrieval_model.py . [ 41%]
tests/unit/tf/examples/test_06_advanced_own_architecture.py . [ 42%]
tests/unit/tf/examples/test_07_train_traditional_models.py . [ 42%]
tests/unit/tf/examples/test_usecase_ecommerce_session_based.py . [ 42%]
tests/unit/tf/examples/test_usecase_pretrained_embeddings.py . [ 42%]
tests/unit/tf/inputs/test_continuous.py ..... [ 43%]
tests/unit/tf/inputs/test_embedding.py ................................. [ 48%]
.. [ 48%]
tests/unit/tf/inputs/test_tabular.py .................. [ 51%]
tests/unit/tf/layers/test_queue.py .............. [ 53%]
tests/unit/tf/losses/test_losses.py ....................... [ 56%]
tests/unit/tf/metrics/test_metrics_popularity.py ..... [ 57%]
tests/unit/tf/metrics/test_metrics_topk.py ....................... [ 60%]
tests/unit/tf/models/test_base.py s................ [ 63%]
tests/unit/tf/models/test_benchmark.py .. [ 63%]
tests/unit/tf/models/test_ranking.py .............................. [ 67%]
tests/unit/tf/models/test_retrieval.py ................................ [ 72%]
tests/unit/tf/prediction_tasks/test_classification.py .. [ 72%]
tests/unit/tf/prediction_tasks/test_multi_task.py ................ [ 75%]
tests/unit/tf/prediction_tasks/test_next_item.py ..... [ 75%]
tests/unit/tf/prediction_tasks/test_regression.py .. [ 76%]
tests/unit/tf/prediction_tasks/test_retrieval.py . [ 76%]
tests/unit/tf/prediction_tasks/test_sampling.py ...... [ 77%]
tests/unit/tf/predictions/test_base.py ..... [ 78%]
tests/unit/tf/predictions/test_classification.py ....... [ 79%]
tests/unit/tf/predictions/test_dot_product.py ........ [ 80%]
tests/unit/tf/predictions/test_regression.py .. [ 80%]
tests/unit/tf/predictions/test_sampling.py .... [ 81%]
tests/unit/tf/utils/test_batch.py .... [ 81%]
tests/unit/tf/utils/test_tf_utils.py ..... [ 82%]
tests/unit/torch/test_dataset.py ......... [ 83%]
tests/unit/torch/test_public_api.py . [ 83%]
tests/unit/torch/block/test_base.py .... [ 84%]
tests/unit/torch/block/test_mlp.py . [ 84%]
tests/unit/torch/features/test_continuous.py .. [ 84%]
tests/unit/torch/features/test_embedding.py .............. [ 87%]
tests/unit/torch/features/test_tabular.py .... [ 87%]
tests/unit/torch/model/test_head.py ............ [ 89%]
tests/unit/torch/model/test_model.py .. [ 89%]
tests/unit/torch/tabular/test_aggregation.py ........ [ 90%]
tests/unit/torch/tabular/test_tabular.py ... [ 91%]
tests/unit/torch/tabular/test_transformations.py ....... [ 92%]
tests/unit/utils/test_schema_utils.py ................................ [ 97%]
tests/unit/xgb/test_xgboost.py .................... [100%]
=================================== FAILURES ===================================
___________ TestAddRandomNegativesToBatch.test_model_with_dataloader ___________
self = <tests.unit.tf.data_augmentation.test_negative_sampling.TestAddRandomNegativesToBatch object at 0x7f2a086349d0>
music_streaming_data = <merlin.io.dataset.Dataset object at 0x7f298427af40>
tf_random_seed = 1
def test_model_with_dataloader(self, music_streaming_data: Dataset, tf_random_seed: int):
add_negatives = UniformNegativeSampling(music_streaming_data.schema, 5, seed=tf_random_seed)
batch_size, n_per_positive = 10, 5
dataset = BatchedDataset(music_streaming_data, batch_size=batch_size)
dataset = dataset.map(add_negatives)
batch_output = next(iter(dataset))
features, targets = batch_output.outputs, batch_output.targets
expected_batch_size = batch_size + batch_size * n_per_positive
assert features["item_genres"].shape[0] > batch_size
assert features["item_genres"].shape[0] <= expected_batch_size
assert all(
f.shape[0] > batch_size and f.shape[0] <= expected_batch_size for f in features.values()
)
assert all(
f.shape[0] > batch_size and f.shape[0] <= expected_batch_size for f in targets.values()
)
model = mm.Model(
mm.InputBlock(music_streaming_data.schema),
mm.MLPBlock([64]),
mm.BinaryClassificationTask("click"),
)
assert model(features).shape[0] > batch_size
assert model(features).shape[0] <= expected_batch_size
testing_utils.model_test(model, dataset)
tests/unit/tf/data_augmentation/test_negative_sampling.py:237:
merlin/models/tf/utils/testing_utils.py:89: in model_test
losses = model.fit(dataset, batch_size=50, epochs=epochs, steps_per_epoch=1)
merlin/models/tf/models/base.py:725: in fit
return super().fit(**fit_kwargs)
/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:60: in error_handler
return fn(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1358: in fit
data_handler = data_adapter.get_data_handler(
/usr/local/lib/python3.8/dist-packages/keras/engine/data_adapter.py:1401: in get_data_handler
return DataHandler(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/keras/engine/data_adapter.py:1151: in init
self._adapter = adapter_cls(
/usr/local/lib/python3.8/dist-packages/keras/engine/data_adapter.py:926: in init
super(KerasSequenceAdapter, self).init(
/usr/local/lib/python3.8/dist-packages/keras/engine/data_adapter.py:801: in init
peek = self._standardize_batch(peek)
/usr/local/lib/python3.8/dist-packages/keras/engine/data_adapter.py:845: in _standardize_batch
x, y, sample_weight = unpack_x_y_sample_weight(data)
data = Prediction(outputs={'item_genres': <tf.RaggedTensor [[98, 58, 94, ..., 16, 76, 94],
[94, 53, 57, ..., 87, 69, 67],
[... [0. ],
[0. ],
[0. ],
[0. ]])>}, sample_weight=None, features=None)
@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
def unpack_x_y_sample_weight(data):
"""Unpacks user-provided data tuple.
This is a convenience utility to be used when overriding
`Model.train_step`, `Model.test_step`, or `Model.predict_step`.
This utility makes it easy to support data of the form `(x,)`,
`(x, y)`, or `(x, y, sample_weight)`.
Standalone usage:
>>> features_batch = tf.ones((10, 5))
>>> labels_batch = tf.zeros((10, 5))
>>> data = (features_batch, labels_batch)
>>> # `y` and `sample_weight` will default to `None` if not provided.
>>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
>>> sample_weight is None
True
Example in overridden `Model.train_step`:
```python
class MyModel(tf.keras.Model):
def train_step(self, data):
# If `sample_weight` is not provided, all samples will be weighted
# equally.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
trainable_variables = self.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
self.compiled_metrics.update_state(y, y_pred, sample_weight)
return {m.name: m.result() for m in self.metrics}
```
Args:
data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
Returns:
The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not
provided.
"""
if isinstance(data, list):
data = tuple(data)
if not isinstance(data, tuple):
return (data, None, None)
elif len(data) == 1:
return (data[0], None, None)
elif len(data) == 2:
return (data[0], data[1], None)
elif len(data) == 3:
return (data[0], data[1], data[2])
else:
error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
"or `(x, y, sample_weight)`, found: {}").format(data)
raise ValueError(error_msg)
E ValueError: Data is expected to be in format x, (x,), (x, y), or (x, y, sample_weight), found: Prediction(outputs={'item_genres': <tf.RaggedTensor [[98, 58, 94, ..., 16, 76, 94],
E [94, 53, 57, ..., 87, 69, 67],
E [95, 60, 22, ..., 32, 92, 1],
E ...,
E [98, 58, 94, ..., 16, 76, 94],
E [2, 41, 14, ..., 5, 9, 78],
E [94, 53, 57, ..., 87, 69, 67]]>, 'user_genres': <tf.RaggedTensor [[22, 1, 77, ..., 71, 38, 34],
E [13, 38, 80, ..., 85, 44, 95],
E [82, 61, 70, ..., 58, 6, 77],
E ...,
E [84, 94, 19, ..., 38, 31, 41],
E [26, 74, 75, ..., 5, 12, 25],
E [26, 80, 94, ..., 38, 4, 26]]>, 'session_id': <tf.Tensor: shape=(52, 1), dtype=int64, numpy=
E array([[ 62],
E [ 91],
E [ 10],
E [ 27],
E [101],
E [105],
E [ 25],
E [ 32],
E [ 21],
E [ 40],
E [ 62],
E [ 62],
E [ 62],
E [ 62],
E [ 91],
E [ 91],
E [ 91],
E [ 91],
E [ 91],
E [ 10],
E [ 10],
E [ 10],
E [ 10],
E [ 10],
E [ 27],
E [ 27],
E [ 27],
E [101],
E [101],
E [101],
E [101],
E [101],
E [105],
E [105],
E [105],
E [ 25],
E [ 25],
E [ 25],
E [ 32],
E [ 32],
E [ 32],
E [ 32],
E [ 32],
E [ 21],
E [ 21],
E [ 21],
E [ 21],
E [ 21],
E [ 40],
E [ 40],
E [ 40],
E [ 40]])>, 'item_id': <tf.Tensor: shape=(52, 1), dtype=int64, numpy=
E array([[135],
E [ 33],
E [ 9],
E [ 5],
E [ 20],
E [ 50],
E [ 31],
E [ 14],
E [ 8],
E [ 39],
E [ 14],
E [ 8],
E [ 39],
E [ 5],
E [ 50],
E [ 9],
E [ 31],
E [ 50],
E [ 5],
E [ 14],
E [ 20],
E [ 14],
E [135],
E [ 20],
E [ 20],
E [ 31],
E [ 33],
E [ 8],
E [ 5],
E [ 39],
E [ 9],
E [ 39],
E [ 39],
E [ 33],
E [ 39],
E [ 5],
E [ 20],
E [ 50],
E [ 31],
E [ 8],
E [ 39],
E [135],
E [ 31],
E [ 39],
E [ 33],
E [ 20],
E [ 50],
E [ 50],
E [ 33],
E [135],
E [ 8],
E [ 33]])>, 'item_category': <tf.Tensor: shape=(52, 1), dtype=int64, numpy=
E array([[77],
E [18],
E [ 4],
E [ 2],
E [10],
E [28],
E [17],
E [ 7],
E [ 3],
E [21],
E [ 7],
E [ 3],
E [21],
E [ 2],
E [28],
E [ 4],
E [17],
E [28],
E [ 2],
E [ 7],
E [10],
E [ 7],
E [77],
E [10],
E [10],
E [17],
E [18],
E [ 3],
E [ 2],
E [21],
E [ 4],
E [21],
E [21],
E [18],
E [21],
E [ 2],
E [10],
E [28],
E [17],
E [ 3],
E [21],
E [77],
E [17],
E [21],
E [18],
E [10],
E [28],
E [28],
E [18],
E [77],
E [ 3],
E [18]])>, 'user_id': <tf.Tensor: shape=(52, 1), dtype=int64, numpy=
E array([[53],
E [24],
E [45],
E [14],
E [12],
E [44],
E [39],
E [32],
E [ 3],
E [57],
E [53],
E [53],
E [53],
E [53],
E [24],
E [24],
E [24],
E [24],
E [24],
E [45],
E [45],
E [45],
E [45],
E [45],
E [14],
E [14],
E [14],
E [12],
E [12],
E [12],
E [12],
E [12],
E [44],
E [44],
E [44],
E [39],
E [39],
E [39],
E [32],
E [32],
E [32],
E [32],
E [32],
E [ 3],
E [ 3],
E [ 3],
E [ 3],
E [ 3],
E [57],
E [57],
E [57],
E [57]])>, 'country': <tf.Tensor: shape=(52, 1), dtype=int64, numpy=
E array([[21],
E [10],
E [18],
E [ 6],
E [ 5],
E [18],
E [16],
E [13],
E [ 1],
E [23],
E [21],
E [21],
E [21],
E [21],
E [10],
E [10],
E [10],
E [10],
E [10],
E [18],
E [18],
E [18],
E [18],
E [18],
E [ 6],
E [ 6],
E [ 6],
E [ 5],
E [ 5],
E [ 5],
E [ 5],
E [ 5],
E [18],
E [18],
E [18],
E [16],
E [16],
E [16],
E [13],
E [13],
E [13],
E [13],
E [13],
E [ 1],
E [ 1],
E [ 1],
E [ 1],
E [ 1],
E [23],
E [23],
E [23],
E [23]])>, 'item_recency': <tf.Tensor: shape=(52, 1), dtype=float32, numpy=
E array([[0.43368974],
E [0.29481202],
E [0.713164 ],
E [0.02824262],
E [0.8301565 ],
E [0.3268307 ],
E [0.9639159 ],
E [0.35253257],
E [0.19819923],
E [0.6538402 ],
E [0.35253257],
E [0.19819923],
E [0.6538402 ],
E [0.02824262],
E [0.3268307 ],
E [0.713164 ],
E [0.9639159 ],
E [0.3268307 ],
E [0.02824262],
E [0.35253257],
E [0.8301565 ],
E [0.35253257],
E [0.43368974],
E [0.8301565 ],
E [0.8301565 ],
E [0.9639159 ],
E [0.29481202],
E [0.19819923],
E [0.02824262],
E [0.6538402 ],
E [0.713164 ],
E [0.6538402 ],
E [0.6538402 ],
E [0.29481202],
E [0.6538402 ],
E [0.02824262],
E [0.8301565 ],
E [0.3268307 ],
E [0.9639159 ],
E [0.19819923],
E [0.6538402 ],
E [0.43368974],
E [0.9639159 ],
E [0.6538402 ],
E [0.29481202],
E [0.8301565 ],
E [0.3268307 ],
E [0.3268307 ],
E [0.29481202],
E [0.43368974],
E [0.19819923],
E [0.29481202]], dtype=float32)>, 'user_age': <tf.Tensor: shape=(52, 1), dtype=float32, numpy=
E array([[11.],
E [ 5.],
E [ 9.],
E [ 3.],
E [ 3.],
E [ 9.],
E [ 8.],
E [ 7.],
E [ 1.],
E [11.],
E [11.],
E [11.],
E [11.],
E [11.],
E [ 5.],
E [ 5.],
E [ 5.],
E [ 5.],
E [ 5.],
E [ 9.],
E [ 9.],
E [ 9.],
E [ 9.],
E [ 9.],
E [ 3.],
E [ 3.],
E [ 3.],
E [ 3.],
E [ 3.],
E [ 3.],
E [ 3.],
E [ 3.],
E [ 9.],
E [ 9.],
E [ 9.],
E [ 8.],
E [ 8.],
E [ 8.],
E [ 7.],
E [ 7.],
E [ 7.],
E [ 7.],
E [ 7.],
E [ 1.],
E [ 1.],
E [ 1.],
E [ 1.],
E [ 1.],
E [11.],
E [11.],
E [11.],
E [11.]], dtype=float32)>, 'position': <tf.Tensor: shape=(52, 1), dtype=float32, numpy=
E array([[60.],
E [37.],
E [26.],
E [99.],
E [98.],
E [44.],
E [79.],
E [ 7.],
E [67.],
E [41.],
E [60.],
E [60.],
E [60.],
E [60.],
E [37.],
E [37.],
E [37.],
E [37.],
E [37.],
E [26.],
E [26.],
E [26.],
E [26.],
E [26.],
E [99.],
E [99.],
E [99.],
E [98.],
E [98.],
E [98.],
E [98.],
E [98.],
E [44.],
E [44.],
E [44.],
E [79.],
E [79.],
E [79.],
E [ 7.],
E [ 7.],
E [ 7.],
E [ 7.],
E [ 7.],
E [67.],
E [67.],
E [67.],
E [67.],
E [67.],
E [41.],
E [41.],
E [41.],
E [41.]], dtype=float32)>}, targets={'click': <tf.Tensor: shape=(52, 1), dtype=float64, numpy=
E array([[0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [1.],
E [1.],
E [0.],
E [1.],
E [1.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.]])>, 'like': <tf.Tensor: shape=(52, 1), dtype=float64, numpy=
E array([[1.],
E [1.],
E [1.],
E [1.],
E [1.],
E [1.],
E [0.],
E [0.],
E [1.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.],
E [0.]])>, 'play_percentage': <tf.Tensor: shape=(52, 1), dtype=float64, numpy=
E array([[0.76568612],
E [0.97869382],
E [0.81352658],
E [0.42749209],
E [0.55828658],
E [0.02428143],
E [0.30800947],
E [0.22286389],
E [0.99614144],
E [0.37629418],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ],
E [0. ]])>}, sample_weight=None, features=None)
/usr/local/lib/python3.8/dist-packages/keras/engine/data_adapter.py:1579: ValueError
=============================== warnings summary ===============================
../../../../../usr/lib/python3/dist-packages/requests/init.py:89
/usr/lib/python3/dist-packages/requests/init.py:89: RequestsDependencyWarning: urllib3 (1.26.11) or chardet (3.0.4) doesn't match a supported version!
warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "
../../../../../usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:36
/usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:36: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
'nearest': pil_image.NEAREST,
../../../../../usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:37
/usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:37: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
'bilinear': pil_image.BILINEAR,
../../../../../usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:38
/usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:38: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
'bicubic': pil_image.BICUBIC,
../../../../../usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:39
/usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:39: DeprecationWarning: HAMMING is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.HAMMING instead.
'hamming': pil_image.HAMMING,
../../../../../usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:40
/usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:40: DeprecationWarning: BOX is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BOX instead.
'box': pil_image.BOX,
../../../../../usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:41
/usr/local/lib/python3.8/dist-packages/keras/utils/image_utils.py:41: DeprecationWarning: LANCZOS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.
'lanczos': pil_image.LANCZOS,
tests/unit/datasets/test_advertising.py: 1 warning
tests/unit/datasets/test_ecommerce.py: 2 warnings
tests/unit/datasets/test_entertainment.py: 4 warnings
tests/unit/datasets/test_social.py: 1 warning
tests/unit/datasets/test_synthetic.py: 6 warnings
tests/unit/implicit/test_implicit.py: 1 warning
tests/unit/lightfm/test_lightfm.py: 1 warning
tests/unit/tf/test_core.py: 6 warnings
tests/unit/tf/test_dataset.py: 1 warning
tests/unit/tf/blocks/test_cross.py: 5 warnings
tests/unit/tf/blocks/test_dlrm.py: 9 warnings
tests/unit/tf/blocks/test_mlp.py: 26 warnings
tests/unit/tf/blocks/test_optimizer.py: 30 warnings
tests/unit/tf/blocks/retrieval/test_matrix_factorization.py: 2 warnings
tests/unit/tf/blocks/retrieval/test_two_tower.py: 10 warnings
tests/unit/tf/core/test_aggregation.py: 6 warnings
tests/unit/tf/core/test_base.py: 2 warnings
tests/unit/tf/core/test_combinators.py: 10 warnings
tests/unit/tf/core/test_index.py: 8 warnings
tests/unit/tf/core/test_prediction.py: 2 warnings
tests/unit/tf/core/test_transformations.py: 13 warnings
tests/unit/tf/data_augmentation/test_negative_sampling.py: 10 warnings
tests/unit/tf/data_augmentation/test_noise.py: 1 warning
tests/unit/tf/inputs/test_continuous.py: 4 warnings
tests/unit/tf/inputs/test_embedding.py: 19 warnings
tests/unit/tf/inputs/test_tabular.py: 18 warnings
tests/unit/tf/models/test_base.py: 17 warnings
tests/unit/tf/models/test_benchmark.py: 2 warnings
tests/unit/tf/models/test_ranking.py: 34 warnings
tests/unit/tf/models/test_retrieval.py: 60 warnings
tests/unit/tf/prediction_tasks/test_classification.py: 2 warnings
tests/unit/tf/prediction_tasks/test_multi_task.py: 16 warnings
tests/unit/tf/prediction_tasks/test_regression.py: 2 warnings
tests/unit/tf/prediction_tasks/test_retrieval.py: 1 warning
tests/unit/tf/predictions/test_base.py: 5 warnings
tests/unit/tf/predictions/test_classification.py: 7 warnings
tests/unit/tf/predictions/test_dot_product.py: 8 warnings
tests/unit/tf/predictions/test_regression.py: 2 warnings
tests/unit/tf/utils/test_batch.py: 9 warnings
tests/unit/torch/block/test_base.py: 4 warnings
tests/unit/torch/block/test_mlp.py: 1 warning
tests/unit/torch/features/test_continuous.py: 1 warning
tests/unit/torch/features/test_embedding.py: 4 warnings
tests/unit/torch/features/test_tabular.py: 4 warnings
tests/unit/torch/model/test_head.py: 12 warnings
tests/unit/torch/model/test_model.py: 2 warnings
tests/unit/torch/tabular/test_aggregation.py: 6 warnings
tests/unit/torch/tabular/test_transformations.py: 3 warnings
tests/unit/xgb/test_xgboost.py: 18 warnings
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].
warnings.warn(
tests/unit/datasets/test_ecommerce.py: 2 warnings
tests/unit/datasets/test_entertainment.py: 4 warnings
tests/unit/datasets/test_social.py: 1 warning
tests/unit/datasets/test_synthetic.py: 5 warnings
tests/unit/implicit/test_implicit.py: 1 warning
tests/unit/lightfm/test_lightfm.py: 1 warning
tests/unit/tf/test_core.py: 6 warnings
tests/unit/tf/test_dataset.py: 1 warning
tests/unit/tf/blocks/test_cross.py: 5 warnings
tests/unit/tf/blocks/test_dlrm.py: 9 warnings
tests/unit/tf/blocks/test_mlp.py: 26 warnings
tests/unit/tf/blocks/test_optimizer.py: 30 warnings
tests/unit/tf/blocks/retrieval/test_matrix_factorization.py: 2 warnings
tests/unit/tf/blocks/retrieval/test_two_tower.py: 10 warnings
tests/unit/tf/core/test_aggregation.py: 6 warnings
tests/unit/tf/core/test_base.py: 2 warnings
tests/unit/tf/core/test_combinators.py: 10 warnings
tests/unit/tf/core/test_index.py: 3 warnings
tests/unit/tf/core/test_prediction.py: 2 warnings
tests/unit/tf/core/test_transformations.py: 10 warnings
tests/unit/tf/data_augmentation/test_negative_sampling.py: 10 warnings
tests/unit/tf/inputs/test_continuous.py: 4 warnings
tests/unit/tf/inputs/test_embedding.py: 19 warnings
tests/unit/tf/inputs/test_tabular.py: 18 warnings
tests/unit/tf/models/test_base.py: 17 warnings
tests/unit/tf/models/test_benchmark.py: 2 warnings
tests/unit/tf/models/test_ranking.py: 32 warnings
tests/unit/tf/models/test_retrieval.py: 32 warnings
tests/unit/tf/prediction_tasks/test_classification.py: 2 warnings
tests/unit/tf/prediction_tasks/test_multi_task.py: 16 warnings
tests/unit/tf/prediction_tasks/test_regression.py: 2 warnings
tests/unit/tf/predictions/test_base.py: 5 warnings
tests/unit/tf/predictions/test_classification.py: 7 warnings
tests/unit/tf/predictions/test_dot_product.py: 8 warnings
tests/unit/tf/predictions/test_regression.py: 2 warnings
tests/unit/tf/utils/test_batch.py: 7 warnings
tests/unit/torch/block/test_base.py: 4 warnings
tests/unit/torch/block/test_mlp.py: 1 warning
tests/unit/torch/features/test_continuous.py: 1 warning
tests/unit/torch/features/test_embedding.py: 4 warnings
tests/unit/torch/features/test_tabular.py: 4 warnings
tests/unit/torch/model/test_head.py: 12 warnings
tests/unit/torch/model/test_model.py: 2 warnings
tests/unit/torch/tabular/test_aggregation.py: 6 warnings
tests/unit/torch/tabular/test_transformations.py: 2 warnings
tests/unit/xgb/test_xgboost.py: 17 warnings
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.USER: 'user'>, <Tags.ID: 'id'>].
warnings.warn(
tests/unit/datasets/test_ecommerce.py::test_synthetic_aliccp_raw_data
tests/unit/tf/test_dataset.py::test_tf_drp_reset[100-True-10]
tests/unit/tf/test_dataset.py::test_tf_drp_reset[100-True-9]
tests/unit/tf/test_dataset.py::test_tf_drp_reset[100-True-8]
tests/unit/tf/test_dataset.py::test_tf_drp_reset[100-False-10]
tests/unit/tf/test_dataset.py::test_tf_drp_reset[100-False-9]
tests/unit/tf/test_dataset.py::test_tf_drp_reset[100-False-8]
tests/unit/tf/test_dataset.py::test_tf_catname_ordering
tests/unit/tf/test_dataset.py::test_tf_map
/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
warnings.warn(
tests/unit/datasets/test_entertainment.py: 1 warning
tests/unit/implicit/test_implicit.py: 1 warning
tests/unit/lightfm/test_lightfm.py: 1 warning
tests/unit/tf/test_dataset.py: 1 warning
tests/unit/tf/blocks/retrieval/test_matrix_factorization.py: 2 warnings
tests/unit/tf/blocks/retrieval/test_two_tower.py: 2 warnings
tests/unit/tf/core/test_combinators.py: 10 warnings
tests/unit/tf/core/test_prediction.py: 1 warning
tests/unit/tf/data_augmentation/test_negative_sampling.py: 9 warnings
tests/unit/tf/inputs/test_continuous.py: 2 warnings
tests/unit/tf/inputs/test_embedding.py: 9 warnings
tests/unit/tf/inputs/test_tabular.py: 8 warnings
tests/unit/tf/models/test_ranking.py: 16 warnings
tests/unit/tf/models/test_retrieval.py: 4 warnings
tests/unit/tf/prediction_tasks/test_multi_task.py: 16 warnings
tests/unit/xgb/test_xgboost.py: 12 warnings
/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.SESSION_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.SESSION: 'session'>, <Tags.ID: 'id'>].
warnings.warn(
tests/unit/tf/blocks/retrieval/test_matrix_factorization.py::test_matrix_factorization_embedding_export
tests/unit/tf/blocks/retrieval/test_matrix_factorization.py::test_matrix_factorization_embedding_export
tests/unit/tf/blocks/retrieval/test_two_tower.py::test_matrix_factorization_embedding_export
tests/unit/tf/blocks/retrieval/test_two_tower.py::test_matrix_factorization_embedding_export
tests/unit/tf/inputs/test_embedding.py::test_embedding_features_exporting_and_loading_pretrained_initializer
/var/jenkins_home/workspace/merlin_models/models/merlin/models/tf/inputs/embedding.py:807: DeprecationWarning: This function is deprecated in favor of cupy.from_dlpack
embeddings_cupy = cupy.fromDlpack(to_dlpack(tf.convert_to_tensor(embeddings)))
tests/unit/tf/core/test_index.py: 4 warnings
tests/unit/tf/models/test_retrieval.py: 54 warnings
tests/unit/tf/prediction_tasks/test_next_item.py: 3 warnings
tests/unit/tf/predictions/test_classification.py: 12 warnings
tests/unit/tf/predictions/test_dot_product.py: 2 warnings
tests/unit/tf/utils/test_batch.py: 2 warnings
/tmp/autograph_generated_filea9yt19zl.py:8: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
ag.converted_call(ag__.ld(warnings).warn, ("The 'warn' method is deprecated, use 'warning' instead", ag__.ld(DeprecationWarning), 2), None, fscope)
tests/unit/tf/data_augmentation/test_noise.py::test_stochastic_swap_noise[0.1]
tests/unit/tf/data_augmentation/test_noise.py::test_stochastic_swap_noise[0.3]
tests/unit/tf/data_augmentation/test_noise.py::test_stochastic_swap_noise[0.5]
tests/unit/tf/data_augmentation/test_noise.py::test_stochastic_swap_noise[0.7]
tests/unit/tf/models/test_base.py::test_model_pre_post[True]
tests/unit/tf/models/test_base.py::test_model_pre_post[False]
/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:1082: UserWarning: tf.keras.backend.random_binomial is deprecated, and will be removed in a future version.Please use tf.keras.backend.random_bernoulli instead.
return dispatch_target(*args, **kwargs)
tests/unit/tf/models/test_base.py::test_freeze_parallel_block[True]
tests/unit/tf/models/test_base.py::test_freeze_sequential_block
tests/unit/tf/models/test_base.py::test_freeze_unfreeze
tests/unit/tf/models/test_base.py::test_unfreeze_all_blocks
/usr/local/lib/python3.8/dist-packages/keras/optimizers/optimizer_v2/gradient_descent.py:108: UserWarning: The lr argument is deprecated, use learning_rate instead.
super(SGD, self).init(name, **kwargs)
tests/unit/tf/models/test_ranking.py::test_wide_deep_model_wide_categorical_one_hot[False]
tests/unit/tf/models/test_ranking.py::test_wide_deep_model_hashed_cross[False]
tests/unit/tf/models/test_ranking.py::test_wide_deep_embedding_custom_inputblock[False]
/usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model/parallel_block_2/sequential_block_6/sequential_block_5/private__dense_3/dense_3/embedding_lookup_sparse/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model/parallel_block_2/sequential_block_6/sequential_block_5/private__dense_3/dense_3/embedding_lookup_sparse/Reshape:0", shape=(None, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model/parallel_block_2/sequential_block_6/sequential_block_5/private__dense_3/dense_3/embedding_lookup_sparse/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
warnings.warn(
tests/unit/tf/models/test_ranking.py::test_wide_deep_embedding_custom_inputblock[True]
tests/unit/tf/models/test_ranking.py::test_wide_deep_embedding_custom_inputblock[False]
/var/jenkins_home/workspace/merlin_models/models/merlin/models/tf/core/transformations.py:980: UserWarning: Please make sure input features to be categorical, detect user_age has no categorical tag
warnings.warn(
tests/unit/tf/models/test_ranking.py::test_wide_deep_embedding_custom_inputblock[False]
/usr/local/lib/python3.8/dist-packages/tensorflow/python/autograph/impl/api.py:371: UserWarning: Please make sure input features to be categorical, detect user_age has no categorical tag
return py_builtins.overload_of(f)(*args)
tests/unit/tf/models/test_ranking.py::test_wide_deep_model_wide_onehot_multihot_feature_interaction[False]
/usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model/parallel_block_5/sequential_block_9/sequential_block_8/private__dense_3/dense_3/embedding_lookup_sparse/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model/parallel_block_5/sequential_block_9/sequential_block_8/private__dense_3/dense_3/embedding_lookup_sparse/Reshape:0", shape=(None, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model/parallel_block_5/sequential_block_9/sequential_block_8/private__dense_3/dense_3/embedding_lookup_sparse/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
warnings.warn(
tests/unit/tf/models/test_ranking.py::test_wide_deep_model_wide_feature_interaction_multi_optimizer[False]
/usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model/parallel_block_4/sequential_block_6/sequential_block_5/private__dense_3/dense_3/embedding_lookup_sparse/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model/parallel_block_4/sequential_block_6/sequential_block_5/private__dense_3/dense_3/embedding_lookup_sparse/Reshape:0", shape=(None, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model/parallel_block_4/sequential_block_6/sequential_block_5/private__dense_3/dense_3/embedding_lookup_sparse/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
warnings.warn(
tests/unit/torch/block/test_mlp.py::test_mlp_block
/var/jenkins_home/workspace/merlin_models/models/tests/unit/torch/_conftest.py:151: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.)
return {key: torch.tensor(value) for key, value in data.items()}
tests/unit/xgb/test_xgboost.py::test_without_dask_client
tests/unit/xgb/test_xgboost.py::TestXGBoost::test_music_regression
tests/unit/xgb/test_xgboost.py::test_gpu_hist_dmatrix[fit_kwargs0-DaskDeviceQuantileDMatrix]
tests/unit/xgb/test_xgboost.py::test_gpu_hist_dmatrix[fit_kwargs1-DaskDMatrix]
tests/unit/xgb/test_xgboost.py::TestEvals::test_multiple
tests/unit/xgb/test_xgboost.py::TestEvals::test_default
tests/unit/xgb/test_xgboost.py::TestEvals::test_train_and_valid
tests/unit/xgb/test_xgboost.py::TestEvals::test_invalid_data
/var/jenkins_home/workspace/merlin_models/models/merlin/models/xgb/init.py:335: UserWarning: Ignoring list columns as inputs to XGBoost model: ['item_genres', 'user_genres'].
warnings.warn(f"Ignoring list columns as inputs to XGBoost model: {list_column_names}.")
tests/unit/xgb/test_xgboost.py::TestXGBoost::test_unsupported_objective
/usr/local/lib/python3.8/dist-packages/tornado/ioloop.py:350: DeprecationWarning: make_current is deprecated; start the event loop first
self.make_current()
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
SKIPPED [1] tests/unit/datasets/test_advertising.py:20: No data-dir available, pass it through env variable $INPUT_DATA_DIR
SKIPPED [1] tests/unit/datasets/test_ecommerce.py:62: ALI-CCP data is not available, pass it through env variable $DATA_PATH_ALICCP
SKIPPED [1] tests/unit/datasets/test_ecommerce.py:78: ALI-CCP data is not available, pass it through env variable $DATA_PATH_ALICCP
SKIPPED [1] tests/unit/datasets/test_ecommerce.py:92: ALI-CCP data is not available, pass it through env variable $DATA_PATH_ALICCP
SKIPPED [3] tests/unit/datasets/test_entertainment.py:44: No data-dir available, pass it through env variable $INPUT_DATA_DIR
SKIPPED [4] ../../../../../usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/test_util.py:2746: Not a test.
===== 1 failed, 666 passed, 11 skipped, 1011 warnings in 940.22s (0:15:40) =====
Build step 'Execute shell' marked build as failure
Performing Post build task...
Match found for : : True
Logical operation result is TRUE
Running script : #!/bin/bash
cd /var/jenkins_home/
CUDA_VISIBLE_DEVICES=1 python test_res_push.py "https://api.GitHub.com/repos/NVIDIA-Merlin/models/issues/$ghprbPullId/comments" "/var/jenkins_home/jobs/$JOB_NAME/builds/$BUILD_NUMBER/log"
[merlin_models] $ /bin/bash /tmp/jenkins2220967673859041873.sh
This approach doesn't work with the keras data adapter implementation.
A few options include (if we want to remove the awkward param return_tuple) and have a layer work both inside the model and in the data loader:
- conditionally change the return type in a data augmentation layer based on the
training/testingcontext (#703) - make the
BatchedDatasetaware of this potential return type and unpack the values after the transform is called. - For a data augmentation layer that changes inputs/targets - we could instead return a 2-tuple of (transformed_inputs, transformed_targets). And update the model base implementation to handle this case.