keras-io icon indicating copy to clipboard operation
keras-io copied to clipboard

Migrating Enhanced Deep Residual Networks for single-image super-resolution example to Keras 3 (TF-Only)

Open sitammeur opened this issue 1 year ago • 3 comments
trafficstars

This PR changes the Enhanced Deep Residual Networks for single-image super-resolution example to keras 3.0 [TF Only Backend] as requested in https://github.com/keras-team/keras-cv/issues/2211.

For example, here is the notebook link provided: https://colab.research.google.com/drive/1Dab2RwWyZkdJB7fhDU7rfyzqCTvJH3sZ?usp=sharing

cc: @divyashreepathihalli @fchollet

The following describes the Git difference for the changed files:

Changes:
diff --git a/examples/vision/edsr.py b/examples/vision/edsr.py
index c4e38d63..96ef0443 100644
--- a/examples/vision/edsr.py
+++ b/examples/vision/edsr.py
@@ -41,13 +41,17 @@ Comparison Graph:
 ## Imports
 """
 
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
 import numpy as np
-import tensorflow as tf
 import tensorflow_datasets as tfds
 import matplotlib.pyplot as plt
 
-from tensorflow import keras
-from tensorflow.keras import layers
+import keras
+from keras import ops
+from keras import layers
+import tensorflow as tf
 
 AUTOTUNE = tf.data.AUTOTUNE
 
@@ -81,7 +85,7 @@ def flip_left_right(lowres_img, highres_img):
     """Flips Images to left and right."""
 
     # Outputs random values from a uniform distribution in between 0 to 1
-    rn = tf.random.uniform(shape=(), maxval=1)
+    rn = keras.random.uniform(shape=(), maxval=1)
     # If rn is less than 0.5 it returns original lowres_img and highres_img
     # If rn is greater than 0.5 it returns flipped image
     return tf.cond(
@@ -98,7 +102,7 @@ def random_rotate(lowres_img, highres_img):
     """Rotates Images by 90 degrees."""
 
     # Outputs random values from uniform distribution in between 0 to 4
-    rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
+    rn = ops.cast(keras.random.uniform(shape=(), maxval=4, dtype="float32"), dtype="int32")
     # Here rn signifies number of times the image(s) are rotated by 90 degrees
     return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)
 
@@ -110,14 +114,14 @@ def random_crop(lowres_img, highres_img, hr_crop_size=96, scale=4):
     high resolution images: 96x96
     """
     lowres_crop_size = hr_crop_size // scale  # 96//4=24
-    lowres_img_shape = tf.shape(lowres_img)[:2]  # (height,width)
+    lowres_img_shape = ops.shape(lowres_img)[:2]  # (height,width)
 
-    lowres_width = tf.random.uniform(
-        shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=tf.int32
-    )
-    lowres_height = tf.random.uniform(
-        shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=tf.int32
-    )
+    lowres_width = ops.cast(keras.random.uniform(
+        shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype="float32"
+    ), dtype="int32")
+    lowres_height = ops.cast(keras.random.uniform(
+        shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype="float32"
+    ), dtype="int32")
 
     highres_width = lowres_width * scale
     highres_height = lowres_height * scale
@@ -218,7 +222,7 @@ memory as the preceding convolutional layers.
 """
 
 
-class EDSRModel(tf.keras.Model):
+class EDSRModel(keras.Model):
     def train_step(self, data):
         # Unpack the data. Its structure depends on your model and
         # on what you pass to `fit()`.
@@ -242,16 +246,16 @@ class EDSRModel(tf.keras.Model):
 
     def predict_step(self, x):
         # Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast
-        x = tf.cast(tf.expand_dims(x, axis=0), tf.float32)
+        x = ops.cast(ops.expand_dims(x, axis=0), dtype="float32")
         # Passing low resolution image to model
         super_resolution_img = self(x, training=False)
         # Clips the tensor from min(0) to max(255)
-        super_resolution_img = tf.clip_by_value(super_resolution_img, 0, 255)
+        super_resolution_img = ops.clip(super_resolution_img, 0, 255)
         # Rounds the values of a tensor to the nearest integer
-        super_resolution_img = tf.round(super_resolution_img)
+        super_resolution_img = ops.round(super_resolution_img)
         # Removes dimensions of size 1 from the shape of a tensor and converting to uint8
-        super_resolution_img = tf.squeeze(
-            tf.cast(super_resolution_img, tf.uint8), axis=0
+        super_resolution_img = ops.squeeze(
+            ops.cast(super_resolution_img, dtype="uint8"), axis=0
         )
         return super_resolution_img
 
@@ -265,12 +269,19 @@ def ResBlock(inputs):
 
 
 # Upsampling Block
-def Upsampling(inputs, factor=2, **kwargs):
-    x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(inputs)
-    x = tf.nn.depth_to_space(x, block_size=factor)
-    x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(x)
-    x = tf.nn.depth_to_space(x, block_size=factor)
-    return x
+class UpsamplingLayer(layers.Layer):
+    def __init__(self, factor=2, **kwargs):
+        super(UpsamplingLayer, self).__init__(**kwargs)
+        self.factor = factor
+        self.conv1 = layers.Conv2D(64 * (factor ** 2), 3, padding="same")
+        self.conv2 = layers.Conv2D(64 * (factor ** 2), 3, padding="same")
+
+    def call(self, inputs):
+        x = self.conv1(inputs)
+        x = tf.nn.depth_to_space(x, block_size=self.factor)
+        x = self.conv2(x)
+        x = tf.nn.depth_to_space(x, block_size=self.factor)
+        return x
 
 
 def make_model(num_filters, num_of_residual_blocks):
(END)

sitammeur avatar Jan 30 '24 15:01 sitammeur