hessian_penalty
hessian_penalty copied to clipboard
Hessian Penalty decreasing rapidly
I plug in the hessian penantly in an autoencoder
h = encoder(E)
E_ = decoder(h)
L = Lambda(lambda x:tf.reduce_mean(tf.square(x[0] - x[1]), axis=(1, 2, 3)), name='mse')([E, E_])
f = lambda x:tf.reduce_mean(tf.abs(x), axis=(1,2,3))
hessian = Lambda(lambda x:hessian_penalty(decoder, x, reduction=f, return_separately=True)[0], name='hessian')(h)
__________________________________________________________________________________________________
encoder (Model) (None, 128) 2224768 net[7][2]
__________________________________________________________________________________________________
decoder (Model) (None, 218, 182, 2) 3462272 encoder[1][0]
__________________________________________________________________________________________________
mse (Lambda) (None,) 0 net[7][2]
decoder[1][0]
__________________________________________________________________________________________________
hessian (Lambda) (None,) 0 encoder[1][0]
And the hessian penantly is decreasing very very rapidly: in the first epoch has a value of about 4, in the second epoch has a value of 1e-08
is this normal?
Thanks for the question. I have observed that it sometimes decreases quickly when training, but not to 1e-8. Is your MSE decreasing slowly? It may be that the hessian penalty is swamping MSE, and so you need to apply a loss weighting to balance the terms.
This is what i get in terms of loss
The blue is the MSE, the orange is the Hessian Penalty x 10^6
it does make sense, i suppose, that at first it's very low, at at the beginning of the training the features are random, and therefore independent from each other? but it still stays rather low.
It does seem slightly odd to me that it drops so low so quickly. Does your plot include the value of the hessian penalty at initialization? Assuming it doesn't start out at 1e-8, my guess is that multiplying the hessian penalty by a weight of 0.1, 0.01, etc. might prevent this sort of behavior. It could also be a result of the type of initialization of your network/ the scale of the data you are training on (if weights are small then the hessian may be small as well). It may also depend on the depth of your decoder network. For example, a single-layer (i.e., linear) network will always have 0 hessian penalty, and empirically it seems that deeper networks commonly have larger hessian penalties. The main indicator I would watch out for is the MSE loss. If the reconstructions appear accurate and you have a small MSE then my intuition is that having a small hessian penalty should tend to be a good sign.
Also I would recommend against using tf.abs
in the reduction function (sorry for any confusion about this, it was only intended to be used as a test case). This may not be related to the low loss problem but it might cause other issues. You may also find it beneficial to use the reduction we used in the paper (tf.reduce_max
, but you can also try tf.reduce_mean
without the abs). You can replace lines 4 and 5 in your code snippet with the following:
f = lambda x: tf.reduce_max(x, axis=(1,2,3))
hessian = Lambda(lambda x:hessian_penalty(decoder, x, reduction=f), name='hessian')(h)
ok, by using reduce_max
the result gets better indeed. But it's still around 0.01 magnitude.
I noticed that if i ran the penalty on a trained autoencoder, the penalty starts at a higher value (~100) and decreases slowly, but if i start random, it will drop in less 1 epoch, and starts increasing right after
Sorry for the delay in responding. I think the behavior you're describing is more typical of what I've seen applying it to ProGAN. 0.01 magnitude can be reasonable depending on the depth of your decoder network. We found in our paper that resuming training of a pre-trained model with the HP gives good results and sometimes works better than training from scratch with the HP (possibly because an over-aggressive HP loss weighting when training from scratch can cause your model to get stuck in a local optima early in training). Ultimately I think the main thing to keep an eye on in your case is the reconstruction loss. If the HP drops too low, it's possible that it is doing something degenerate like simply scaling-down the output of the decoder. Such a degenerate solution should yield a worse reconstruction loss. But if your reconstruction loss is low (similar to what you get when you train without the HP) I think that should be a good sign that it is not doing something degenerate.