dinov3 icon indicating copy to clipboard operation
dinov3 copied to clipboard

CPU memory leak

Open Extord1108 opened this issue 2 months ago • 4 comments

When training DINOv3 on a server with 500 GB of CPU memory using my own dataset, I noticed an unreasonable linear growth in memory usage, which eventually caused the server to crash.

To investigate the issue, I used the tracemalloc library to trace memory usage within the ExtendedVisionDataset class:

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.

from typing import Any, Tuple
import tracemalloc

from torchvision.datasets import VisionDataset
import torch

from .decoders import Decoder, ImageDataDecoder, TargetDecoder


class ExtendedVisionDataset(VisionDataset):
    def __init__(
        self,
        image_decoder: Decoder = ImageDataDecoder,
        target_decoder: Decoder = TargetDecoder,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)  # type: ignore
        self.image_decoder = image_decoder
        self.target_decoder = target_decoder

        self._last_snapshot = None
        tracemalloc.start()

    def get_image_data(self, index: int) -> bytes:
        raise NotImplementedError

    def get_target(self, index: int) -> Any:
        raise NotImplementedError

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        print("BEFORE===================================BEFORE")
        if self._last_snapshot is None:
            self._last_snapshot = tracemalloc.take_snapshot()
        else:
            current_snapshot = tracemalloc.take_snapshot()
            top_stats = current_snapshot.compare_to(self._last_snapshot, "traceback")
            for stat in top_stats[:5]:
                print(stat)
                for line in stat.traceback.format():
                    print(line)
            self._last_snapshot = current_snapshot
        try:
            image_data = self.get_image_data(index)
            image = self.image_decoder(image_data).decode()
        except Exception as e:
            raise RuntimeError(f"can not read image for sample {index}") from e
        target = self.get_target(index)
        target = self.target_decoder(target).decode()

        if self.transforms is not None:
            image, target = self.transforms(image, target)
        print("AFTER===================================AFTER")
        if self._last_snapshot is None:
            self._last_snapshot = tracemalloc.take_snapshot()
        else:
            current_snapshot = tracemalloc.take_snapshot()
            top_stats = current_snapshot.compare_to(self._last_snapshot, "traceback")
            for stat in top_stats[:5]:
                print(stat)
                for line in stat.traceback.format():
                    print(line)
            self._last_snapshot = current_snapshot

        return image, target

    def __len__(self) -> int:
        raise NotImplementedError

The results show that memory allocated by torch.utils._pytree.tree_flatten() is not released properly:

BEFORE===================================BEFORE

...

AFTER===================================AFTER
...

/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/utils/_pytree.py:1279: size=308 KiB (+26.3 KiB), count=3083 (+289), average=102 B
  File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1279
    def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:

...

This function is called in torchvision.transforms.v2._transform.py, line 41 to flatten various input format to a list.

I modified the v2 API to v1 in augmentations.py as follow, and it can work well.

#from torchvision.transforms import v2
import torchvision.transforms as v2

...

        # normalization
        self.normalize = v2.Compose(
            [
                #v2.ToImage(),
                #v2.ToDtype(torch.float32, scale=True),
                v2.ToTensor(),
                make_normalize_transform(mean=mean, std=std),
            ]
        )

...

However, is there any other way to prevent this memory leak?

Extord1108 avatar Oct 14 '25 22:10 Extord1108

Hmm...I tried to trace the memory usage in the training loop:

...

    consecutive_nan_count = 0
    last_snapshot = None
    tracemalloc.start()
    for data in metric_logger.log_every(
        data_loader,
        print_freq=10,
        header="Training",
        n_iterations=max_iter,
        start_iteration=start_iter,
    ):
        if last_snapshot is None:
            last_snapshot = tracemalloc.take_snapshot()
        else:
            current_snapshot = tracemalloc.take_snapshot()
            top_stats = current_snapshot.compare_to(last_snapshot, "traceback")
            for stat in top_stats[:5]:
                print(stat)
                for line in stat.traceback.format():
                    print(line)
            print("===================================")
            last_snapshot = current_snapshot
        it = iteration
        data["global_batch_size"] = global_batch_size

...

And it seems that the memory leak also exists:

===================================
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/utils/_pytree.py:1279: size=464 KiB (+63.2 KiB), count=5408 (+739), average=88 B
  File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1279
    def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/distributed/utils.py:221: size=115 KiB (+18.7 KiB), count=997 (+157), average=118 B
  File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/distributed/utils.py", line 221
    def apply(x):
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/autograd/function.py:576: size=120 KiB (+17.3 KiB), count=852 (+126), average=144 B
  File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/autograd/function.py", line 576
    return super().apply(*args, **kwargs)  # type: ignore[misc]
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py:240: size=58.9 KiB (+10.1 KiB), count=932 (+162), average=65 B
  File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 240
    cast_fn = functools.partial(
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py:100: size=65.7 KiB (+9600 B), count=337 (+48), average=200 B
  File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 100
    return f(*args)

Extord1108 avatar Oct 14 '25 22:10 Extord1108

Update: The problem does not occur on pytree, as gc. collect() will clear excess data.

I will try to reduce num_workers

Extord1108 avatar Oct 15 '25 01:10 Extord1108

Update: The memory leak did occur in torchvision.transforms as this issue describes, but I haven't figured out the reason. Modifiy the v2 API to v1 in both dinov3/data/augmentations.py and dinov3/data/transforms.py can relieve the memory leak and make it affordable for my server.

Extord1108 avatar Oct 17 '25 01:10 Extord1108

I tried:

  1. replace v2 API with v1
  2. delete the gc.disable() in dinov3/train/train.py, line 435 The memory leak seems to be resolved, but I don't know why it works.

weihaolan6-oss avatar Dec 12 '25 08:12 weihaolan6-oss