dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

batch norm initialization of variance moving average in resnet models

Open stefanozampini opened this issue 2 years ago • 1 comments

I believe you have a small bug when initializing the resnet models. Suppose I want to compute a gradient with is_training=False and test_local_stats=False. It may happen that the moving average for the variance is zero, breaking the jax.lax.rsqrt call and producing huge floating point numbers. Initializing the moving average for the variance to one (which seems reasonable to me) fixes the issues on my side. I'm attaching a small reproducer at the end of this message (I'm using dm-haiku 0.0.6 from pip install)

I cannot give you a patch since I'm not sure how you would like to fix this issue, since you access the current average value via a property method, https://github.com/deepmind/dm-haiku/blob/211b6ab7704784f1507a2575bea57798371eb5ce/haiku/_src/batch_norm.py#L182

My dummy fix is to add a bogus call to hidden = hk.get_state("average", value.shape, value.dtype, init=self.init_avg) here https://github.com/deepmind/dm-haiku/blob/211b6ab7704784f1507a2575bea57798371eb5ce/haiku/_src/moving_averages.py#L121 where init_avg is an argument passed to the ema constructor

import jax
import numpy as np
import jax.numpy as jnp

import haiku as hk
key = jax.random.PRNGKey(1)

def forward(images, is_training: bool):
    net = hk.nets.ResNet18(num_classes=10, bn_config={'decay_rate': 0.9})
    return net(images, is_training=is_training, test_local_stats=False)


model = hk.transform_with_state(forward)
sample_input = jnp.ones((1, 3, 32, 32))
image = jnp.array(np.random.rand(1, 3, 32, 32))
params, state = model.init(key, sample_input, is_training=True)

def test_fn(p):
    logits, state_new = model.apply(
        p, state, key, image, is_training=False)
    loss = logits.mean()
    return loss, state_new
(val, state), grads = jax.value_and_grad(test_fn, has_aux=True)(params)


#print('-----------params')
#print(params)
#print('-----------val')
#print(val)
#print('-------------state')
#print(state)

# If you uncomment the grads printing below, you should get something like:
#      ...
#      'res_net18/~/logits': {'b': DeviceArray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32), 'w': DeviceArray([[2.0988107e+12, 2.0988107e+12, 2.0988107e+12, ...,
#                    2.0988107e+12, 2.0988107e+12, 2.0988107e+12],
#                   [5.7180153e+11, 5.7180153e+11, 5.7180153e+11, ...,
#                    5.7180153e+11, 5.7180153e+11, 5.7180153e+11],
#                   [9.7583841e+10, 9.7583841e+10, 9.7583841e+10, ...,
#                    9.7583841e+10, 9.7583841e+10, 9.7583841e+10],
#                   ...,
#                   [6.5526674e+11, 6.5526674e+11, 6.5526674e+11, ...,
#                    6.5526674e+11, 6.5526674e+11, 6.5526674e+11],
#                   [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
#                    0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
#                   [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
#                    0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],            dtype=float32)}}
#print('--------------grads')
#print(grads)

stefanozampini avatar Apr 11 '22 08:04 stefanozampini

I am experiencing this problem too. My use case is for reinforcement learning where the network has be used to generate training data, and the policy raises error because the values blow up.

uduse avatar Jul 13 '22 21:07 uduse