keras-io
keras-io copied to clipboard
Migrating Enhanced Deep Residual Networks for single-image super-resolution example to Keras 3 (TF-Only)
trafficstars
This PR changes the Enhanced Deep Residual Networks for single-image super-resolution example to keras 3.0 [TF Only Backend] as requested in https://github.com/keras-team/keras-cv/issues/2211.
For example, here is the notebook link provided: https://colab.research.google.com/drive/1Dab2RwWyZkdJB7fhDU7rfyzqCTvJH3sZ?usp=sharing
cc: @divyashreepathihalli @fchollet
The following describes the Git difference for the changed files:
Changes:
diff --git a/examples/vision/edsr.py b/examples/vision/edsr.py
index c4e38d63..96ef0443 100644
--- a/examples/vision/edsr.py
+++ b/examples/vision/edsr.py
@@ -41,13 +41,17 @@ Comparison Graph:
## Imports
"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
-import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
-from tensorflow import keras
-from tensorflow.keras import layers
+import keras
+from keras import ops
+from keras import layers
+import tensorflow as tf
AUTOTUNE = tf.data.AUTOTUNE
@@ -81,7 +85,7 @@ def flip_left_right(lowres_img, highres_img):
"""Flips Images to left and right."""
# Outputs random values from a uniform distribution in between 0 to 1
- rn = tf.random.uniform(shape=(), maxval=1)
+ rn = keras.random.uniform(shape=(), maxval=1)
# If rn is less than 0.5 it returns original lowres_img and highres_img
# If rn is greater than 0.5 it returns flipped image
return tf.cond(
@@ -98,7 +102,7 @@ def random_rotate(lowres_img, highres_img):
"""Rotates Images by 90 degrees."""
# Outputs random values from uniform distribution in between 0 to 4
- rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
+ rn = ops.cast(keras.random.uniform(shape=(), maxval=4, dtype="float32"), dtype="int32")
# Here rn signifies number of times the image(s) are rotated by 90 degrees
return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)
@@ -110,14 +114,14 @@ def random_crop(lowres_img, highres_img, hr_crop_size=96, scale=4):
high resolution images: 96x96
"""
lowres_crop_size = hr_crop_size // scale # 96//4=24
- lowres_img_shape = tf.shape(lowres_img)[:2] # (height,width)
+ lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)
- lowres_width = tf.random.uniform(
- shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=tf.int32
- )
- lowres_height = tf.random.uniform(
- shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=tf.int32
- )
+ lowres_width = ops.cast(keras.random.uniform(
+ shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype="float32"
+ ), dtype="int32")
+ lowres_height = ops.cast(keras.random.uniform(
+ shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype="float32"
+ ), dtype="int32")
highres_width = lowres_width * scale
highres_height = lowres_height * scale
@@ -218,7 +222,7 @@ memory as the preceding convolutional layers.
"""
-class EDSRModel(tf.keras.Model):
+class EDSRModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
@@ -242,16 +246,16 @@ class EDSRModel(tf.keras.Model):
def predict_step(self, x):
# Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast
- x = tf.cast(tf.expand_dims(x, axis=0), tf.float32)
+ x = ops.cast(ops.expand_dims(x, axis=0), dtype="float32")
# Passing low resolution image to model
super_resolution_img = self(x, training=False)
# Clips the tensor from min(0) to max(255)
- super_resolution_img = tf.clip_by_value(super_resolution_img, 0, 255)
+ super_resolution_img = ops.clip(super_resolution_img, 0, 255)
# Rounds the values of a tensor to the nearest integer
- super_resolution_img = tf.round(super_resolution_img)
+ super_resolution_img = ops.round(super_resolution_img)
# Removes dimensions of size 1 from the shape of a tensor and converting to uint8
- super_resolution_img = tf.squeeze(
- tf.cast(super_resolution_img, tf.uint8), axis=0
+ super_resolution_img = ops.squeeze(
+ ops.cast(super_resolution_img, dtype="uint8"), axis=0
)
return super_resolution_img
@@ -265,12 +269,19 @@ def ResBlock(inputs):
# Upsampling Block
-def Upsampling(inputs, factor=2, **kwargs):
- x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(inputs)
- x = tf.nn.depth_to_space(x, block_size=factor)
- x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(x)
- x = tf.nn.depth_to_space(x, block_size=factor)
- return x
+class UpsamplingLayer(layers.Layer):
+ def __init__(self, factor=2, **kwargs):
+ super(UpsamplingLayer, self).__init__(**kwargs)
+ self.factor = factor
+ self.conv1 = layers.Conv2D(64 * (factor ** 2), 3, padding="same")
+ self.conv2 = layers.Conv2D(64 * (factor ** 2), 3, padding="same")
+
+ def call(self, inputs):
+ x = self.conv1(inputs)
+ x = tf.nn.depth_to_space(x, block_size=self.factor)
+ x = self.conv2(x)
+ x = tf.nn.depth_to_space(x, block_size=self.factor)
+ return x
def make_model(num_filters, num_of_residual_blocks):
(END)