neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

Questions about memory consumption of infinitely wide NTK

Open jasonli0707 opened this issue 1 year ago • 6 comments

I am working on a simple MNIST example. I found that I could not compute the NTK for the entire dataset without running out of memory. Below is the code snippet I used:

import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from jax import random, jit
import jax.numpy as jnp

def FC(depth=1, num_classes=10, W_std=1.0, b_std=0.0):
    layers = [stax.Flatten()]
    for _ in range(depth):
        layers += [stax.Dense(1, W_std, b_std), stax.Relu()]
    layers += [stax.Dense(num_classes, W_std, b_std)]
    return stax.serial(*layers)

x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', data_dir="./data", permute_train=True)

key = random.PRNGKey(0)
init_fn, apply_fn, kernel_fn = FC()
_, params = init_fn(key, (-1, 784))

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2,))

batched_kernel_fn = nt.batch(kernel_fn, 1000, store_on_device=False)

k_train_train = kernel_fn(x_train, None, 'ntk')
k_test_train = kernel_fn(x_test, x_train, 'ntk')
predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train)
fx_train_0 = apply_fn(params, x_train)
fx_test_0 = apply_fn(params, x_test)
fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)

I am running this on two RTX3090 each having a 24Gb buffer. Is there something I'm doing wrong, or is it normal for NTK to consume so much memory? Thank you!

jasonli0707 avatar Sep 07 '22 07:09 jasonli0707

Thanks for the report, your code correct!

This actually looks like two bugs on our side:

  1. store_on_device argument isn't working, and the kernel is stored on the GPU (I'm assuming you have enough CPU RAM, so you're not running out of it).
  2. even if store_on_device=True, 24Gb of GPU RAM should be enough for the 50k x 50k kernel, but somehow it's not. I suspect there might be a conflict with JAX and Tensorflow competing for GPU memory, could you try running this version of the code on your machine? https://colab.research.google.com/gist/romanngg/96421af87f4cc1e13a78454d8bfb4ee9/memory_repro.ipynb The part that hopefully helps is
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
import tensorflow_datasets as tfds

(and I'm using tfds instead of neural_tangents.examples)

Another idea is to binary search smaller training set sizes to figure out if we're really hitting the memory limit (e.g. it works for 40K, but not 50K), or if the GPU memory is just not available for some reason (e.g. it doesn't work even for 5K).

Also, could you please post the whole error message?

romanngg avatar Sep 07 '22 19:09 romanngg

Thank you so much for the detailed reply!

I have tried your code but still face the same issue. Below shows the complete error message for your reference:

2022-09-08 13:20:36.044808: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.31GiB (rounded to 10000000000)requested by op 2022-09-08 13:20:36.044942: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ***************************************************************************************************_ 2022-09-08 13:20:36.045005: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 10000000000 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 9.31GiB constant allocation: 0B maybe_live_out allocation: 9.31GiB preallocated temp allocation: 0B total allocation: 18.63GiB total fragmentation: 0B (0.00%) Peak buffers: Buffer 1: Size: 9.31GiB Entry Parameter Subshape: s32[50000,50000]

    Buffer 2:
            Size: 9.31GiB
            Operator: op_name="jit(add)/jit(main)/add" source_file="/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py" source_line=1222
            XLA Label: fusion
            Shape: s32[50000,50000]
            

    Buffer 3:
            Size: 4B
            Entry Parameter Subshape: s32[]
            

