loss and accuracy different on Custom fit function using train_step function compared to normal training
Please go to TF Forum for help and support:
https://discuss.tensorflow.org/tag/keras
If you open a GitHub issue, here is our policy:
It must be a bug, a feature request, or a significant problem with the documentation (for small docs fixes please send a PR instead). The form below must be filled out.
Here's why we have that policy:.
Keras developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GPU model and memory:
- Exact command to reproduce:
You can collect some of this information using our environment capture script:
https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
You can obtain the TensorFlow version with: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
Describe the problem.
When I am training a model on CIFAR10 with only 4000 labels, the loss converges and accuracy improves during the normal training. But when I use my own custom fit function using train step, the loss and the accuracy do not behave in similar way.
Describe the current behavior. Loss and accuracy decreasing for normal model but getting different results on custom model
Describe the expected behavior. Loss and accuracy must behave same way on both models
- Do you want to contribute a PR? (yes/no):
- If yes, please read this page for instructions
- Briefly describe your candidate solution(if contributing):
Standalone code to reproduce the issue.
#preprocessing
(trainX,trainY), (testX,testY) = tf.keras.datasets.cifar10.load_data()
import sklearn
#trainX, valX, trainY,valY = sklearn.model_selection.train_test_split(trainX,trainY,test_size=0.2) trainX1 = [] trainY1 = [] trainuns = [] trainunsY = [] for c in range(0,10): indices = tf.where(trainY==c) subtrain = trainX[indices[:,0]] withlab = subtrain[0:400,:,:,:] print(subtrain[400:,:,:,:].shape) trainX1.append(np.array(withlab)) trainY1.append(trainY[indices[:,0]][0:400]) trainuns.append(np.array(subtrain[400:,:,:,:])) trainunsY.append(trainY[indices[:,0]][400:])
trainX1 = np.stack(trainX1,0) print(trainX1.shape) trainY1 = np.array(trainY1) trainuns = np.concatenate(trainuns,axis=0) trainunsY = np.array(trainunsY) print(trainuns.shape)
trainX1 = tf.reshape(trainX1,(4000,32,32,3)) trainuns = tf.reshape(trainuns,(-1,32,32,3)) trainY1 = tf.reshape(trainY1,[4000,-1]) def _input_fn(X, Xuns, y): dataset = tf.data.Dataset.from_tensor_slices((X,y)) dataset = dataset.batch(128, drop_remainder=False) dataset2 = tf.data.Dataset.from_tensor_slices((Xuns)) dataset2 = dataset2.batch(128, drop_remainder=False) dataset = tf.data.Dataset.zip( (dataset, dataset2) ).prefetch(buffer_size=tf.data.AUTOTUNE) return dataset
data_augmentation = keras.Sequential( [ layers.RandomFlip("horizontal"), layers.RandomCrop(32,32), layers.Normalization(mean=((0.491,0.482,0.446)),variance=((0.247,0.243,0.261))) ] )
#Model def VGG16(): weight_decay=0.0005 inputs = layers.Input(shape=(32, 32, 3)) #inputs = data_augmentation(inputs) x = data_augmentation(inputs) x = layers.GaussianNoise(stddev=0.15)(inputs) x1 = layers.Conv2D(filters=128,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) '''x = layers.BatchNormalization()(x1) x = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x)''' #x = layers.Add()([x,layers.Conv2D(filters=64,kernel_size=(3,3),strides=(2,2),padding="same", activation="relu")(x1)]) x = layers.Conv2D(filters=128,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x1) x = layers.BatchNormalization()(x) x2 = layers.Conv2D(filters=128,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x2) x = layers.Dropout(0.5)(x) x = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) x = layers.Conv2D(filters=256,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=256,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=256,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Dropout(0.5)(x) x = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) x = layers.Conv2D(filters=512,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=512,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=512,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Dropout(0.5)(x) x = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) x = layers.Conv2D(filters=512,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=256,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=128,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) xlow = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) #xlow_dense = layers.Dense(11,activation='softmax')(xlow) xi = layers.GlobalAveragePooling2D()(xlow) xi = layers.BatchNormalization()(xi) x1 = layers.Dense(128,activation='relu')(xi) x1 = layers.BatchNormalization()(x1) x = layers.Dense(10,activation='softmax')(x1) return tf.keras.Model(inputs = inputs,outputs = x)
#Normal model and training model,_ = VGG16() model.compile(optimizer=tf.keras.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics='sparse_categorical_accuracy') model.fit(trainX1,trainY1,epochs=200)
#My custom model class CustomModel(keras.Model): def init(self): super(CustomModel, self).init() #self.encoder, self.mid_model = VGG16() self.encoder = VGG16() self.accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
def compile(self, optimizer, loss,metrics):
super().compile(optimizer)
self.loss = loss
self.optimizer = optimizer
self.metric = metrics
def call(self,inputs):
return self.encoder(inputs)
def train_step(self, data):
(data1,label), data_uns=data
with tf.GradientTape() as tape:
y_pred = self(data1,training=True)
y_pred_aug = self.encoder(data_augmentation(data1),training=True)
loss_value = keras.losses.sparse_categorical_crossentropy(label, y_pred, from_logits=True)
acc = self.accuracy(label,y_pred)
grads = tape.gradient(loss_value, self.encoder.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.encoder.trainable_weights))
return {m.name: m.result() for m in self.metrics}
model2 = CustomModel() model2.compile(optimizer=tf.keras.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics='sparse_categorical_accuracy') model2.fit(_input_fn(trainX1,trainuns,trainY1),epochs=200)
Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Source code / logs. Result during normal training (None, 32, 32, 3) Epoch 1/200 125/125 [==============================] - 7s 30ms/step - loss: 3.9305 - sparse_categorical_accuracy: 0.2280 Epoch 2/200 125/125 [==============================] - 4s 31ms/step - loss: 3.5681 - sparse_categorical_accuracy: 0.3268 Epoch 3/200 125/125 [==============================] - 4s 28ms/step - loss: 3.3872 - sparse_categorical_accuracy: 0.3728 Epoch 4/200 125/125 [==============================] - 3s 28ms/step - loss: 3.1754 - sparse_categorical_accuracy: 0.3938 Epoch 5/200 125/125 [==============================] - 3s 28ms/step - loss: 2.9871 - sparse_categorical_accuracy: 0.4283 Epoch 6/200 125/125 [==============================] - 4s 28ms/step - loss: 2.7658 - sparse_categorical_accuracy: 0.4703 Epoch 7/200 125/125 [==============================] - 4s 28ms/step - loss: 2.6248 - sparse_categorical_accuracy: 0.4723 Epoch 8/200 125/125 [==============================] - 4s 30ms/step - loss: 2.4352 - sparse_categorical_accuracy: 0.5192 Epoch 9/200 125/125 [==============================] - 4s 28ms/step - loss: 2.3484 - sparse_categorical_accuracy: 0.5145 Epoch 10/200 125/125 [==============================] - 3s 28ms/step - loss: 2.2068 - sparse_categorical_accuracy: 0.5443 Epoch 11/200 125/125 [==============================] - 3s 28ms/step - loss: 2.0758 - sparse_categorical_accuracy: 0.5658 Epoch 12/200 125/125 [==============================] - 3s 28ms/step - loss: 1.9753 - sparse_categorical_accuracy: 0.5838 Epoch 13/200 125/125 [==============================] - 4s 28ms/step - loss: 1.8631 - sparse_categorical_accuracy: 0.6208 Epoch 14/200 125/125 [==============================] - 4s 28ms/step - loss: 1.7671 - sparse_categorical_accuracy: 0.6363 Epoch 15/200 125/125 [==============================] - 3s 28ms/step - loss: 1.7040 - sparse_categorical_accuracy: 0.6590 Epoch 16/200 125/125 [==============================] - 4s 33ms/step - loss: 1.6649 - sparse_categorical_accuracy: 0.6680 Epoch 17/200 103/125 [=======================>......] - ETA: 0s - loss: 1.6199 - sparse_categorical_accuracy: 0.6799
During training using CustomModel() model
(None, 32, 32, 3)
Epoch 1/200
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:1082: UserWarning: "sparse_categorical_crossentropy received from_logits=True, but the output argument was produced by a sigmoid or softmax activation and thus does not represent logits. Was this intended?"
return dispatch_target(*args, **kwargs)
32/32 [==============================] - 47s 115ms/step - sparse_categorical_accuracy: 0.1192
Epoch 2/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.0780
Epoch 3/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.0785
Epoch 4/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.0635
Epoch 5/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.0540
Epoch 6/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.0640
Epoch 7/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.0610
Epoch 8/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.0775
Epoch 9/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.0857
Epoch 10/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.0852
Epoch 11/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.1085
Epoch 12/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.1025
Epoch 13/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.1053
Epoch 14/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.1085
Epoch 15/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.1182
Epoch 16/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.1258
Epoch 17/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1182
Epoch 18/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1433
Epoch 19/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1098
Epoch 20/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1200
Epoch 21/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1332
Epoch 22/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1495
Epoch 23/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1530
Epoch 24/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.1472
Epoch 25/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1260
Epoch 26/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1550
Epoch 27/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1412
Epoch 28/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1493
Epoch 29/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1490
Epoch 30/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.1450
Epoch 31/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1315
Epoch 32/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1737
Epoch 33/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.1918
Epoch 34/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.1865
Epoch 35/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1875
Epoch 36/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.1928
Epoch 37/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.2007
Epoch 38/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2002
Epoch 39/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.2163
Epoch 40/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2230
Epoch 41/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2578
Epoch 42/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2520
Epoch 43/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2548
Epoch 44/200
32/32 [==============================] - 4s 119ms/step - sparse_categorical_accuracy: 0.2582
Epoch 45/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2850
Epoch 46/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.2785
Epoch 47/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2948
Epoch 48/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2803
Epoch 49/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2957
Epoch 50/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2693
Epoch 51/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2150
Epoch 52/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2360
Epoch 53/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2855
Epoch 54/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2982
Epoch 55/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2993
Epoch 56/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3075
Epoch 57/200
32/32 [==============================] - 4s 119ms/step - sparse_categorical_accuracy: 0.3298
Epoch 58/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.3277
Epoch 59/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3225
Epoch 60/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.3440
Epoch 61/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3273
Epoch 62/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3355
Epoch 63/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3318
Epoch 64/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3252
Epoch 65/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3410
Epoch 66/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3363
Epoch 67/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3293
Epoch 68/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3433
Epoch 69/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3395
Epoch 70/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3347
Epoch 71/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3363
Epoch 72/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3402
Epoch 73/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3295
Epoch 74/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3365
Epoch 75/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3237
Epoch 76/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3372
Epoch 77/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3338
Epoch 78/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3275
Epoch 79/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3332
Epoch 80/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3288
Epoch 81/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3322
Epoch 82/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3232
Epoch 83/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3343
Epoch 84/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3290
Epoch 85/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3300
Epoch 86/200
32/32 [==============================] - 4s 121ms/step - sparse_categorical_accuracy: 0.3275
Epoch 87/200
32/32 [==============================] - 4s 128ms/step - sparse_categorical_accuracy: 0.3313
Epoch 88/200
32/32 [==============================] - 4s 129ms/step - sparse_categorical_accuracy: 0.3245
Epoch 89/200
32/32 [==============================] - 4s 123ms/step - sparse_categorical_accuracy: 0.3313
Epoch 90/200
32/32 [==============================] - 4s 121ms/step - sparse_categorical_accuracy: 0.3327
Epoch 91/200
32/32 [==============================] - 4s 121ms/step - sparse_categorical_accuracy: 0.3360
Epoch 92/200
32/32 [==============================] - 4s 123ms/step - sparse_categorical_accuracy: 0.3240
Epoch 93/200
32/32 [==============================] - 4s 124ms/step - sparse_categorical_accuracy: 0.3400
Epoch 94/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3335
Epoch 95/200
32/32 [==============================] - 4s 118ms/step - sparse_categorical_accuracy: 0.3305
Epoch 96/200
32/32 [==============================] - 4s 123ms/step - sparse_categorical_accuracy: 0.3270
Epoch 97/200
32/32 [==============================] - 4s 118ms/step - sparse_categorical_accuracy: 0.3363
Epoch 98/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3330
Epoch 99/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3345
Epoch 100/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3338
Epoch 101/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3307
Epoch 102/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3300
Epoch 103/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3332
Epoch 104/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3268
Epoch 105/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3275
Epoch 106/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3280
Epoch 107/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3255
Epoch 108/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3365
Epoch 109/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3288
Epoch 110/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3320
Epoch 111/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3220
Epoch 112/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3358
Epoch 113/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3390
Epoch 114/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3360
Epoch 115/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.2083
Epoch 116/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.0805
Epoch 117/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.0432
Epoch 118/200
32/32 [==============================] - 4s 119ms/step - sparse_categorical_accuracy: 0.0422
Epoch 119/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.0550
Epoch 120/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.0693
Epoch 121/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.0780
Epoch 122/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.0885
Epoch 123/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.0972
Epoch 124/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.0980
Epoch 125/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1130
Epoch 126/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.1125
Epoch 127/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1353
Epoch 128/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.1458
Epoch 129/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1605
Epoch 130/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1622
Epoch 131/200
32/32 [==============================] - 4s 119ms/step - sparse_categorical_accuracy: 0.1400
Epoch 132/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1562
Epoch 133/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1618
Epoch 134/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2023
Epoch 135/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1902
Epoch 136/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.1815
Epoch 137/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2120
Epoch 138/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.1558
Epoch 139/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2070
Epoch 140/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2387
Epoch 141/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.2470
Epoch 142/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.2663
Epoch 143/200
32/32 [==============================] - 4s 118ms/step - sparse_categorical_accuracy: 0.2882
Epoch 144/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2898
Epoch 145/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.2848
Epoch 146/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3097
Epoch 147/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3075
Epoch 148/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3228
Epoch 149/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3050
Epoch 150/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3330
Epoch 151/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3250
Epoch 152/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3315
Epoch 153/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3388
Epoch 154/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3255
Epoch 155/200
32/32 [==============================] - 4s 117ms/step - sparse_categorical_accuracy: 0.3173
Epoch 156/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3422
Epoch 157/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3347
Epoch 158/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3385
Epoch 159/200
32/32 [==============================] - 4s 118ms/step - sparse_categorical_accuracy: 0.3445
Epoch 160/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3400
Epoch 161/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3395
Epoch 162/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3415
Epoch 163/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3338
Epoch 164/200
32/32 [==============================] - 4s 115ms/step - sparse_categorical_accuracy: 0.3397
Epoch 165/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3383
Epoch 166/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3395
Epoch 167/200
32/32 [==============================] - 4s 116ms/step - sparse_categorical_accuracy: 0.3355
Epoch 168/200
5/32 [===>..........................] - ETA: 3s - sparse_categorical_accuracy: 0.5516
Please help me identify why the results are so different !!
colab notebook https://colab.research.google.com/drive/1hCZK1XO5l9LPEQEQKRE0uotMv_ap7xbK?authuser=2#scrollTo=lQZmuw5enmuO Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem.
@arpit196 In order to expedite the trouble-shooting process, please provide a code snippet to reproduce the issue reported here. Thank you!
(trainX,trainY), (testX,testY) = tf.keras.datasets.cifar10.load_data() import sklearn
#trainX, valX, trainY,valY = sklearn.model_selection.train_test_split(trainX,trainY,test_size=0.2) trainX1 = [] trainY1 = [] trainuns = [] trainunsY = [] for c in range(0,10): indices = tf.where(trainY==c) subtrain = trainX[indices[:,0]] withlab = subtrain[0:400,:,:,:] print(subtrain[400:,:,:,:].shape) trainX1.append(np.array(withlab)) trainY1.append(trainY[indices[:,0]][0:400]) trainuns.append(np.array(subtrain[400:,:,:,:])) trainunsY.append(trainY[indices[:,0]][400:])
trainX1 = np.stack(trainX1,0) print(trainX1.shape) trainY1 = np.array(trainY1) trainuns = np.concatenate(trainuns,axis=0) trainunsY = np.array(trainunsY)
def VGG16(): weight_decay=0.0005 inputs = layers.Input(shape=(32, 32, 3)) #inputs = data_augmentation(inputs) x = data_augmentation(inputs) x = layers.GaussianNoise(stddev=0.15)(inputs) x1 = layers.Conv2D(filters=128,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) '''x = layers.BatchNormalization()(x1) x = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x)''' #x = layers.Add()([x,layers.Conv2D(filters=64,kernel_size=(3,3),strides=(2,2),padding="same", activation="relu")(x1)]) x = layers.Conv2D(filters=128,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x1) x = layers.BatchNormalization()(x) x2 = layers.Conv2D(filters=128,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x2) x = layers.Dropout(0.5)(x) x = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) x = layers.Conv2D(filters=256,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=256,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=256,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Dropout(0.5)(x) x = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) x = layers.Conv2D(filters=512,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=512,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=512,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Dropout(0.5)(x) x = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) x = layers.Conv2D(filters=512,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=256,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(filters=128,kernel_size=(3,3),padding="same", activation="relu",kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) x = layers.BatchNormalization()(x) xlow = layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) #xlow_dense = layers.Dense(11,activation='softmax')(xlow) xi = layers.GlobalAveragePooling2D()(xlow) xi = layers.BatchNormalization()(xi) x1 = layers.Dense(128,activation='relu')(xi) x1 = layers.BatchNormalization()(x1) x = layers.Dense(10,activation='softmax')(x1) return tf.keras.Model(inputs = inputs,outputs = x)
trainX1 = tf.reshape(trainX1,(4000,32,32,3)) trainuns = tf.reshape(trainuns,(-1,32,32,3)) trainY1 = tf.reshape(trainY1,[4000,-1]) def _input_fn(X, Xuns, y): dataset = tf.data.Dataset.from_tensor_slices((X,y)) dataset = dataset.batch(128, drop_remainder=False) dataset2 = tf.data.Dataset.from_tensor_slices((Xuns)) dataset2 = dataset2.batch(128, drop_remainder=False) dataset = tf.data.Dataset.zip( (dataset, dataset2) ) return dataset
class CustomModel(keras.Model): def init(self): super(CustomModel, self).init() self.encoder = VGG16() self.accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
def compile(self, optimizer, loss,metrics):
super().compile(optimizer)
self.loss = loss
self.optimizer = optimizer
self.metric = metrics
def call(self,inputs):
return self.encoder(inputs)
def train_step(self, data):
(data1,label), data_uns=data
with tf.GradientTape() as tape:
y_pred = self(data1,training=True)
y_pred_aug = self.encoder(data_augmentation(data1),training=True)
loss_value = keras.losses.sparse_categorical_crossentropy(label, y_pred, from_logits=True)
acc = self.accuracy(label,y_pred)
grads = tape.gradient(loss_value, self.encoder.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.encoder.trainable_weights))
return {m.name: m.result() for m in self.metrics}
model2 = CustomModel() model2.compile(optimizer=tf.keras.optimizers.Adam(0.003), loss='sparse_categorical_crossentropy', metrics='sparse_categorical_accuracy') model2.fit(_input_fn(trainX1,trainuns,trainY1),epochs=200)
@arpit196 Could you please share the notebook link or colab gist in order to reproduce the issue correctly. I am getting different error while replicating the above code. Thank you!
Dear @sushreebarsa can you access this notebook https://colab.research.google.com/drive/1BG7c7Cn-GST4KuKPtTW-YmTLlCSjf0Vg?authuser=1#scrollTo=zL0gBSujg-CZ? You can note that normally training the VGG16 model defined on top does not lead to explosion of gradients but the CustomModelSemi one below gives nan loss.
@gadagashwini I was able replicate the issue on colab, please find the gist here. Thank you!
This is a fairly large training script, with quite a bit going on in the trian_step. If this is a bug report, it would help to get a much more concise reproduction of the issues.
It looks like this issue may be fairly simple though. In the custom train_step, you computing a loss with keras.losses.categorical_crossentropy(label, y_pred, from_logits=True). But y_pred looks like it is generated from a softmax in the VGG16 model. This is incorrect. If your predicted labels are from a softmax of sigmoid, you need to pass from_logits=False, these are probabilities, not logits.