hessian_penalty icon indicating copy to clipboard operation
hessian_penalty copied to clipboard

Hessian Penalty decreasing rapidly

Open HansLeonardVanBrueggemann opened this issue 2 years ago • 6 comments

I plug in the hessian penantly in an autoencoder

h = encoder(E)
E_ = decoder(h)
L = Lambda(lambda x:tf.reduce_mean(tf.square(x[0] - x[1]), axis=(1, 2, 3)), name='mse')([E, E_])
f = lambda x:tf.reduce_mean(tf.abs(x), axis=(1,2,3))
hessian = Lambda(lambda x:hessian_penalty(decoder, x, reduction=f, return_separately=True)[0], name='hessian')(h)

__________________________________________________________________________________________________
encoder (Model)                 (None, 128)          2224768     net[7][2]                        
__________________________________________________________________________________________________
decoder (Model)                 (None, 218, 182, 2)  3462272     encoder[1][0]                    
__________________________________________________________________________________________________
mse (Lambda)                    (None,)              0           net[7][2]                        
                                                                 decoder[1][0]                    
__________________________________________________________________________________________________
hessian (Lambda)                (None,)              0           encoder[1][0]                    

And the hessian penantly is decreasing very very rapidly: in the first epoch has a value of about 4, in the second epoch has a value of 1e-08

is this normal?

Thanks for the question. I have observed that it sometimes decreases quickly when training, but not to 1e-8. Is your MSE decreasing slowly? It may be that the hessian penalty is swamping MSE, and so you need to apply a loss weighting to balance the terms.

wpeebles avatar Aug 31 '21 21:08 wpeebles

This is what i get in terms of loss image The blue is the MSE, the orange is the Hessian Penalty x 10^6

it does make sense, i suppose, that at first it's very low, at at the beginning of the training the features are random, and therefore independent from each other? but it still stays rather low.

It does seem slightly odd to me that it drops so low so quickly. Does your plot include the value of the hessian penalty at initialization? Assuming it doesn't start out at 1e-8, my guess is that multiplying the hessian penalty by a weight of 0.1, 0.01, etc. might prevent this sort of behavior. It could also be a result of the type of initialization of your network/ the scale of the data you are training on (if weights are small then the hessian may be small as well). It may also depend on the depth of your decoder network. For example, a single-layer (i.e., linear) network will always have 0 hessian penalty, and empirically it seems that deeper networks commonly have larger hessian penalties. The main indicator I would watch out for is the MSE loss. If the reconstructions appear accurate and you have a small MSE then my intuition is that having a small hessian penalty should tend to be a good sign.

wpeebles avatar Sep 02 '21 07:09 wpeebles

Also I would recommend against using tf.abs in the reduction function (sorry for any confusion about this, it was only intended to be used as a test case). This may not be related to the low loss problem but it might cause other issues. You may also find it beneficial to use the reduction we used in the paper (tf.reduce_max, but you can also try tf.reduce_mean without the abs). You can replace lines 4 and 5 in your code snippet with the following:

f = lambda x: tf.reduce_max(x, axis=(1,2,3))
hessian = Lambda(lambda x:hessian_penalty(decoder, x, reduction=f), name='hessian')(h)

wpeebles avatar Sep 02 '21 07:09 wpeebles

ok, by using reduce_max the result gets better indeed. But it's still around 0.01 magnitude.

I noticed that if i ran the penalty on a trained autoencoder, the penalty starts at a higher value (~100) and decreases slowly, but if i start random, it will drop in less 1 epoch, and starts increasing right after

Sorry for the delay in responding. I think the behavior you're describing is more typical of what I've seen applying it to ProGAN. 0.01 magnitude can be reasonable depending on the depth of your decoder network. We found in our paper that resuming training of a pre-trained model with the HP gives good results and sometimes works better than training from scratch with the HP (possibly because an over-aggressive HP loss weighting when training from scratch can cause your model to get stuck in a local optima early in training). Ultimately I think the main thing to keep an eye on in your case is the reconstruction loss. If the HP drops too low, it's possible that it is doing something degenerate like simply scaling-down the output of the decoder. Such a degenerate solution should yield a worse reconstruction loss. But if your reconstruction loss is low (similar to what you get when you train without the HP) I think that should be a good sign that it is not doing something degenerate.

wpeebles avatar Sep 16 '21 06:09 wpeebles