neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

NotImplementedError: When I use stax.DotGeneral

Open kkeevin123456 opened this issue 2 years ago • 4 comments

Hi, when I try to implement two-layer coupling layer like below image. I got this error, do you have any insight to fix it?

image

The error looks like:

image

Some direction I had try

  • Since ResNet can work, the problem must not be FanInSum
  • I also set is_gaussian to be True
  • I think the problem may occur when I try to kernelize my architecture, so I try the optimizers.sgd to train my network. It works, but I still need to kernelize it

Here is some code can reproduce error:

    from jax import random
    from neural_tangents import stax
    import jax.numpy as np
    import neural_tangents as nt

    def DenseBlock(neurons):
        return stax.serial(
            stax.Dense(neurons), stax.Relu()
        )
    
    def ReluNetwork(latent_dim, hidden_dim, num_layers):
        """Create the network which is embedd in flow_base model
        
        Args:
            latent_dim: input and output dim
            hidden_dim: the width dim of the ReluNetwork
            num_layers: depth of the ReluNetwork
        
        Returns:
            stax.serial(ReluNetwork)
        """
        blocks = [DenseBlock(hidden_dim)]
        for _ in range(num_layers):
            blocks += [DenseBlock(hidden_dim)]
        blocks += [stax.Dense(latent_dim)]
        
        return stax.serial(*blocks)
    
    def lower_path(input_dim):
        helf_dim = input_dim//2
        # pre_half's rhs
        rhs1 = np.identity(helf_dim)
        rhs1 = np.pad(rhs1, ((0, 0), (0, helf_dim)))
        rhs1 = np.reshape(rhs1, (*rhs1.shape, 1))
        
        # post_half's rhs
        rhs2 = np.identity(helf_dim)
        rhs2 = np.pad(rhs2, ((helf_dim, 0), (helf_dim, 0)))
        rhs2 = np.reshape(rhs2, (*rhs2.shape, 1))
        
        
        rhs4 = np.identity(helf_dim)
        rhs4 = np.pad(rhs4, ((helf_dim, 0), (0, 0)))
        rhs4 = np.reshape(rhs4, (*rhs4.shape, 1))
        
        blocks = [
            stax.DotGeneral(
                    rhs = rhs1,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )]
        blocks += [ReluNetwork(latent_dim=helf_dim, hidden_dim=helf_dim//2, num_layers=4)]
        blocks += [
            stax.DotGeneral(
                    rhs = rhs4,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )]
        
        
        pre_half = stax.serial(
            *blocks
        )
        post_half = stax.serial(
            stax.DotGeneral(
                    rhs = rhs2,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )
        )
        return stax.serial(stax.FanOut(2),
                           stax.parallel(pre_half, post_half),
                           stax.FanInSum()
                          )
    
    def AdditiveCouplingLayer(input_dim, order):
        """the additive couplinglayer in the paper
        
        Args:
            nonlinearity: the ReluNetwork
        
        Returns:
            stax.serial(AdditiveCouplingLayer)
        """
        helf_dim = input_dim//2
        
        rhs_matrix = np.identity(helf_dim)
        rhs_matrix = np.pad(rhs_matrix, ((0, helf_dim), (0, helf_dim)))
        rhs_matrix = np.reshape(rhs_matrix, (*rhs_matrix.shape, 1))
    
        upper_path = stax.serial(
            stax.DotGeneral(
                    rhs = rhs_matrix,
                    dimension_numbers = (((2,), (1,)), ((), ())),
                    channel_axis = 1
                ), 
            stax.DotGeneral(
                    rhs = np.array([1]),
                    dimension_numbers = (((3,), (0,)), ((), ())),
                    channel_axis = 1
                )
        )
        
        return stax.serial(stax.FanOut(2),
                           stax.parallel(upper_path, lower_path(input_dim)),
                           stax.FanInSum()
                          )
    def LogisticPriorLoss(fx, y):
        return np.mean((0.5*np.sum(np.power(fx, 2), axis=1) + fx.shape[1]*0.5*np.log(2*np.pi)))

    # test
    x = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]])
    x = np.reshape(x, (x.shape[0], 1, *x.shape[1:]))
    input_dim = x.shape[2]  # (B, 1, 4): B is batch size
    helf_dim = input_dim//2
    
    init_fn, apply_fn, kernel_fn = lower_path(input_dim=input_dim)
    
    key = random.PRNGKey(1)
    _, params = init_fn(key, input_shape=x.shape)
    
    # z_train.dim = x_train.dim
    z_train = random.normal(key, x.shape)
    x_test = np.array([[1, 2, 3, 4, 5, 6]])
    x_test = np.reshape(x_test, (x_test.shape[0], 1, *x_test.shape[1:]))
    
    ntk_train_train = kernel_fn(x, x, 'ntk', channel_axis=1, is_gaussian=True)
    ntk_test_train = kernel_fn(x_test, x, 'ntk')
    predictor = nt.predict.gradient_descent(LogisticPriorLoss, ntk_train_train, z_train)

