neural-tangents
neural-tangents copied to clipboard
Draw Phase Diagram for CNTK
I'm curious about the initialization for CNTK, so I replace the kernel_fn
in c_map(W_var, b_var)
function in colab with:
# Create a single layer of a network as an affine transformation composed
# with an Erf nonlinearity.
# kernel_fn = stax.serial(stax.Dense(1024, W_std, b_std), stax.Erf())[2]
kernel_fn = stax.serial(
stax.Conv(out_chan=1024, filter_shape=(3, 3), strides=None, padding='SAME', W_std=W_std, b_std=b_std),
stax.Relu(),
stax.Flatten(),
stax.Dense(10, W_std=W_std, b_std=b_std, parameterization='ntk')
)[2]
However, it seems that there's a bottom layer error when I tried to plot, with the error msg as follow:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_54123/684423119.py in <module>
----> 1 plt.contourf(W_var, b_var, c_star(W_var, b_var))
2 plt.colorbar()
3 plt.title('$C^*$ as a function of weight and bias variance', fontsize=14)
4
5 format_plot('$\\sigma_w^2$', '$\\sigma_b^2$')
[... skipping hidden 18 frame]
/tmp/ipykernel_54123/2333898834.py in <lambda>(W_var, b_var)
51 return c_map_fn
52
---> 53 c_star = lambda W_var, b_var: fixed_point(c_map(W_var, b_var), 0.1, 1e-7)
54 chi = lambda c, W_var, b_var: grad(c_map(W_var, b_var))(c)
55 chi_1 = partial(chi, 1.)
/tmp/ipykernel_54123/2333898834.py in c_map(W_var, b_var)
42 return kernel_fn(Kernel(np.array([[q]]))).nngp[0, 0]
43
---> 44 qstar = fixed_point(q_map_fn, 1.0, 1e-7)
45
46 def c_map_fn(c):
/tmp/ipykernel_54123/3420146269.py in fixed_point(f, initial_value, threshold)
38 return x - g(x) / dg(x), x
39
---> 40 return lax.while_loop(cond_fn, body_fn, (initial_value, 0.0))[0]
[... skipping hidden 12 frame]
/tmp/ipykernel_54123/3420146269.py in body_fn(x)
36 def body_fn(x):
37 x, _ = x
---> 38 return x - g(x) / dg(x), x
39
40 return lax.while_loop(cond_fn, body_fn, (initial_value, 0.0))[0]
/tmp/ipykernel_54123/3420146269.py in <lambda>(x)
27 def fixed_point(f, initial_value, threshold):
28 """Find fixed-points of a function f:R->R using Newton's method."""
---> 29 g = lambda x: f(x) - x
30 dg = grad(g)
31
/tmp/ipykernel_54123/2333898834.py in q_map_fn(q)
40 def q_map_fn(q):
41 print(q)
---> 42 return kernel_fn(Kernel(np.array([[q]]))).nngp[0, 0]
43
44 qstar = fixed_point(q_map_fn, 1.0, 1e-7)
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
174 @functools.wraps(f)
175 def h(*args, **kwargs):
--> 176 return g(*args, **kwargs)
177
178 h.__signature__ = inspect.signature(f)
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in getter_fn(*args, **kwargs)
208 len(args)])
209
--> 210 fn_out = fn(*canonicalized_args, **kwargs)
211
212 @nt_tree_fn()
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_any(x1_or_kernel, x2, get, pattern, mask_constant, diagonal_batch, diagonal_spatial, **kwargs)
4293 """
4294 if utils.is_nt_tree_of(x1_or_kernel, Kernel):
-> 4295 return kernel_fn_kernel(x1_or_kernel,
4296 pattern=pattern,
4297 diagonal_batch=diagonal_batch,
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_kernel(kernel, **kwargs)
4212
4213 def kernel_fn_kernel(kernel, **kwargs):
-> 4214 out_kernel = kernel_fn(kernel, **kwargs)
4215 return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs)
4216
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
174 @functools.wraps(f)
175 def h(*args, **kwargs):
--> 176 return g(*args, **kwargs)
177
178 h.__signature__ = inspect.signature(f)
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in new_kernel_fn(k, **kwargs)
191 pass
192
--> 193 return kernel_fn(k, **kwargs)
194
195 setattr(new_kernel_fn, _INPUT_REQ, frozendict.frozendict(static_reqs))
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn(k, **kwargs)
325 # inside kernel functions here and parallel below.
326 for f in kernel_fns:
--> 327 k = f(k, **kwargs)
328 return k
329
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
174 @functools.wraps(f)
175 def h(*args, **kwargs):
--> 176 return g(*args, **kwargs)
177
178 h.__signature__ = inspect.signature(f)
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in getter_fn(*args, **kwargs)
208 len(args)])
209
--> 210 fn_out = fn(*canonicalized_args, **kwargs)
211
212 @nt_tree_fn()
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_any(x1_or_kernel, x2, get, pattern, mask_constant, diagonal_batch, diagonal_spatial, **kwargs)
4293 """
4294 if utils.is_nt_tree_of(x1_or_kernel, Kernel):
-> 4295 return kernel_fn_kernel(x1_or_kernel,
4296 pattern=pattern,
4297 diagonal_batch=diagonal_batch,
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_kernel(kernel, **kwargs)
4212
4213 def kernel_fn_kernel(kernel, **kwargs):
-> 4214 out_kernel = kernel_fn(kernel, **kwargs)
4215 return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs)
4216
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_with_masking(k, **user_reqs)
277 mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
278
--> 279 k = kernel_fn(k, **user_reqs) # type: Kernel
280
281 if remask_kernel:
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
174 @functools.wraps(f)
175 def h(*args, **kwargs):
--> 176 return g(*args, **kwargs)
177
178 h.__signature__ = inspect.signature(f)
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in new_kernel_fn(k, **kwargs)
191 pass
192
--> 193 return kernel_fn(k, **kwargs)
194
195 setattr(new_kernel_fn, _INPUT_REQ, frozendict.frozendict(static_reqs))
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn(k, **kwargs)
1506 return out
1507
-> 1508 cov1 = conv(cov1, 1 if k.diagonal_batch else 2)
1509 cov2 = conv(cov2, 1 if k.diagonal_batch else 2)
1510
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in conv(lhs, batch_ndim)
1502
1503 def conv(lhs, batch_ndim):
-> 1504 out = conv_unscaled(lhs, batch_ndim)
1505 out = affine(out, W_std**2, b_std**2, batch_ndim)
1506 return out
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in conv_unscaled(lhs, batch_ndim)
1477
1478 def conv_unscaled(lhs, batch_ndim):
-> 1479 lhs = conv_kernel(lhs,
1480 filter_shape_kernel,
1481 strides_kernel,
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in _conv_kernel_full_spatial_shared(lhs, filter_shape, strides, padding, batch_ndim)
4759 return n_channels
4760
-> 4761 out = _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding,
4762 lax_conv, get_n_channels)
4763 return out
~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding, lax_conv, get_n_channels)
4912 spatial_i = (i - batch_ndim) // 2
4913
-> 4914 lhs = np.moveaxis(lhs, (i - 1, i), (-2, -1))
4915 preshape = lhs.shape[:-2]
4916 n_channels = get_n_channels(utils.size_at(preshape))
~/jax/jax/_src/numpy/lax_numpy.py in moveaxis(a, source, destination)
1535 destination_axes = tuple(cast(Sequence[int], destination))
1536 source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes)
-> 1537 destination_axes = tuple(_canonicalize_axis(i, ndim(a))
1538 for i in destination_axes)
1539 if len(source_axes) != len(destination_axes):
~/jax/jax/_src/numpy/lax_numpy.py in <genexpr>(.0)
1535 destination_axes = tuple(cast(Sequence[int], destination))
1536 source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes)
-> 1537 destination_axes = tuple(_canonicalize_axis(i, ndim(a))
1538 for i in destination_axes)
1539 if len(source_axes) != len(destination_axes):
~/jax/jax/_src/util.py in canonicalize_axis(axis, num_dims)
275 axis = operator.index(axis)
276 if not -num_dims <= axis < num_dims:
--> 277 raise ValueError(
278 "axis {} is out of bounds for array of dimension {}".format(
279 axis, num_dims))
ValueError: axis -2 is out of bounds for array of dimension 1
Is there any misunderstanding of me to the Phase Diagram? (Is CNTK fundamentally un-drawn-able? Also, I've also found that there's totally no difference in Phase Diagram when I simply deeper an FC network, e.g.
def DenseGroup(n, neurons, W_std, b_std):
blocks = []
for _ in range(n):
blocks += [stax.Dense(neurons, W_std, b_std), stax.Erf()]
return stax.serial(*blocks)
for layer in range(1,11):
def c_map(W_var, b_var):
...
kernel_fn = stax.serial(DenseGroup(layer, 1024, W_std, b_std))[2]
...
c_star = lambda W_var, b_var: fixed_point(c_map(W_var, b_var), 0.1, 1e-7)
chi = lambda c, W_var, b_var: grad(c_map(W_var, b_var))(c)
chi_1 = partial(chi, 1.)
c_star = jit(vectorize_over_sw_sb(c_star))
chi_1 = jit(vectorize_over_sw_sb(chi_1))
plt.contourf(W_var, b_var, c_star(W_var, b_var))
Does it mean that the depth of NNs won't affect the initialization?
@SiuMath and @sschoenholz may answer better, but I can give some brief comments:
-
Re changing the depth, your observation is correct. $C*$ diagram shows the fixed point correlation, i.e. the limiting correlation value $c^*$ in the infinite-depth limit, so it shouldn't matter if you repeat 1 or 3 identical layers infinitely-many times. The $\chi$ plot will change, but note that by definition and the chain rule, the $\chi$ for $n$ identical layers will be equal to $\chi$ for one layer to the power of $n$, so the phase boundary where $\chi = 1$ will remain the same.
-
I imagine the code could be generalized to CNNs, but it would need to support vector-valued variances $q$ and covariances $c$ (for spatial locations), so may need some work. Note that per https://arxiv.org/abs/1806.05393 for standard/ntk parameterization and
CIRCULAR
padding, it should yield the same phase diagram as for the fully-connected network. In Figure 11 of https://arxiv.org/abs/1810.05148 we've run some experiments withSAME
padding, and obtained a reasonable agreement too.
One other comment, this notebook relies on being able to determine a fixed-point variance $q*$, which does not always exist for ReLU that you used in your example (for weight variance above 2, no stable non-zero variance exists, and it explodes in the infinite-depth limit), so ReLU nonlinearity won't work in the notebook at the moment, even for FCNs. But you can find the ReLU phase diagram in Figure 4 (b) of https://arxiv.org/abs/1711.00165.
Finally, note that these diagrams study the forward propagation of the signal, so they are only working with the CNN-GP kernel (and not the CNTK), so the parameterization
argument should have no impact on them.
Hope this helps!