Saving the model in one file every epoch

1 year ago

Help to understand, please. I changed the code a little. It works fine. Better quality images are generated every epoch. But, how can I save the model and weights in one file, for example h5, and then load it in a new project without first describing the model? If I load weights into a new model, the output is just noise. I repeat, during training everything is fine, the pictures are normal. Here is the updated code

import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from keras import layers
import os
from PIL import Image
import os
import wandb
from wandb.keras import WandbCallback
from keras.callbacks import ReduceLROnPlateau
import pathlib
import shutil
import datetime
import pickle

UTC = +3
time_now = ( + datetime.timedelta(hours=UTC)).strftime('%Y-%m-%d_%H-%M-%S')
root_path = "WinHost/"
results_path = root_path +'results/' + time_now + '/'

shutil.copy2(root_path +'main.ipynb', results_path + '/main.ipynb')

wandb.init(project="FaceGan_DDIM", name=""+time_now)

dataset_repetitions = 3
num_epochs = 1000  # train for at least 50 epochs for good results
height = 256
width = 192
# KID = Kernel Inception Distance, see related section
kid_image_size = height #75
kid_diffusion_steps = 5
plot_diffusion_steps = 128 #20

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# architecture
embedding_dims = 128 #32
embedding_max_frequency = 1000.0
widths = [32, 64, 128, 256, 512, 1024] #[32, 64, 96, 128]
block_depth = 2 #2

# optimization
batch_size = 4
ema = 0.999
learning_rate = 1e-4
weight_decay = 1e-4

