keras icon indicating copy to clipboard operation
keras copied to clipboard

Memory leak/crash on torch backend

Open mattdangerw opened this issue 9 months ago • 4 comments

Working on a mini-gpt example for @monicadsong. Loads a decently large (few GB) tf.data.Dataset. Colab:

https://colab.research.google.com/gist/mattdangerw/4f871c46f3eb5af49f828e2aea3bef79/mini-gpt-from-scatch.ipynb

This works as written on tf and jax backends without issue, but on the torch backend we OOM the GPU in the middle of the first epoch. This appears to be a leak or something inconsistent as we see this a variable number of steps into training. A few hundred or few thousand depending on the run.

[/usr/local/lib/python3.11/dist-packages/keras/src/trainers/compile_utils.py](https://localhost:8080/#) in __call__(self, y_true, y_pred, sample_weight)
    689     def __call__(self, y_true, y_pred, sample_weight=None):
    690         with ops.name_scope(self.name):
--> 691             return self.call(y_true, y_pred, sample_weight)
    692 
    693     def call(self, y_true, y_pred, sample_weight=None):

[/usr/local/lib/python3.11/dist-packages/keras/src/trainers/compile_utils.py](https://localhost:8080/#) in call(self, y_true, y_pred, sample_weight)
    698             _, loss_fn, loss_weight, _ = self._flat_losses[0]
    699             loss_value = ops.cast(
--> 700                 loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype
    701             )
    702             if loss_weight is not None:

[/usr/local/lib/python3.11/dist-packages/keras/src/losses/loss.py](https://localhost:8080/#) in __call__(self, y_true, y_pred, sample_weight)
     65             )
     66 
---> 67             losses = self.call(y_true, y_pred)
     68             out_mask = backend.get_keras_mask(losses)
     69 

[/usr/local/lib/python3.11/dist-packages/keras/src/losses/losses.py](https://localhost:8080/#) in call(self, y_true, y_pred)
     31         y_true = tree.map_structure_up_to(y_true, lambda x: x[0], y_true_y_pred)
     32         y_pred = tree.map_structure_up_to(y_pred, lambda x: x[1], y_true_y_pred)
---> 33         return self.fn(y_true, y_pred, **self._fn_kwargs)
     34 
     35     def get_config(self):

[/usr/local/lib/python3.11/dist-packages/keras/src/losses/losses.py](https://localhost:8080/#) in sparse_categorical_crossentropy(y_true, y_pred, from_logits, ignore_class, axis)
   2244         )
   2245 
-> 2246     res = ops.sparse_categorical_crossentropy(
   2247         y_true,
   2248         y_pred,

[/usr/local/lib/python3.11/dist-packages/keras/src/ops/nn.py](https://localhost:8080/#) in sparse_categorical_crossentropy(target, output, from_logits, axis)
   1961             from_logits=from_logits, axis=axis
   1962         ).symbolic_call(target, output)
-> 1963     return backend.nn.sparse_categorical_crossentropy(
   1964         target, output, from_logits=from_logits, axis=axis
   1965     )

[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/nn.py](https://localhost:8080/#) in sparse_categorical_crossentropy(target, output, from_logits, axis)
    705         output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
    706         log_prob = torch.log(output)
--> 707     target = one_hot(target, output.shape[axis], axis=axis)
    708     return -torch.sum(target * log_prob, dim=axis)
    709 

[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/nn.py](https://localhost:8080/#) in one_hot(x, num_classes, axis, dtype, sparse)
    629     # `where` afterwards.
    630     output = tnn.one_hot(maximum(x, 0), num_classes)
--> 631     output = where(expand_dims(x, axis=-1) >= 0, output, zero)
    632     output = convert_to_tensor(output, dtype=dtype)
    633     dims = output.dim()

[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/numpy.py](https://localhost:8080/#) in where(condition, x1, x2)
   1529         x1 = convert_to_tensor(x1)
   1530         x2 = convert_to_tensor(x2)
-> 1531         return torch.where(condition, x1, x2)
   1532     else:
   1533         return torch.where(condition)

OutOfMemoryError: CUDA out of memory. Tried to allocate 7.81 GiB. GPU 0 has a total capacity of 39.56 GiB of which 7.34 GiB is free. Process 30879 has 32.21 GiB memory in use. Of the allocated memory 25.56 GiB is allocated by PyTorch, and 6.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

mattdangerw avatar Mar 15 '25 00:03 mattdangerw

Unclear to me if this is an issue with loss computation where this stack trace is from, or just a leak in a data iterator when going from tf.data -> torch (where this line crashes not because it's the source of the leak but because this is a line that requires a lot of memory).

mattdangerw avatar Mar 15 '25 00:03 mattdangerw

(Might not related to the memory leak), tensorflow could take the default GPU device when it is loaded, so if you need to run tensorflow and torch at same time, you might want to hide GPU from TF so that it won't take majority of the HBM (jax also does similar things).

eg via

import tensorflow as tf
tf.config.set_visible_devices([], "GPU")

qlzh727 avatar Mar 17 '25 20:03 qlzh727

Thanks Scott!

I added that line to the beginning of the colab:

import keras
import pathlib
import tensorflow as tf
tf.config.set_visible_devices([], "GPU")

And still got the same error.

monicadsong avatar Mar 17 '25 21:03 monicadsong

"Hi" The error occurs specifically in the (sparse_categorical_crossentropy) function using the PyTorch backend.You must reduce the size of the tensor being allocated at the moment of the crash.

maaddyhere avatar Dec 06 '25 22:12 maaddyhere