Many thanks for your kindly reply.

kkeevin123456 avatar Jul 26 '21 08:07 kkeevin123456

Great question!

A few points.

  1. You were on the right track with setting is_gaussian=True. Notice that post_half doesn't have any dense layers and so if the inputs to it aren't Gaussian then the outputs also will not be Gaussian. Unfortunately, NT doesn't support explicitly setting is_gaussian=True in the kernel_fn (since the inputs to the network are assumed to be constants rather than gaussian random variables). One way to solve this problem is to add a single dense layer at the top of your network.
  2. There were some inconsistencies in setting channel_axis. In particular, you also had to set channel_axis=1 in the Dense layers in DenseBlock(..) and ReluNetwork(..). When this was done there was a shape error where the number of channels was a bit different between the two branches. To solve this I ended up setting the initial dense layer to project down to the latent dimension, but I'm not sure whether this was what you were going for.

In any case, here is a version of the code that should work. Let me know if you run into any trouble!

%pdb on

from jax import random
from neural_tangents import stax
import jax.numpy as np
import neural_tangents as nt

def DenseBlock(neurons):
    return stax.serial(
        stax.Dense(neurons, channel_axis=1), stax.Relu()
    )

def ReluNetwork(latent_dim, hidden_dim, num_layers):
    """Create the network which is embedd in flow_base model
    
    Args:
        latent_dim: input and output dim
        hidden_dim: the width dim of the ReluNetwork
        num_layers: depth of the ReluNetwork
    
    Returns:
        stax.serial(ReluNetwork)
    """
    blocks = [DenseBlock(hidden_dim)]
    for _ in range(num_layers):
        blocks += [DenseBlock(hidden_dim)]
    blocks += [stax.Dense(latent_dim, channel_axis=1)]
    
    return stax.serial(*blocks)

def lower_path(input_dim):
    helf_dim = input_dim//2
    # pre_half's rhs
    rhs1 = np.identity(helf_dim)
    rhs1 = np.pad(rhs1, ((0, 0), (0, helf_dim)))
    rhs1 = np.reshape(rhs1, (*rhs1.shape, 1))
    
    # post_half's rhs
    rhs2 = np.identity(helf_dim)
    rhs2 = np.pad(rhs2, ((helf_dim, 0), (helf_dim, 0)))
    rhs2 = np.reshape(rhs2, (*rhs2.shape, 1))
    
    
    rhs4 = np.identity(helf_dim)
    rhs4 = np.pad(rhs4, ((helf_dim, 0), (0, 0)))
    rhs4 = np.reshape(rhs4, (*rhs4.shape, 1))
    
    blocks = [
        stax.DotGeneral(
                rhs = rhs1,
                dimension_numbers = (((2,), (1,)), ((), ())),
                channel_axis = 1
            ), 
        stax.DotGeneral(
                rhs = np.array([1]),
                dimension_numbers = (((3,), (0,)), ((), ())),
                channel_axis = 1
            )]
    blocks += [ReluNetwork(latent_dim=helf_dim, hidden_dim=helf_dim//2, num_layers=4)]
    blocks += [
        stax.DotGeneral(
                rhs = rhs4,
                dimension_numbers = (((2,), (1,)), ((), ())),
                channel_axis = 1
            ), 
        stax.DotGeneral(
                rhs = np.array([1]),
                dimension_numbers = (((3,), (0,)), ((), ())),
                channel_axis = 1
            )]
    
    
    pre_half = stax.serial(
        *blocks
    )

    post_half = stax.serial(
        stax.DotGeneral(
                rhs = rhs2,
                dimension_numbers = (((2,), (1,)), ((), ())),
                channel_axis = 1
            ), 
        stax.DotGeneral(
                rhs = np.array([1]),
                dimension_numbers = (((3,), (0,)), ((), ())),
                channel_axis = 1
            )
    )
    return stax.serial(stax.Dense(helf_dim, channel_axis=1),
                       stax.FanOut(2),
                       stax.parallel(pre_half, post_half),
                       stax.FanInSum()
                      )

def AdditiveCouplingLayer(input_dim, order):
    """the additive couplinglayer in the paper
    
    Args:
        nonlinearity: the ReluNetwork
    
    Returns:
        stax.serial(AdditiveCouplingLayer)
    """
    helf_dim = input_dim//2
    
    rhs_matrix = np.identity(helf_dim)
    rhs_matrix = np.pad(rhs_matrix, ((0, helf_dim), (0, helf_dim)))
    rhs_matrix = np.reshape(rhs_matrix, (*rhs_matrix.shape, 1))

    upper_path = stax.serial(
        stax.DotGeneral(
                rhs = rhs_matrix,
                dimension_numbers = (((2,), (1,)), ((), ())),
                channel_axis = 1
            ), 
        stax.DotGeneral(
                rhs = np.array([1]),
                dimension_numbers = (((3,), (0,)), ((), ())),
                channel_axis = 1
            )
    )
    
    return stax.serial(stax.FanOut(2),
                       stax.parallel(upper_path, lower_path(input_dim)),
                       stax.FanInSum()
                      )
def LogisticPriorLoss(fx, y):
    return np.mean((0.5*np.sum(np.power(fx, 2), axis=1) + fx.shape[1]*0.5*np.log(2*np.pi)))

# test
x = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]])
x = np.reshape(x, (x.shape[0], 1, *x.shape[1:]))
input_dim = x.shape[2]  # (B, 1, 4): B is batch size
helf_dim = input_dim//2

init_fn, apply_fn, kernel_fn = lower_path(input_dim=input_dim)

key = random.PRNGKey(1)
_, params = init_fn(key, input_shape=x.shape)

# z_train.dim = x_train.dim
z_train = random.normal(key, x.shape)
x_test = np.array([[1, 2, 3, 4, 5, 6]])
x_test = np.reshape(x_test, (x_test.shape[0], 1, *x_test.shape[1:]))

ntk_train_train = kernel_fn(x, x, 'ntk', channel_axis=1, is_gaussian=True)
ntk_test_train = kernel_fn(x_test, x, 'ntk')
predictor = nt.predict.gradient_descent(LogisticPriorLoss, ntk_train_train, z_train)

sschoenholz avatar Jul 28 '21 00:07 sschoenholz

Hello @sschoenholz @romanngg Thanks for your kindly reply!!

But I still have some questions. Does adding Dense layer is a only solution for this problem? Since if I add Dense, the output and output.shape will change.

Base on optimizer.sgd method, the result may seems like

image

Above image shows that my architecture only change bottom half Your reply makes large progression to me

kkeevin123456 avatar Jul 28 '21 13:07 kkeevin123456

Follow up

I observed one weird thing:

  • upper_path in AdditiveCouplingLayer also doesn't have any Dense layer, why this can work normally

kkeevin123456 avatar Jul 29 '21 11:07 kkeevin123456

Sorry for the long delay, a few more observations:

  • AdditiveCouplingLayer is not called in the above code sample (only lower_path), so I imagine it would have the same problem.
  • In lower_path, which is tested above, you already have a dense layer in pre_half, but no dense layer in post_half. So you could either add a common dense layer as @sschoenholz suggested, or you could also add a dense layer only somewhere in post_half instead, e.g.
    post_half = stax.serial(
        stax.DotGeneral(
                rhs = rhs2,
                dimension_numbers = (((2,), (1,)), ((), ())),
                channel_axis = 1
            ), 
        stax.DotGeneral(
                rhs = np.array([1]),
                dimension_numbers = (((3,), (0,)), ((), ())),
                channel_axis = 1
            ),
        stax.Dense(helf_dim, channel_axis=1),
    )
  • Note that in the case above, Dense layer will not change the output shape if you set the out_dim = helf_dim equal to the number of channels in your input image; it will not affect the pixel structure; but as you've mentioned it will indeed change the outputs themselves). However, out_dim must be equal to helf_dim, which is the number of channels output by pre_half - otherwise you would be asking FanInSum to add arrays of different shapes, and I think it only worked above because the input channels have size 1 and it was silently broadcasted to helf_dim when combined.

  • Similarly, in the infinite-width limit, it's not clear how to avoid adding Dense, since FanInSum must add arrays of the same shape. In the infinite width limit, which is invoked when you call kernel_fn, pre_half will output infinite-dimensional arrays along the channel axis, since it has dense layers. post_half, without dense layers, will output finite-dimensional arrays, having as many channels along the channel axis as the input image. Therefore arguably the sum of the two is not well-defined.

  • In any case, for NTK, if your post_half branch doesn't contain any trainable parameters, it should not influence the NTK (i.e. if f(params, x) = pre_half(params, x) + post_half(x), then NTK(f)(x1, x2, params) = NTK(pre_half)(x1, x2, params)), so IIUC as a workaround you could just compute the NTK of pre_half in separation. Alternatively, you could also use nt.empirical_kernel_fn to compute the empirical NTK, which should work for any function/architecture.

Lmk if this helps!

romanngg avatar Aug 17 '21 23:08 romanngg