neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

Excessive memory consumption for deep networks

Open jglaser opened this issue 2 years ago • 2 comments

The LLVM compiler pass uses excessive amounts of memory for deep networks which are constructed like this

stax.serial([my_layer]*depth)

In fact, the compilation may eventually OOM.

The reason is that the serial combinator internally relies on a python for loop (with carry) to support mixed input sequences.

It would be nice to have a specialization for the case in which the same layer is repeated n times, which could then use jax.lax.scan() to save compilation time by avoiding loop unrolling.

Suggestion:

import jax.example_libraries.stax as ostax
from neural_tangents._src.utils.typing import Layer, InternalLayer, NTTree
from neural_tangents._src.stax.requirements import get_req, requires, layer
from neural_tangents._src.utils.kernel import Kernel
from jax.lax import scan
import jax.numpy as np

@layer
def repeat(layer: Layer, n: int) -> InternalLayer:
  """Combinator for repeating the same layers `n` times.

  Based on :obj:`jax.example_libraries.stax.serial`.

  Args:
    layer:
      a single layer, each an `(init_fn, apply_fn, kernel_fn)` triple.

    n:
      the number of iterations

  Returns:
    A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple,
    representing the composition of `n` layers.
  """
  init_fn, apply_fn, kernel_fn = layer

  init_fn, apply_fn = ostax.serial(*zip([init_fn] * n, [apply_fn] * n))
  @requires(**get_req(kernel_fn))
  def kernel_fn_scan(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
    # TODO(xlc): if we drop `x1_is_x2` and use `rng` instead, need split key
    # inside kernel functions here and parallel below.
    k, _ = scan(lambda carry, x: (kernel_fn(carry, **kwargs), None), k, np.arange(n))
    return k

  return init_fn, apply_fn, kernel_fn_scan

Use like this

repeat(my_layer, depth)

jglaser avatar Nov 06 '22 22:11 jglaser

Thanks for the suggestion, please check out the added https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.repeat.html#neural_tangents.stax.repeat

One caveat that makes this less elegant than we'd like is that kernel_fn sometimes makes non-jittable changes to the metadata of the Kernel object, and when this happens, lax.scan fails (see especially second warning), so unfortunately for now it's less flexible than stax.serial.

romanngg avatar Dec 06 '22 05:12 romanngg

awesome, thanks!

jglaser avatar Dec 06 '22 16:12 jglaser