STN.keras icon indicating copy to clipboard operation
STN.keras copied to clipboard

How to use this code to recognize mnist of size 28 * 28?

Open ghost opened this issue 6 years ago • 0 comments

Hi, thanks for providing the code. I have run the code and the result worked out well ,but when I try to transform the code to recognize mnist of size 28 * 28 I encounter some problems. Really appreciate if anyone could help. Here is my code.

import keras.backend as K
from keras.datasets import mnist
from keras.optimizers import Adam
from src.models import STN
import matplotlib.pyplot as plt
import keras as k

num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[:50000]
y_train = y_train[:50000]
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1))
x_test = x_test.reshape((x_test.shape[0], 28, 28, 1))
y_test = k.utils.to_categorical(y_test, num_classes)
y_train = k.utils.to_categorical(y_train, num_classes)

model = STN(input_shape=(28, 28, 1), sampling_size=(14, 14))
model.compile(loss='categorical_crossentropy', optimizer=Adam())
input_image = model.input
output_STN = model.get_layer('bilinear_interpolation_1').output
STN_function = K.function([input_image], [output_STN])

num_epochs = 3
batch_size = 10
model.fit(x_train, y_train, batch_size=batch_size, epochs=num_epochs)
image_result = STN_function([x_train[:10]])
for i in range(2):
    plt.imshow(x_train[i].reshape(28, 28), cmap='gray')
    plt.show()
    image = K.np.squeeze(image_result[0][i])
    plt.imshow(image, cmap='gray')
    plt.show()

And here is the result, I couldn't get the transformed image but the whole black. image image

What do I need to do when I change to other types of datasets? The loss stuck at around 2.3000 after 3 epochs of training.

ghost avatar Jan 24 '19 09:01 ghost