dm-haiku
dm-haiku copied to clipboard
batch norm initialization of variance moving average in resnet models
I believe you have a small bug when initializing the resnet models. Suppose I want to compute a gradient with is_training=False
and test_local_stats=False
. It may happen that the moving average for the variance is zero, breaking the jax.lax.rsqrt
call and producing huge floating point numbers. Initializing the moving average for the variance to one (which seems reasonable to me) fixes the issues on my side. I'm attaching a small reproducer at the end of this message (I'm using dm-haiku 0.0.6
from pip install
)
I cannot give you a patch since I'm not sure how you would like to fix this issue, since you access the current average value via a property method, https://github.com/deepmind/dm-haiku/blob/211b6ab7704784f1507a2575bea57798371eb5ce/haiku/_src/batch_norm.py#L182
My dummy fix is to add a bogus call to hidden = hk.get_state("average", value.shape, value.dtype, init=self.init_avg)
here https://github.com/deepmind/dm-haiku/blob/211b6ab7704784f1507a2575bea57798371eb5ce/haiku/_src/moving_averages.py#L121
where init_avg is an argument passed to the ema constructor
import jax
import numpy as np
import jax.numpy as jnp
import haiku as hk
key = jax.random.PRNGKey(1)
def forward(images, is_training: bool):
net = hk.nets.ResNet18(num_classes=10, bn_config={'decay_rate': 0.9})
return net(images, is_training=is_training, test_local_stats=False)
model = hk.transform_with_state(forward)
sample_input = jnp.ones((1, 3, 32, 32))
image = jnp.array(np.random.rand(1, 3, 32, 32))
params, state = model.init(key, sample_input, is_training=True)
def test_fn(p):
logits, state_new = model.apply(
p, state, key, image, is_training=False)
loss = logits.mean()
return loss, state_new
(val, state), grads = jax.value_and_grad(test_fn, has_aux=True)(params)
#print('-----------params')
#print(params)
#print('-----------val')
#print(val)
#print('-------------state')
#print(state)
# If you uncomment the grads printing below, you should get something like:
# ...
# 'res_net18/~/logits': {'b': DeviceArray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32), 'w': DeviceArray([[2.0988107e+12, 2.0988107e+12, 2.0988107e+12, ...,
# 2.0988107e+12, 2.0988107e+12, 2.0988107e+12],
# [5.7180153e+11, 5.7180153e+11, 5.7180153e+11, ...,
# 5.7180153e+11, 5.7180153e+11, 5.7180153e+11],
# [9.7583841e+10, 9.7583841e+10, 9.7583841e+10, ...,
# 9.7583841e+10, 9.7583841e+10, 9.7583841e+10],
# ...,
# [6.5526674e+11, 6.5526674e+11, 6.5526674e+11, ...,
# 6.5526674e+11, 6.5526674e+11, 6.5526674e+11],
# [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
# 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
# [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
# 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]], dtype=float32)}}
#print('--------------grads')
#print(grads)
I am experiencing this problem too. My use case is for reinforcement learning where the network has be used to generate training data, and the policy raises error because the values blow up.