Traceback (most recent call last): File "mnist.py", line 68, in fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train) File "/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py", line 270, in predict_fn return get_predict_fn_inf()(fx_train_0, fx_test_0, k_test_train) File "/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py", line 163, in get_predict_fn_inf solve = _get_cho_solve(k_train_train, diag_reg, diag_reg_absolute_scale) File "/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py", line 1232, in _get_cho_solve A = _add_diagonal_regularizer(A, diag_reg, diag_reg_absolute_scale) File "/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py", line 1222, in _add_diagonal_regularizer return A + diag_reg * np.eye(dimension) File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2103, in eye return lax_internal._eye(_jnp_dtype(dtype), (N, M), k) File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 1203, in _eye bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)), File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 444, in add return add_p.bind(x, y) File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/core.py", line 325, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/core.py", line 328, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/core.py", line 686, in process_primitive return primitive.impl(*tracers, **params) File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/dispatch.py", line 113, in apply_primitive return compiled_fun(*args) File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/dispatch.py", line 198, in return lambda *args, **kw: compiled(*args, **kw)[0] File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/dispatch.py", line 837, in _execute_compiled out_flat = compiled.execute(in_flat) jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 10000000000 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 9.31GiB constant allocation: 0B maybe_live_out allocation: 9.31GiB preallocated temp allocation: 0B total allocation: 18.63GiB total fragmentation: 0B (0.00%) Peak buffers: Buffer 1: Size: 9.31GiB Entry Parameter Subshape: s32[50000,50000]

    Buffer 2:
            Size: 9.31GiB
            Operator: op_name="jit(add)/jit(main)/add" source_file="/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py" source_line=1222
            XLA Label: fusion
            Shape: s32[50000,50000]
            

    Buffer 3:
            Size: 4B
            Entry Parameter Subshape: s32[]
            

jasonli0707 avatar Sep 08 '22 05:09 jasonli0707

I have also tried searching for the maximum number of samples before encountering the memory issue, which turned out to be 36000 in my case:

num_samples = 36000
x_train = x_train[:num_samples]
y_train = y_train[:num_samples]

jasonli0707 avatar Sep 08 '22 05:09 jasonli0707

Oh thanks for the error message, I realized what's actually failing is

fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)

and not the kernel computation. Indeed 24Gb is not enough to run the Cholesky solver on the 50k x 50k matrix, so you'd need to be doing it on CPU.

To make it happen on CPU, I think the easiest way should be to have predict_fn = jit(predict_fn, backend='cpu') after you define it (and good to jit this function anyway).

Alternatively, but hopefully not necessarily, you can pin input tensors to CPU, to make sure the function called with them as inputs is executed on CPU:

fx_train_0 = jax.device_put(fx_train_0, devices('cpu')[0])
fx_test_0 = jax.device_put(fx_test_0, devices('cpu')[0])
k_test_train = jax.device_put(k_test_train, devices('cpu')[0])

and/or

k_train_train = jax.device_put(k_train_train, devices('cpu')[0])
y_train = jax.device_put(y_train, devices('cpu')[0])

before defining predict_fn. In general, you can print x.device_buffer.device() in various places to see which tensors x are stored on which devices, to figure out what is happening on CPU/GPU (you want your last line to be executed on CPU).

romanngg avatar Sep 08 '22 16:09 romanngg

Thank you so much for the detailed follow-up!

As you suggested, I have tried to move everything to the CPU before defining the predict_fn and verified that they were indeed stored on the CPU. However, after a few minutes, the program is killed by the signal SIGSEGV (Address boundary error). Does it mean that I'm also out of CPU RAM? If yes, is there anything that I can do?

jasonli0707 avatar Sep 09 '22 07:09 jasonli0707

How much RAM do you have? Does it work (on CPU, after your modifications) if you use 36k points? I suspect you'd need at least ~64 Gb of RAM, but I only ever tried it on a machine with >128Gb, so I'm not sure what is the exact requirement.

To better debug this you can try to run the piece of code from https://github.com/google/neural-tangents/issues/152#issuecomment-1121615513 using first numpy/scipy, and then jax.numpy and jax.scipy to have a smaller repro. Then you could post it to https://github.com/google/jax and ask what they think. I also occasionally get these low-level errors when doing level-3 algebra on large matrices, and don't know how to debug them myself... (e.g. https://github.com/google/jax/issues/10411, https://github.com/google/jax/issues/10420)

romanngg avatar Sep 09 '22 17:09 romanngg