keras-io
keras-io copied to clipboard
Migrating supervised contrastive learning example to Keras 3
trafficstars
This PR changes the Supervised Contrastive Learning example to keras 3.0 [TF-Only Example] as requested in https://github.com/keras-team/keras-cv/issues/2211.
For example, here is the notebook link provided: https://colab.research.google.com/drive/1QGiil-RpO55UNESBkilNtYETuX3B_4ZF?usp=sharing
cc: @divyashreepathihalli @fchollet
The following describes the Git difference for the changed files:
Changes:
diff --git a/examples/vision/supervised-contrastive-learning.py b/examples/vision/supervised-contrastive-learning.py
index 4803e671..6b45d568 100644
--- a/examples/vision/supervised-contrastive-learning.py
+++ b/examples/vision/supervised-contrastive-learning.py
@@ -20,22 +20,20 @@ Learning is performed in two phases:
that representations of images in the same class will be more similar compared to
representations of images in different classes.
2. Training a classifier on top of the frozen encoder.
+"""
-Note that this example requires [TensorFlow Addons](https://www.tensorflow.org/addons),
-which you can install using the following command:
-
-```python
-pip install tensorflow-addons
-```
-
+"""
## Setup
"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+from keras import ops
+from keras import layers
import tensorflow as tf
-import tensorflow_addons as tfa
import numpy as np
-from tensorflow import keras
-from tensorflow.keras import layers
"""
## Prepare the data
@@ -159,6 +157,19 @@ softmax are optimized.
"""
+def npairs_loss(y_true, y_pred):
+ y_pred = ops.convert_to_tensor(y_pred)
+ y_true = ops.cast(y_true, dtype="uint32")
+
+ # Expand to [batch_size, 1]
+ y_true = ops.expand_dims(y_true, -1)
+ y_true = ops.cast(ops.equal(y_true, ops.transpose(y_true)), dtype="uint32")
+ y_true /= ops.sum(y_true, 1, keepdims=True)
+
+ loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=y_true)
+ return ops.mean(loss)
+
+
class SupervisedContrastiveLoss(keras.losses.Loss):
def __init__(self, temperature=1, name=None):
super().__init__(name=name)
@@ -174,7 +185,7 @@ class SupervisedContrastiveLoss(keras.losses.Loss):
),
self.temperature,
)
- return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
+ return npairs_loss(ops.squeeze(labels), logits)
def add_projection_head(encoder):
(END)