dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

Wrong gradients in a Haiku network

Open SNMS95 opened this issue 8 months ago • 0 comments

Hey guys, I was trying to use haiku to create a convolutional neural network architecture (to reproduce a paper, whose implementation was in TF2.0). The CNN works correctly, however, when I use jax.test_util.check_grads, there seems to be an error. The code is as follows:

import jax.numpy as np
import jax
from functools import partial


class CNNParameterization(hk.Module):

    def __init__(self):
        super().__init__()
        self.layers = self._build_layers()
        
    def _build_layers(self):
        activation = jax.nn.leaky_relu
        Nx = 64
        Ny = 64
        total_resize = onp.prod((1, 2, 2, 2, 1 ))
        h = Nx // total_resize
        w = Ny // total_resize

        layers = []
        self.latent_params = hk.get_parameter(
                        "beta", shape=(128, ),
                        init=hk.initializers.RandomNormal())
        dense_output_size = 32*w*h
        dense_init = hk.initializers.Orthogonal(
                        scale=1.0*onp.sqrt(onp.max(
                                (dense_output_size/128, 1)
                            )))
        dense_layer = hk.Linear(output_size=dense_output_size,
                                name='dense_layer',
                                w_init=dense_init)
        layers.append(dense_layer)
        # Reshape preserves batch dimension
        layers.append(hk.Reshape((h, w, 32),
                                 name="reshape"))
        counter = 0
        for resize, conv_filters in zip((1, 2, 2, 2, 1), \
                                        (32, 16, 8, 4, 1)):
            layers.append(activation)
            layers.append(hk.Conv2D(output_channels=conv_filters,
                                    kernel_shape=(5, 5),
                                    padding='SAME',
                                    name='conv_layer',
                                    w_init=hk.initializers.VarianceScaling()))
            counter += 1
        return layers

    def __call__(self, model_input: jax.Array = None):
        """Forward pass.

        The model input is unused.
        """
        del model_input

        x = self.latent_params
        for layer_no, layer in enumerate(self.layers):
            if layer_no == 1:  # Only for reshaping layer
                x = layer(x.reshape((1, ) + layer.output_shape))
            else:
                x = layer(x)
        x = x.ravel()
        return x

# Test the gradients
def mapping_fn(x):
     result = CNNParameterization()(x)
     return result

model_input = np.ones((100, 3))
forward_pass_pure = hk.without_apply_rng(
    hk.transform_with_state(mapping_fn))
init_params, init_state = forward_pass_pure.init(x=model_input,
                                                 rng=rng_key)
forward_func = jax.jit(forward_pass_pure.apply)

def dummy_func(params, state, x):
    return forward_func(params, state, x)[0].mean()

check_grads(dummy_func, (init_params, init_state, model_input),
            order=2, eps=1e-4)

The error is :

AssertionError: 
Not equal to tolerance rtol=0.002, atol=0.002
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.04407668
Max relative difference: 0.01132658
 x: array(-3.847362, dtype=float32)
 y: array(-3.891438, dtype=float32)

SNMS95 avatar Oct 17 '23 19:10 SNMS95