CPU memory leak
When training DINOv3 on a server with 500 GB of CPU memory using my own dataset, I noticed an unreasonable linear growth in memory usage, which eventually caused the server to crash.
To investigate the issue, I used the tracemalloc library to trace memory usage within the ExtendedVisionDataset class:
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
from typing import Any, Tuple
import tracemalloc
from torchvision.datasets import VisionDataset
import torch
from .decoders import Decoder, ImageDataDecoder, TargetDecoder
class ExtendedVisionDataset(VisionDataset):
def __init__(
self,
image_decoder: Decoder = ImageDataDecoder,
target_decoder: Decoder = TargetDecoder,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs) # type: ignore
self.image_decoder = image_decoder
self.target_decoder = target_decoder
self._last_snapshot = None
tracemalloc.start()
def get_image_data(self, index: int) -> bytes:
raise NotImplementedError
def get_target(self, index: int) -> Any:
raise NotImplementedError
def __getitem__(self, index: int) -> Tuple[Any, Any]:
print("BEFORE===================================BEFORE")
if self._last_snapshot is None:
self._last_snapshot = tracemalloc.take_snapshot()
else:
current_snapshot = tracemalloc.take_snapshot()
top_stats = current_snapshot.compare_to(self._last_snapshot, "traceback")
for stat in top_stats[:5]:
print(stat)
for line in stat.traceback.format():
print(line)
self._last_snapshot = current_snapshot
try:
image_data = self.get_image_data(index)
image = self.image_decoder(image_data).decode()
except Exception as e:
raise RuntimeError(f"can not read image for sample {index}") from e
target = self.get_target(index)
target = self.target_decoder(target).decode()
if self.transforms is not None:
image, target = self.transforms(image, target)
print("AFTER===================================AFTER")
if self._last_snapshot is None:
self._last_snapshot = tracemalloc.take_snapshot()
else:
current_snapshot = tracemalloc.take_snapshot()
top_stats = current_snapshot.compare_to(self._last_snapshot, "traceback")
for stat in top_stats[:5]:
print(stat)
for line in stat.traceback.format():
print(line)
self._last_snapshot = current_snapshot
return image, target
def __len__(self) -> int:
raise NotImplementedError
The results show that memory allocated by torch.utils._pytree.tree_flatten() is not released properly:
BEFORE===================================BEFORE
...
AFTER===================================AFTER
...
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/utils/_pytree.py:1279: size=308 KiB (+26.3 KiB), count=3083 (+289), average=102 B
File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1279
def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
...
This function is called in torchvision.transforms.v2._transform.py, line 41 to flatten various input format to a list.
I modified the v2 API to v1 in augmentations.py as follow, and it can work well.
#from torchvision.transforms import v2
import torchvision.transforms as v2
...
# normalization
self.normalize = v2.Compose(
[
#v2.ToImage(),
#v2.ToDtype(torch.float32, scale=True),
v2.ToTensor(),
make_normalize_transform(mean=mean, std=std),
]
)
...
However, is there any other way to prevent this memory leak?
Hmm...I tried to trace the memory usage in the training loop:
...
consecutive_nan_count = 0
last_snapshot = None
tracemalloc.start()
for data in metric_logger.log_every(
data_loader,
print_freq=10,
header="Training",
n_iterations=max_iter,
start_iteration=start_iter,
):
if last_snapshot is None:
last_snapshot = tracemalloc.take_snapshot()
else:
current_snapshot = tracemalloc.take_snapshot()
top_stats = current_snapshot.compare_to(last_snapshot, "traceback")
for stat in top_stats[:5]:
print(stat)
for line in stat.traceback.format():
print(line)
print("===================================")
last_snapshot = current_snapshot
it = iteration
data["global_batch_size"] = global_batch_size
...
And it seems that the memory leak also exists:
===================================
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/utils/_pytree.py:1279: size=464 KiB (+63.2 KiB), count=5408 (+739), average=88 B
File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1279
def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/distributed/utils.py:221: size=115 KiB (+18.7 KiB), count=997 (+157), average=118 B
File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/distributed/utils.py", line 221
def apply(x):
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/autograd/function.py:576: size=120 KiB (+17.3 KiB), count=852 (+126), average=144 B
File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/autograd/function.py", line 576
return super().apply(*args, **kwargs) # type: ignore[misc]
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py:240: size=58.9 KiB (+10.1 KiB), count=932 (+162), average=65 B
File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 240
cast_fn = functools.partial(
/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py:100: size=65.7 KiB (+9600 B), count=337 (+48), average=200 B
File "/data/gyf/micromamba/envs/dinov3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 100
return f(*args)
Update: The problem does not occur on pytree, as gc. collect() will clear excess data.
I will try to reduce num_workers
Update: The memory leak did occur in torchvision.transforms as this issue describes, but I haven't figured out the reason.
Modifiy the v2 API to v1 in both dinov3/data/augmentations.py and dinov3/data/transforms.py can relieve the memory leak and make it affordable for my server.
I tried:
- replace v2 API with v1
- delete the gc.disable() in dinov3/train/train.py, line 435 The memory leak seems to be resolved, but I don't know why it works.