neural-tangents
neural-tangents copied to clipboard
Excessive memory consumption for deep networks
The LLVM compiler pass uses excessive amounts of memory for deep networks which are constructed like this
stax.serial([my_layer]*depth)
In fact, the compilation may eventually OOM.
The reason is that the serial combinator internally relies on a python for loop (with carry) to support mixed input sequences.
It would be nice to have a specialization for the case in which the same layer is repeated n times, which could then use jax.lax.scan() to save compilation time by avoiding loop unrolling.
Suggestion:
import jax.example_libraries.stax as ostax
from neural_tangents._src.utils.typing import Layer, InternalLayer, NTTree
from neural_tangents._src.stax.requirements import get_req, requires, layer
from neural_tangents._src.utils.kernel import Kernel
from jax.lax import scan
import jax.numpy as np
@layer
def repeat(layer: Layer, n: int) -> InternalLayer:
"""Combinator for repeating the same layers `n` times.
Based on :obj:`jax.example_libraries.stax.serial`.
Args:
layer:
a single layer, each an `(init_fn, apply_fn, kernel_fn)` triple.
n:
the number of iterations
Returns:
A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple,
representing the composition of `n` layers.
"""
init_fn, apply_fn, kernel_fn = layer
init_fn, apply_fn = ostax.serial(*zip([init_fn] * n, [apply_fn] * n))
@requires(**get_req(kernel_fn))
def kernel_fn_scan(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
# TODO(xlc): if we drop `x1_is_x2` and use `rng` instead, need split key
# inside kernel functions here and parallel below.
k, _ = scan(lambda carry, x: (kernel_fn(carry, **kwargs), None), k, np.arange(n))
return k
return init_fn, apply_fn, kernel_fn_scan
Use like this
repeat(my_layer, depth)
Thanks for the suggestion, please check out the added https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.repeat.html#neural_tangents.stax.repeat
One caveat that makes this less elegant than we'd like is that kernel_fn sometimes makes non-jittable changes to the metadata of the Kernel object, and when this happens, lax.scan fails (see especially second warning), so unfortunately for now it's less flexible than stax.serial.
awesome, thanks!