keras
keras copied to clipboard
`CompileLoss` truncates multiple loss inputs
Environment
System: Apple M2 MacOS Sonoma 14.4.1
python: 3.11.9
tensorflow: 2.16.1
keras: 3.3.3
Issue
Upon first call of model.compute_loss, CompileLoss.call is triggered, which fails if y_true is e.g. a tuple (y1, y2) and the loss function expects y_true: tuple[tf.Tensor, tf.Tensor], as only the first element y1 gets passed to the loss function.
Details
let's say I have a loss function
def my_loss(y_true: tuple[tf.Tensor, tf.Tensor], y_pred: tf.Tensor)
i.e. the ground truth consists of two tensors (I can't stack them, their size is incompatible) and the prediction from my model is a single tensor.
Let's also say I have a model that takes two inputs x1 and x2 and produces one output y_pred.
Due to the nature of my problem, the input is automatically the ground truth, i.e. y_true=(x1, x2).
So I would like to execute
# [...] model and data creation
model.compile(loss=my_loss)
y_pred = model((x1, x2))
loss = model.compute_loss(y=(x1, x2), y_pred=y_pred)
where the last call causes an exception in my_loss, because only x1 gets passed as y_true to my_loss.
I have pinpointed the problem to the call to y_true = self._flatten_y(y_true) here and the subsequent zip iteration here
The problem seems to be that the zip iteration expects y_true and y_pred to be iterables of the same length as self.flat_losses (which in my case is just 1). Now y_pred = self._flatten_y(y_pred) wraps the single tensor y_pred into a single element list, which is correct. However, y_true = self._flatten_y(y_true) converts y_true from a 2 element tuple to a 2 element list, where it should be a nested list with a single 2-element list [[y1, y2]].
Consequently the zip iteration takes only y1 from y_true in its single iteration (since all other iterables are just length one) and passes it as y_true argument to my_loss.
I imagine this behavior comes from the fact that in cases where one has multiple losses (which take single tensors as y_true and y_pred), the correct way of calling compute_loss is to pass sequences for y_true and y_pred to compute_loss, one for each loss function.
Isn't there a way to reconcile both cases? I.e. sequence inputs to single loss functions, but still supporting multiple loss functions? All the information is there, i.e. how many loss functions and what are their signatures. For starters, it is not checked if all elements in the zip iteration are of same length...
Reproducible Example
Here is a reproducible minimal example
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
def my_loss(y_true: tuple[tf.Tensor, tf.Tensor], y_pred: tf.Tensor):
y1, y2 = y_true
pred_sum = keras.ops.sum(y_pred)
return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(keras.ops.sum(y2) - pred_sum)
def main():
input1 = keras.Input((2,))
input2 = keras.Input((3, 6))
x1 = keras.ops.expand_dims(keras.layers.Dense(10)(input1), 1)
x2 = keras.layers.Dense(10)(input2)
x = keras.ops.sum(x1 + x2, axis=1)
out = keras.layers.Dense(8)(x)
model = keras.Model(inputs=[input1, input2], outputs=out)
model.compile(loss=my_loss)
x1 = tf.random.uniform((10, 2))
x2 = tf.random.uniform((10, 3, 6))
y_pred = model((x1, x2))
loss1 = my_loss((x1, x2), y_pred)
loss2 = model.compute_loss(y=(x1, x2), y_pred=y_pred)
print(loss1)
print(loss2)
if __name__ == "__main__":
main()
which produces the following error
/Users/valentin/miniconda3/envs/dif/bin/p
ython /Users/valentin/parity/python/diffusion/src/discrete_diffusion/loss_problem.py
Traceback (most recent call last):
File "/Users/valentin/parity/python/diffusion/src/discrete_diffusion/loss_problem.py", line 34, in <module>
main()
File "/Users/valentin/parity/python/diffusion/src/discrete_diffusion/loss_problem.py", line 27, in main
loss2 = model.compute_loss(y=(x1, x2), y_pred=y_pred)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/trainers/trainer.py", line 316, in compute_loss
loss = self._compile_loss(y, y_pred, sample_weight)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/trainers/compile_utils.py", line 609, in __call__
return self.call(y_true, y_pred, sample_weight)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/trainers/compile_utils.py", line 645, in call
loss(y_t, y_p, sample_weight), dtype=backend.floatx()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/losses/loss.py", line 43, in __call__
losses = self.call(y_true, y_pred)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/losses/losses.py", line 22, in call
return self.fn(y_true, y_pred, **self._fn_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/valentin/parity/python/diffusion/src/discrete_diffusion/loss_problem.py", line 7, in my_loss
y1, y2 = y_true
^^^^^^
Many thanks for all the great effort on keras, I greatly appreciate all the awesome features!
The same problem persists if you use dict inputs btw.
def my_loss(y_true: dict[str, tf.Tensor], y_pred: tf.Tensor):
y1 = y_true["x1"]
y2 = y_true["x2"]
pred_sum = keras.ops.sum(y_pred)
return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(keras.ops.sum(y2) - pred_sum)
...
model.compute_loss(y={"x1": x1, "x2": x2}, y_pred=y_pred)
...
...
There it's even worse, as the dict gets converted to a (sorted) list by y_true = self._flatten_y(y_true). So if you coded your loss expecting a dict input, it now receives a list, which will again lead to problems.
Hi @Darkdragon84
Happy to hear your thoughts on #19879
With that PR and a few small changes, your code works:
import keras
class MyLoss(keras.Loss):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def call(self, y_true, y_pred):
y1, y2 = y_true
pred_sum = keras.ops.sum(y_pred)
return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(
keras.ops.sum(y2) - pred_sum
)
def main():
input1 = keras.Input((2,))
input2 = keras.Input((3, 6))
x1 = keras.ops.expand_dims(keras.layers.Dense(10)(input1), 1)
x2 = keras.layers.Dense(10)(input2)
x = keras.ops.sum(x1 + x2, axis=1)
out = keras.layers.Dense(8)(x)
model = keras.Model(inputs=[input1, input2], outputs=out)
my_loss = MyLoss()
my_loss.set_specs([input1, input2], out) # <-- newly introduced feature
model.compile(loss=my_loss)
x1 = keras.random.uniform((10, 2))
x2 = keras.random.uniform((10, 3, 6))
y_pred = model((x1, x2))
loss1 = my_loss((x1, x2), y_pred)
print(loss1)
loss2 = model.compute_loss(y=(x1, x2), y_pred=y_pred)
print(loss2)
if __name__ == "__main__":
main()
tf.Tensor(68.81332, shape=(), dtype=float32)
tf.Tensor(68.81332, shape=(), dtype=float32)
Hi @james77777778
Wow, fantastic, that PR looks great. The changes are way above my understanding of the inner workings of Keras, but I'll sure have a look! Thanks a lot for the quick response in form of a fix PR :clap:
This is fixed with the new structured loss feature provided that y_true and y_pred have the same structure:
import numpy as np
import tensorflow as tf
import keras
def my_loss(y_true: tuple[tf.Tensor, tf.Tensor], y_pred: tuple[tf.Tensor, tf.Tensor]):
y1, y2 = y_true
y_pred, _ = y_pred
pred_sum = keras.ops.sum(y_pred)
return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(keras.ops.sum(y2) - pred_sum)
def main():
input1 = keras.Input((2,))
input2 = keras.Input((3, 6))
x1 = keras.ops.expand_dims(keras.layers.Dense(10)(input1), 1)
x2 = keras.layers.Dense(10)(input2)
x = keras.ops.sum(x1 + x2, axis=1)
out = keras.layers.Dense(8)(x)
model = keras.Model(inputs=[input1, input2], outputs=(out, out))
model.compile(loss=[my_loss])
x1 = tf.random.uniform((10, 2))
x2 = tf.random.uniform((10, 3, 6))
y_pred = model((x1, x2))
loss1 = my_loss((x1, x2), y_pred)
loss2 = model.compute_loss(y=(x1, x2), y_pred=y_pred)
print(loss1)
print(loss2)
if __name__ == "__main__":
main()
74.80203
74.80203
awesome, thanks @nicolaspi :raised_hands: