neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

Draw Phase Diagram for CNTK

Open hhorace opened this issue 1 year ago • 1 comments

I'm curious about the initialization for CNTK, so I replace the kernel_fn in c_map(W_var, b_var) function in colab with:

# Create a single layer of a network as an affine transformation composed
# with an Erf nonlinearity.
# kernel_fn = stax.serial(stax.Dense(1024, W_std, b_std), stax.Erf())[2]
kernel_fn = stax.serial(
      stax.Conv(out_chan=1024, filter_shape=(3, 3), strides=None, padding='SAME', W_std=W_std, b_std=b_std),
      stax.Relu(),
      stax.Flatten(),
      stax.Dense(10, W_std=W_std, b_std=b_std, parameterization='ntk')
)[2]

However, it seems that there's a bottom layer error when I tried to plot, with the error msg as follow:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_54123/684423119.py in <module>
----> 1 plt.contourf(W_var, b_var, c_star(W_var, b_var))
      2 plt.colorbar()
      3 plt.title('$C^*$ as a function of weight and bias variance', fontsize=14)
      4 
      5 format_plot('$\\sigma_w^2$', '$\\sigma_b^2$')

    [... skipping hidden 18 frame]

/tmp/ipykernel_54123/2333898834.py in <lambda>(W_var, b_var)
     51   return c_map_fn
     52 
---> 53 c_star = lambda W_var, b_var: fixed_point(c_map(W_var, b_var), 0.1, 1e-7)
     54 chi = lambda c, W_var, b_var: grad(c_map(W_var, b_var))(c)
     55 chi_1 = partial(chi, 1.)

/tmp/ipykernel_54123/2333898834.py in c_map(W_var, b_var)
     42     return kernel_fn(Kernel(np.array([[q]]))).nngp[0, 0]
     43 
---> 44   qstar = fixed_point(q_map_fn, 1.0, 1e-7)
     45 
     46   def c_map_fn(c):

/tmp/ipykernel_54123/3420146269.py in fixed_point(f, initial_value, threshold)
     38     return x - g(x) / dg(x), x
     39 
---> 40   return lax.while_loop(cond_fn, body_fn, (initial_value, 0.0))[0]

    [... skipping hidden 12 frame]

/tmp/ipykernel_54123/3420146269.py in body_fn(x)
     36   def body_fn(x):
     37     x, _ = x
---> 38     return x - g(x) / dg(x), x
     39 
     40   return lax.while_loop(cond_fn, body_fn, (initial_value, 0.0))[0]

/tmp/ipykernel_54123/3420146269.py in <lambda>(x)
     27 def fixed_point(f, initial_value, threshold):
     28   """Find fixed-points of a function f:R->R using Newton's method."""
---> 29   g = lambda x: f(x) - x
     30   dg = grad(g)
     31 

/tmp/ipykernel_54123/2333898834.py in q_map_fn(q)
     40   def q_map_fn(q):
     41     print(q)
---> 42     return kernel_fn(Kernel(np.array([[q]]))).nngp[0, 0]
     43 
     44   qstar = fixed_point(q_map_fn, 1.0, 1e-7)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in getter_fn(*args, **kwargs)
    208                                                           len(args)])
    209 
--> 210       fn_out = fn(*canonicalized_args, **kwargs)
    211 
    212       @nt_tree_fn()

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_any(x1_or_kernel, x2, get, pattern, mask_constant, diagonal_batch, diagonal_spatial, **kwargs)
   4293     """
   4294     if utils.is_nt_tree_of(x1_or_kernel, Kernel):
-> 4295       return kernel_fn_kernel(x1_or_kernel,
   4296                               pattern=pattern,
   4297                               diagonal_batch=diagonal_batch,

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_kernel(kernel, **kwargs)
   4212 
   4213   def kernel_fn_kernel(kernel, **kwargs):
-> 4214     out_kernel = kernel_fn(kernel, **kwargs)
   4215     return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs)
   4216 

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in new_kernel_fn(k, **kwargs)
    191               pass
    192 
--> 193       return kernel_fn(k, **kwargs)
    194 
    195     setattr(new_kernel_fn, _INPUT_REQ, frozendict.frozendict(static_reqs))

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn(k, **kwargs)
    325     # inside kernel functions here and parallel below.
    326     for f in kernel_fns:
--> 327       k = f(k, **kwargs)
    328     return k
    329 

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in getter_fn(*args, **kwargs)
    208                                                           len(args)])
    209 
--> 210       fn_out = fn(*canonicalized_args, **kwargs)
    211 
    212       @nt_tree_fn()

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_any(x1_or_kernel, x2, get, pattern, mask_constant, diagonal_batch, diagonal_spatial, **kwargs)
   4293     """
   4294     if utils.is_nt_tree_of(x1_or_kernel, Kernel):
-> 4295       return kernel_fn_kernel(x1_or_kernel,
   4296                               pattern=pattern,
   4297                               diagonal_batch=diagonal_batch,

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_kernel(kernel, **kwargs)
   4212 
   4213   def kernel_fn_kernel(kernel, **kwargs):
-> 4214     out_kernel = kernel_fn(kernel, **kwargs)
   4215     return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs)
   4216 

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_with_masking(k, **user_reqs)
    277         mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
    278 
--> 279         k = kernel_fn(k, **user_reqs)  # type: Kernel
    280 
    281         if remask_kernel:

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in new_kernel_fn(k, **kwargs)
    191               pass
    192 
--> 193       return kernel_fn(k, **kwargs)
    194 
    195     setattr(new_kernel_fn, _INPUT_REQ, frozendict.frozendict(static_reqs))

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn(k, **kwargs)
   1506       return out
   1507 
-> 1508     cov1 = conv(cov1, 1 if k.diagonal_batch else 2)
   1509     cov2 = conv(cov2, 1 if k.diagonal_batch else 2)
   1510 

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in conv(lhs, batch_ndim)
   1502 
   1503     def conv(lhs, batch_ndim):
-> 1504       out = conv_unscaled(lhs, batch_ndim)
   1505       out = affine(out, W_std**2, b_std**2, batch_ndim)
   1506       return out

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in conv_unscaled(lhs, batch_ndim)
   1477 
   1478     def conv_unscaled(lhs, batch_ndim):
-> 1479       lhs = conv_kernel(lhs,
   1480                         filter_shape_kernel,
   1481                         strides_kernel,

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in _conv_kernel_full_spatial_shared(lhs, filter_shape, strides, padding, batch_ndim)
   4759     return n_channels
   4760 
-> 4761   out = _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding,
   4762                                        lax_conv, get_n_channels)
   4763   return out

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding, lax_conv, get_n_channels)
   4912     spatial_i = (i - batch_ndim) // 2
   4913 
-> 4914     lhs = np.moveaxis(lhs, (i - 1, i), (-2, -1))
   4915     preshape = lhs.shape[:-2]
   4916     n_channels = get_n_channels(utils.size_at(preshape))

~/jax/jax/_src/numpy/lax_numpy.py in moveaxis(a, source, destination)
   1535     destination_axes = tuple(cast(Sequence[int], destination))
   1536   source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes)
-> 1537   destination_axes = tuple(_canonicalize_axis(i, ndim(a))
   1538                            for i in destination_axes)
   1539   if len(source_axes) != len(destination_axes):

~/jax/jax/_src/numpy/lax_numpy.py in <genexpr>(.0)
   1535     destination_axes = tuple(cast(Sequence[int], destination))
   1536   source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes)
-> 1537   destination_axes = tuple(_canonicalize_axis(i, ndim(a))
   1538                            for i in destination_axes)
   1539   if len(source_axes) != len(destination_axes):

~/jax/jax/_src/util.py in canonicalize_axis(axis, num_dims)
    275   axis = operator.index(axis)
    276   if not -num_dims <= axis < num_dims:
--> 277     raise ValueError(
    278         "axis {} is out of bounds for array of dimension {}".format(
    279             axis, num_dims))

ValueError: axis -2 is out of bounds for array of dimension 1

Is there any misunderstanding of me to the Phase Diagram? (Is CNTK fundamentally un-drawn-able? Also, I've also found that there's totally no difference in Phase Diagram when I simply deeper an FC network, e.g.

def DenseGroup(n, neurons, W_std, b_std):
    blocks = []
    for _ in range(n):
        blocks += [stax.Dense(neurons, W_std, b_std), stax.Erf()]
    return stax.serial(*blocks)

for layer in range(1,11):
    def c_map(W_var, b_var):
        ...
        kernel_fn = stax.serial(DenseGroup(layer, 1024, W_std, b_std))[2]
        ...
    c_star = lambda W_var, b_var: fixed_point(c_map(W_var, b_var), 0.1, 1e-7)
    chi = lambda c, W_var, b_var: grad(c_map(W_var, b_var))(c)
    chi_1 = partial(chi, 1.)
    
    c_star = jit(vectorize_over_sw_sb(c_star))
    chi_1 = jit(vectorize_over_sw_sb(chi_1))

    plt.contourf(W_var, b_var, c_star(W_var, b_var))

Does it mean that the depth of NNs won't affect the initialization?

hhorace avatar Sep 02 '22 18:09 hhorace

@SiuMath and @sschoenholz may answer better, but I can give some brief comments:

  1. Re changing the depth, your observation is correct. $C*$ diagram shows the fixed point correlation, i.e. the limiting correlation value $c^*$ in the infinite-depth limit, so it shouldn't matter if you repeat 1 or 3 identical layers infinitely-many times. The $\chi$ plot will change, but note that by definition and the chain rule, the $\chi$ for $n$ identical layers will be equal to $\chi$ for one layer to the power of $n$, so the phase boundary where $\chi = 1$ will remain the same.

  2. I imagine the code could be generalized to CNNs, but it would need to support vector-valued variances $q$ and covariances $c$ (for spatial locations), so may need some work. Note that per https://arxiv.org/abs/1806.05393 for standard/ntk parameterization and CIRCULAR padding, it should yield the same phase diagram as for the fully-connected network. In Figure 11 of https://arxiv.org/abs/1810.05148 we've run some experiments with SAME padding, and obtained a reasonable agreement too.

One other comment, this notebook relies on being able to determine a fixed-point variance $q*$, which does not always exist for ReLU that you used in your example (for weight variance above 2, no stable non-zero variance exists, and it explodes in the infinite-depth limit), so ReLU nonlinearity won't work in the notebook at the moment, even for FCNs. But you can find the ReLU phase diagram in Figure 4 (b) of https://arxiv.org/abs/1711.00165.

Finally, note that these diagrams study the forward propagation of the signal, so they are only working with the CNN-GP kernel (and not the CNTK), so the parameterization argument should have no impact on them.

Hope this helps!

romanngg avatar Sep 04 '22 20:09 romanngg