def create_dataset_from_directory(directory, height, width):
	path = pathlib.Path(directory)
	if not path.is_dir():
		raise ValueError(f"{directory} is not a valid directory.")
	files_count = len(list(path.glob('*.jpeg'))) + len(list(path.glob('*.jpg'))) + len(list(path.glob('*.png')))
	if files_count == 0:
		raise ValueError(f"No images found in the '{directory}' directory with '*.jpg', '*.jpeg' or '*.png' extensions.")
	files_pattern = os.path.join(directory, '*.*')
	dataset =, shuffle=True)
	dataset = x: tf.image.decode_image(, channels=3, expand_animations=False))
	dataset = x: tf.image.resize(x, size=[height, width], antialias=True))
	dataset = x: tf.image.random_flip_left_right(x))
	dataset = x: tf.clip_by_value(x / 255.0, 0.0, 1.0))
	dataset = dataset.cache()
	dataset = dataset.shuffle(files_count)
	dataset = dataset.repeat(dataset_repetitions)
	dataset = dataset.batch(batch_size, drop_remainder=True)
	dataset = dataset.prefetch(
	return dataset

#image_directory = "WinHost/faces"
train_dataset = create_dataset_from_directory("WinHost/faces", height, width)
val_dataset = create_dataset_from_directory("WinHost/faces_val", height, width)

class KID(keras.metrics.Metric):
	def __init__(self, name, **kwargs):
		super().__init__(name=name, **kwargs)

		self.kid_tracker = keras.metrics.Mean(name="kid_tracker")

		self.encoder = keras.Sequential(
				keras.Input(shape=(height, width, 3)),
				layers.Resizing(height=kid_image_size, width=kid_image_size),
					input_shape=(kid_image_size, kid_image_size, 3),

	def polynomial_kernel(self, features_1, features_2):
		feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype=tf.float32)
		return (features_1 @ tf.transpose(features_2) / feature_dimensions + 1.0) ** 3.0

	def update_state(self, real_images, generated_images, sample_weight=None):
		real_features = self.encoder(real_images, training=False)
		generated_features = self.encoder(generated_images, training=False)

		# compute polynomial kernels using the two sets of features
		kernel_real = self.polynomial_kernel(real_features, real_features)
		kernel_generated = self.polynomial_kernel(
			generated_features, generated_features
		kernel_cross = self.polynomial_kernel(real_features, generated_features)

		# estimate the squared maximum mean discrepancy using the average kernel values
		batch_size = tf.shape(real_features)[0]
		batch_size_f = tf.cast(batch_size, dtype=tf.float32)
		mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / (
			batch_size_f * (batch_size_f - 1.0)
		mean_kernel_generated = tf.reduce_sum(
			kernel_generated * (1.0 - tf.eye(batch_size))
		) / (batch_size_f * (batch_size_f - 1.0))
		mean_kernel_cross = tf.reduce_mean(kernel_cross)
		kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross

		# update the average KID estimate

	def result(self):
		return self.kid_tracker.result()

	def reset_state(self):

def sinusoidal_embedding(x):
	embedding_min_frequency = 1.0
	frequencies = tf.exp(
			embedding_dims // 2,
	angular_speeds = 2.0 * math.pi * frequencies
	embeddings = tf.concat(
		[tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3
	return embeddings

def ResidualBlock(width):
	def apply(x):
		input_width = x.shape[3]
		if input_width == width:
			residual = x
			residual = layers.Conv2D(width, kernel_size=1)(x)
		x = layers.BatchNormalization(center=False, scale=False)(x)
		x = layers.Conv2D(
			width, kernel_size=3, padding="same", activation=keras.activations.swish
		x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
		x = layers.Add()([x, residual])
		return x

	return apply

def DownBlock(width, block_depth):
	def apply(x):
		x, skips = x
		for _ in range(block_depth):
			x = ResidualBlock(width)(x)
		x = layers.AveragePooling2D(pool_size=2)(x)
		return x

	return apply

def UpBlock(width, block_depth):
	def apply(x):
		x, skips = x
		x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
		for _ in range(block_depth):
			x = layers.Concatenate()([x, skips.pop()])
			x = ResidualBlock(width)(x)
		return x

	return apply

def get_network(height, width, widths, block_depth):
	noisy_images = keras.Input(shape=(height, width, 3))
	noise_variances = keras.Input(shape=(1, 1, 1))

	e = layers.Lambda(sinusoidal_embedding)(noise_variances)
	e = layers.UpSampling2D(size=(height, width), interpolation="nearest")(e)

	x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
	x = layers.Concatenate()([x, e])

	skips = []
	for width in widths[:-1]:
		x = DownBlock(width, block_depth)([x, skips])

	for _ in range(block_depth):
		x = ResidualBlock(widths[-1])(x)

	for width in reversed(widths[:-1]):
		x = UpBlock(width, block_depth)([x, skips])

	x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)

	return keras.Model([noisy_images, noise_variances], x, name="residual_unet")

class DiffusionModel(keras.Model):
	def __init__(self, height, width, widths, block_depth):
		self.height = height
		self.width = width
		self.widths = widths
		self.block_depth = block_depth

		self.normalizer = layers.Normalization() = get_network(height, width, widths, block_depth)  
		self.ema_network = keras.models.clone_model(

		self.mean = None  # Added.
		self.variance = None  # Added.
	def call(self, inputs, *, training=None):
		noisy_images, noise_variances = inputs
		return[noisy_images, noise_variances], training=training)
	def compile(self, **kwargs):

		self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
		self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
		self.kid = KID(name="kid")

	def metrics(self):
		return [self.noise_loss_tracker, self.image_loss_tracker, self.kid]

	def denormalize(self, images):
		# convert the pixel values back to 0-1 range
		images = self.normalizer.mean + images * self.normalizer.variance**0.5
		return tf.clip_by_value(images, 0.0, 1.0)

	def diffusion_schedule(self, diffusion_times):
		# diffusion times -> angles
		start_angle = tf.acos(max_signal_rate)
		end_angle = tf.acos(min_signal_rate)

		diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)

		# angles -> signal and noise rates
		signal_rates = tf.cos(diffusion_angles)
		noise_rates = tf.sin(diffusion_angles)
		# note that their squared sum is always: sin^2(x) + cos^2(x) = 1

		return noise_rates, signal_rates

	def denoise(self, noisy_images, noise_rates, signal_rates, training):
		# the exponential moving average weights are used at evaluation
		if training:
			network =
			network = self.ema_network

		# predict noise component and calculate the image component using it
		pred_noises = network([noisy_images, noise_rates**2], training=training)
		pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates

		return pred_noises, pred_images

	def reverse_diffusion(self, initial_noise, diffusion_steps):
		num_images = initial_noise.shape[0]
		step_size = 1.0 / diffusion_steps

		next_noisy_images = initial_noise
		for step in range(diffusion_steps):
			noisy_images = next_noisy_images

			# separate the current noisy image to its components
			diffusion_times = tf.ones((num_images, 1, 1, 1)) - step * step_size
			noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
			pred_noises, pred_images = self.denoise(
				noisy_images, noise_rates, signal_rates, training=False
			# network used in eval mode

			# remix the predicted components using the next signal and noise rates
			next_diffusion_times = diffusion_times - step_size
			next_noise_rates, next_signal_rates = self.diffusion_schedule(
			next_noisy_images = (
				next_signal_rates * pred_images + next_noise_rates * pred_noises
			# this new noisy image will be used in the next step

		return pred_images

	def generate(self, num_images, diffusion_steps):
		# noise -> images -> denormalized images
		initial_noise = tf.random.normal(shape=(num_images, height, width, 3))
		generated_images = self.reverse_diffusion(initial_noise, diffusion_steps)
		generated_images = self.denormalize(generated_images)
		return generated_images

	def train_step(self, images):
		# normalize images to have standard deviation of 1, like the noises
		images = self.normalizer(images, training=True)
		noises = tf.random.normal(shape=(batch_size, height, width, 3))

		# sample uniform random diffusion times
		diffusion_times = tf.random.uniform(
			shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
		noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
		# mix the images with noises accordingly
		noisy_images = signal_rates * images + noise_rates * noises

		with tf.GradientTape() as tape:
			# train the network to separate noisy images to their components
			pred_noises, pred_images = self.denoise(
				noisy_images, noise_rates, signal_rates, training=True

			noise_loss = self.loss(noises, pred_noises)  # used for training
			image_loss = self.loss(images, pred_images)  # only used as metric

		gradients = tape.gradient(noise_loss,


		# track the exponential moving averages of weights
		for weight, ema_weight in zip(, self.ema_network.weights):
			ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

		# KID is not measured during the training phase for computational efficiency
		return { m.result() for m in self.metrics[:-1]}

	def test_step(self, images):
		# normalize images to have standard deviation of 1, like the noises
		images = self.normalizer(images, training=False)
		noises = tf.random.normal(shape=(batch_size, height, width, 3))

		# sample uniform random diffusion times
		diffusion_times = tf.random.uniform(
			shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
		noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
		# mix the images with noises accordingly
		noisy_images = signal_rates * images + noise_rates * noises

		# use the network to separate noisy images to their components
		pred_noises, pred_images = self.denoise(
			noisy_images, noise_rates, signal_rates, training=False

		noise_loss = self.loss(noises, pred_noises)
		image_loss = self.loss(images, pred_images)


		images = self.denormalize(images)
		generated_images = self.generate(
			num_images=batch_size, diffusion_steps=kid_diffusion_steps
		self.kid.update_state(images, generated_images)

		return { m.result() for m in self.metrics}

	def plot_images(self, epoch=None,  folder="", num_rows=3, num_cols=4):
		# Создаем каталог, если он еще не существует
		os.makedirs(folder, exist_ok=True)

		# plot random generated images for visual evaluation of generation quality
		generated_images = self.generate(
			num_images=num_rows * num_cols,

		# Calculate dimensions of the final concatenated image
		img_rows, img_cols, _ = generated_images[0].shape
		concat_img ='RGB', (img_cols*num_cols, img_rows*num_rows))

		# Concatenate the generated images into one large image
		for row in range(num_rows):
			for col in range(num_cols):
				index = row * num_cols + col
				img = (generated_images[index].numpy() * 255).astype('uint8')
				img = Image.fromarray(img)
				concat_img.paste(img, (col * img_cols, row * img_rows))

		img_path = os.path.join(folder, f'image_{epoch+1:05}.png')

model = DiffusionModel(height, width, widths, block_depth)

reduce_lr_on_plateau = ReduceLROnPlateau(
	monitor="val_kid",  # мониторинг метрики валидации KID
	factor=0.5,         # уменьшает скорость обучения в factor раз
	patience=5,         # сколько эпох ожидается перед понижением
	verbose=1,          # активировать вывод логов
	mode="min",         # цель минимизировать метрику val_kid
	min_lr=1e-8,        # минимальная скорость обучения

class LRTensorBoard(keras.callbacks.Callback):
	def on_epoch_end(self, epoch, logs=None):
		logs = logs or {}
		logs['lr'] =
		wandb.log({"lr":}) # Python 3.8+, если версия Python < 3.8 заменить на tf.compat.v1.keras.backend.get_value(

lr_logger = LRTensorBoard()

		learning_rate=learning_rate, weight_decay=weight_decay

def save_model_callback(epoch, logs):
	if epoch % 1 == 0:, f"epoch_{epoch+1:05}.h5")) results_path + f"epoch_{epoch+1:05}" ,save_format='tf') , f"epoch_{epoch+1:05}.h5"))
		with open(results_path + f"epoch_{epoch+1:05}_normalizer.pkl", "wb") as f:
			pickle.dump({'mean': model.normalizer.mean.numpy(), 'variance': model.normalizer.variance.numpy()}, f)

def generate_callback(epoch, logs):
	if epoch % 1 == 0: 
		model.plot_images(epoch, results_path)
training_callback = keras.callbacks.LambdaCallback(
	on_epoch_end=lambda epoch, logs: (
		generate_callback(epoch, logs),
		save_model_callback(epoch, logs),

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
	filepath=results_path + "checkpoints/checkpoint",

model.mean = model.normalizer.mean
model.variance = model.normalizer.variance


