CancelOut icon indicating copy to clipboard operation
CancelOut copied to clipboard

Question about Keras version

Open nhattan214 opened this issue 2 years ago • 2 comments
trafficstars

Hi Dr. Vadim,

In your Keras example, the CancelOut loss does not contain the variance term def call(self, inputs): if self.cancelout_loss: self.add_loss( self.lambda_1 * tf.norm(self.w, ord=1) + self.lambda_2 * tf.norm(self.w, ord=2)) return tf.math.multiply(inputs, self.activation(self.w)) May I ask why we need to sum up the L1-norm and L2-norm of the weight here? Also, in your Pytorch notebook, it is mentioned that the variance term is optional. So we can set that term to 0?

Regards, Tan

nhattan214 avatar Jun 21 '23 13:06 nhattan214

Hi Tan,

Thank you for opening the issue and your interest in our work.

In your Keras example, the CancelOut loss does not contain the variance term

After publishing the paper, I reconsider this term in the loss function, I do not think it is necessary to have it, but it might help.

May I ask why we need to sum up the L1-norm and L2-norm of the weight here?

This is a standard approach, since we need to have one number to calculate the loss.

Also, in your Pytorch notebook, it is mentioned that the variance term is optional. So we can set that term to 0?

it is a hyper-parameter, you have to find the best one for your dataset and model. }

Hope this helps!

unnir avatar Jul 04 '23 14:07 unnir

Hi Dr Vadim,

Thank you very much for your replies. I have some other questions:

  1. From the Torch code: loss = criterion(outputs, labels) + lambda_1 * l1_norm - lambda_2 * var, I can clearly see that this is from Equation 10 from the paper. However, in the Keras version it shows the following def call(self, inputs): if self.cancelout_loss: self.add_loss( self.lambda_1 * tf.norm(self.w, ord=1) + self.lambda_2 * tf.norm(self.w, ord=2)) return tf.math.multiply(inputs, self.activation(self.w)) I wonder where is the entropy loss in this case?
  2. Also, in the Keras version, inside Cancelout class, the method add_weight(), and add_loss() are not defined, but I can still run the code properly. I am a little bit confused about this. I hope you can help me clarify this also.

Thank you and hope to hear from you soon

nhattan214 avatar Jul 05 '23 09:07 nhattan214