keras-cv
keras-cv copied to clipboard
Add Dice Loss
Close https://github.com/keras-team/keras-cv/issues/296
Ps. Tested on actual training, works fine so far for 2D and 3D cases; now needs some polish.
Thanks. Will take a look. (Do u still want to keep it as a draft or make it a formal PR?)
@qlzh727 This PR can be reviewed now. Though I will add some more test cases in the dice_test.py
, feel free to provide some initial comments. Thanks.
I can see axis=[1,2]
in this and some other classes in this repository.
Don't know if this has been discussed somewhere else before, but I would like to suggest that we assumed all axes contained positional information, except for the batch
and channels
axes. This way, every single method/procedure would handle multi-dimentional cases (2D, 3D, etc) transparently.
This can be achieved by reshaping the n-d signal into a (batch, positions, channels)
one, and reducing over positions
:
def some_metric(y_true, y_pred, axis=None):
if axis is None:
y_true = squeeze(y_true, dims=3) # 2d, 3d, 4d, ..., nd -> 3d
y_pred = squeeze(y_pred, dims=3) # 2d, 3d, 4d, ..., nd -> 3d
axis = 1
# calculate metric result
# ...
return tf.reduce_mean(result, axis=axis)
def squeeze(y, dims: int = 3):
shape = tf.shape(y)
if dims not in (2, 3):
raise ValueError(f'Illegal value for parameter dims=`{dims}`. Can only squeeze '
'positional signal, resulting in a tensor with rank 2 or 3.')
new_shape = [shape[0], -1]
if dims == 3: # keep channels.
new_shape += [shape[-1]]
return tf.reshape(y, new_shape)
Thanks for your patience @innat, wanted to merge a Loss first before this. Ill revisit this now that we have done don
This is the oldest open PR. What is its status?
This is the oldest open PR. What is its status?
To be honest it simply hasn’t been prioritized due to it not being a part of one of our flagship workflows (OD w/ RetinaNet, classification training in a reusable template with SOTA augmentation, and segmentation map prediction with model TBD)
one thing to double check is that the loss reduction matches the construct set forth in the rest of the losses (see smoothL1, focal as examples).
we can revisit this, but at the moment we are focusing effort on delivering those flagship workflows.
This is the oldest open PR. What is its status?
To be honest it simply hasn’t been prioritized due to it not being a part of one of our flagship workflows (OD w/ RetinaNet, classification training in a reusable template with SOTA augmentation, and segmentation map prediction with model TBD)
one thing to double check is that the loss reduction matches the construct set forth in the rest of the losses (see smoothL1, focal as examples).
we can revisit this, but at the moment we are focusing effort on delivering those flagship workflows.
That explanation said @innat, feel free to rebase and re request a review. It’s almost ready so happy to prioritize this.
@LukeWood I may not able to work on it due to tight schedule. If it's prioritized, someone may need to take it from here. ( Same goes to Jaccard PR. )
Can I take over this one and #449?
Awesome, thanks! I'll be working on these in the upcoming days then. What's left is the documentation, test cases and implementing the suggestions from the comments, right?
Awesome, thanks! I'll be working on these in the upcoming days then. What's left is the documentation, test cases and implementing the suggestions from the comments, right?
Sort of.
@innat Could you share the training script you mentioned if you still have it?
I am not sure which one you're referring to. Could you please more precise? ( I'm in leave, so if I've anything that is related, I will share in upcoming days.)
Oh, I meant this:
Ps. Tested on actual training, works fine so far for 2D and 3D cases; now needs some polish.
Sorry to bug you while you're on leave - enjoy your time off! :)
Of course! I'll get an example up and running as soon as I finish the test cases for the other PR
@tanzhenyu added an example run in the new PR #968 based on your training script for DeepLabV3