tf_unet icon indicating copy to clipboard operation
tf_unet copied to clipboard

Batch normalization

Open fischersci opened this issue 7 years ago • 7 comments

Is there a reason why no batch normalization is applied? I read batch normalization can be helpful to improve convergence during training.

fischersci avatar Nov 09 '17 09:11 fischersci

The unet is often trained with a batch size of one. However it would be relatively easy to extend the data providers to have normalization

jakeret avatar Nov 14 '17 14:11 jakeret

Thanks for your answer. Yes I know but in my case the results are much better with batch size of 32. Batch normalization is often applied after Convolution and before the ReLU. I know tensorflow provides specific code (tf.nn.batch_normalization) for that but I am not sure how to implement it

fischersci avatar Nov 15 '17 08:11 fischersci

I see. I had a quick look at the BN - the data provider is obviously the wrong place.

Couldn't you adapt the layers.conv2d method to call tf.nn.moments and then tf.nn.batch_normalization

jakeret avatar Nov 15 '17 08:11 jakeret

fischersci Did you get any performance improvement after batch normalization?

myway0101 avatar Jan 03 '18 12:01 myway0101

I am not sure if I did it correctly but the training results got really bad when I used it.

fischersci avatar Jan 14 '18 10:01 fischersci

Would you share your code/implementation for the batch normalization? I think it would be really helpful to get started with...

fschi avatar Jan 15 '18 09:01 fschi

I know it's bad style, but here a version which seems to be working/improving the results. When using only tensorflow I got some errors I didn't understand. So I ended up with this keras/tensorflow mixture in the layers.conv2d method...

from keras.layers import BatchNormalization
from keras import backend as K
...

def conv2d(x, W,keep_prob_):
    conv_2d = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')

    K.set_learning_phase(0)
    conv_2d_withBN = BatchNormalization()(conv_2d)
    conv_2d_withBNDropout = tf.nn.dropout(conv_2d_withBN, keep_prob_)
    conv_2d_withBNDropoutBN = BatchNormalization()(conv_2d_withBNDropout)

    return conv_2d_withBNDropoutBN

Comments are welcome!

fschi avatar Jan 18 '18 16:01 fschi