STN.keras
STN.keras copied to clipboard
How to use this code to recognize mnist of size 28 * 28?
Hi, thanks for providing the code. I have run the code and the result worked out well ,but when I try to transform the code to recognize mnist of size 28 * 28 I encounter some problems. Really appreciate if anyone could help. Here is my code.
import keras.backend as K
from keras.datasets import mnist
from keras.optimizers import Adam
from src.models import STN
import matplotlib.pyplot as plt
import keras as k
num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[:50000]
y_train = y_train[:50000]
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1))
x_test = x_test.reshape((x_test.shape[0], 28, 28, 1))
y_test = k.utils.to_categorical(y_test, num_classes)
y_train = k.utils.to_categorical(y_train, num_classes)
model = STN(input_shape=(28, 28, 1), sampling_size=(14, 14))
model.compile(loss='categorical_crossentropy', optimizer=Adam())
input_image = model.input
output_STN = model.get_layer('bilinear_interpolation_1').output
STN_function = K.function([input_image], [output_STN])
num_epochs = 3
batch_size = 10
model.fit(x_train, y_train, batch_size=batch_size, epochs=num_epochs)
image_result = STN_function([x_train[:10]])
for i in range(2):
plt.imshow(x_train[i].reshape(28, 28), cmap='gray')
plt.show()
image = K.np.squeeze(image_result[0][i])
plt.imshow(image, cmap='gray')
plt.show()
And here is the result, I couldn't get the transformed image but the whole black.
What do I need to do when I change to other types of datasets? The loss stuck at around 2.3000 after 3 epochs of training.