dm-haiku
dm-haiku copied to clipboard
Wrong gradients in a Haiku network
Hey guys, I was trying to use haiku to create a convolutional neural network architecture (to reproduce a paper, whose implementation was in TF2.0). The CNN works correctly, however, when I use jax.test_util.check_grads
, there seems to be an error. The code is as follows:
import jax.numpy as np
import jax
from functools import partial
class CNNParameterization(hk.Module):
def __init__(self):
super().__init__()
self.layers = self._build_layers()
def _build_layers(self):
activation = jax.nn.leaky_relu
Nx = 64
Ny = 64
total_resize = onp.prod((1, 2, 2, 2, 1 ))
h = Nx // total_resize
w = Ny // total_resize
layers = []
self.latent_params = hk.get_parameter(
"beta", shape=(128, ),
init=hk.initializers.RandomNormal())
dense_output_size = 32*w*h
dense_init = hk.initializers.Orthogonal(
scale=1.0*onp.sqrt(onp.max(
(dense_output_size/128, 1)
)))
dense_layer = hk.Linear(output_size=dense_output_size,
name='dense_layer',
w_init=dense_init)
layers.append(dense_layer)
# Reshape preserves batch dimension
layers.append(hk.Reshape((h, w, 32),
name="reshape"))
counter = 0
for resize, conv_filters in zip((1, 2, 2, 2, 1), \
(32, 16, 8, 4, 1)):
layers.append(activation)
layers.append(hk.Conv2D(output_channels=conv_filters,
kernel_shape=(5, 5),
padding='SAME',
name='conv_layer',
w_init=hk.initializers.VarianceScaling()))
counter += 1
return layers
def __call__(self, model_input: jax.Array = None):
"""Forward pass.
The model input is unused.
"""
del model_input
x = self.latent_params
for layer_no, layer in enumerate(self.layers):
if layer_no == 1: # Only for reshaping layer
x = layer(x.reshape((1, ) + layer.output_shape))
else:
x = layer(x)
x = x.ravel()
return x
# Test the gradients
def mapping_fn(x):
result = CNNParameterization()(x)
return result
model_input = np.ones((100, 3))
forward_pass_pure = hk.without_apply_rng(
hk.transform_with_state(mapping_fn))
init_params, init_state = forward_pass_pure.init(x=model_input,
rng=rng_key)
forward_func = jax.jit(forward_pass_pure.apply)
def dummy_func(params, state, x):
return forward_func(params, state, x)[0].mean()
check_grads(dummy_func, (init_params, init_state, model_input),
order=2, eps=1e-4)
The error is :
AssertionError:
Not equal to tolerance rtol=0.002, atol=0.002
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.04407668
Max relative difference: 0.01132658
x: array(-3.847362, dtype=float32)
y: array(-3.891438, dtype=float32)