Keras-Group-Normalization icon indicating copy to clipboard operation
Keras-Group-Normalization copied to clipboard

Permute dimensions before K.reshape

Open xuannianc opened this issue 5 years ago • 8 comments

Hi, @titu1994. Thanks for your hard work. I have one question after reading your code. When use tf as backend and set axis=-1 and if the input shape is [n, h, w, c], the current implementation would reshape the inputs to [n, g, h, w, c//g]. Do we need to permute it to [n, c, h, w] first, then reshape it to [n, g, c//g, h, w]?

xuannianc avatar Aug 27 '19 06:08 xuannianc

I'm filling the original papers tensorflow implementations which also used NHWC. So I don't think there needs to be a permutation first. Of the original format is NCHW, then maybe we need to permute first.

titu1994 avatar Aug 27 '19 12:08 titu1994

Hi,I debug group norm in your code, i found that result of keras group norm is not same with pytorch groupnorm. i gusses the reshape operation cause this difference?

Yuxiang1990 avatar Aug 28 '19 10:08 Yuxiang1990

Here is another implementation for Tensorflow which follows the same process as the paper. https://github.com/shaohua0116/Group-Normalization-Tensorflow/blob/master/ops.py

titu1994 avatar Aug 28 '19 12:08 titu1994

Update: Seems you are right. For TF, it needs to transpose to NCHW first, then perform reshape and then finally transpose back.

Thank you for catching it. I will update the code, but if you want credit, please send a PR to this repo.

titu1994 avatar Aug 28 '19 12:08 titu1994

Hi, is code already updated regarding this issue?

Raazzta avatar Nov 11 '19 02:11 Raazzta

I quite forgot to write this. I've kinda moved on from Keras to TF.keras, and that already has Group Normalize in the TF add-ons.

If you would like, could you submit a PR for this correction ?

titu1994 avatar Nov 11 '19 02:11 titu1994

Hi. to be honest, I dont have idea how to create a PR, excuse me.

Raazzta avatar Nov 12 '19 04:11 Raazzta

Hi, I am trying to implement the code in Keras. Would running the code look like this ?

x = GroupNormalization(groups=8, axis=-1)(x)

I've used axis = -1 since Keras by default is using channels_last. Thank you.

rahulgomes19 avatar Mar 06 '20 06:03 rahulgomes19