backpack
backpack copied to clipboard
memory leakage problem when loss.backward()
Hi.
How can i solve memory leakage problem on loss.backward? Actually my code is a bit complex, which makes it hard to provide whole.
The main memory burden comes from here. Exactly after the execution of loss.backward(). The memory increases for every iteration, which leads to OOM.
loss = bce_extended(logits, y).sum()
with backpack(BatchGrad()):
if real_sample:
loss.backward(inputs = list(model.parameters()))
I also tried with disable():
which prevent the memory leak problem. However, it cannot be implemented with with backpack(BatchGrad()):
when i want to get the per-sample gradient.
Hi,
thanks for reporting. I understand your full code might be complex to share, but could you provide a minimal example that reproduces the leak? This would be extremely helpful for debugging.
The above code looks fine to me, so it would be interesting to have more details when the issue occurs.
Best, Felix
I tried hard to construct the minimal example. Sorry for providing the still lengthy code, which is also difficult to understand the learning framework.
My task is dataset condensation, which transforms the random noise into the informative input by directly updating it via gradient on random noise. I am trying to construct the learning objective for the given image by matching (the loss gradient between training dataset and synthetic input) and (the loss gradient variance between them, which needs the access into per-sample gradient for variance computation.)
I also found out that the memory leakage arises on loss.backward() part on get_grad function. Please do not care much about the semantics of code, where some parts are pseudo produced. I am really grateful for you. thanks.
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
import torch
from backpack import backpack, extend
from backpack.extensions import BatchGrad
from backpack.utils.examples import load_one_batch_mnist
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader
# download path 정의
download_root = './MNIST_DATASET'
import torchvision.transforms as transforms
# Normalize data with mean=0.5, std=1.0
mnist_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))
])
train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
# option 값 정의
batch_size = 512
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
device = torch.device("cuda:0")
def compute_distance_grads_var(dict_grad_1,dict_grad_2):
penalty = 0
penalty += l2_between_lists(dict_grad_1, dict_grad_2)
return penalty
def l2_between_lists(list_1, list_2):
assert len(list_1) == len(list_2)
return (
torch.cat(tuple([t.view(-1) for t in list_1])) -
torch.cat(tuple([t.view(-1) for t in list_2]))
).pow(2).sum()
def dist(x, y, method='mse'):
"""Distance objectives
"""
if method == 'mse':
dist_ = (x - y).pow(2).sum()
elif method == 'l1':
dist_ = (x - y).abs().sum()
elif method == 'l1_mean':
n_b = x.shape[0]
dist_ = (x - y).abs().reshape(n_b, -1).mean(-1).sum()
elif method == 'cos':
x = x.reshape(x.shape[0], -1)
y = y.reshape(y.shape[0], -1)
dist_ = torch.sum(1 - torch.sum(x * y, dim=-1) /
(torch.norm(x, dim=-1) * torch.norm(y, dim=-1) + 1e-6))
elif method == 'l2_mean':
dist_ = torch.norm(x-y, 2)
return dist_
def get_grads(logits, y,model,bce_extended,real_sample):
loss = bce_extended(logits, y).sum()
with backpack(BatchGrad()):
if real_sample:
loss.backward()
else:
loss.backward(create_graph=True)
grads_mean = []
dict_grads_batch = []
for name, weights in model.named_parameters():
if real_sample:
grads_mean.append(weights.grad.detach().clone())
dict_grads_batch.append(weights.grad_batch.detach().clone().view(weights.grad_batch.size(0), -1))
else:
grads_mean.append(weights.grad.clone())
dict_grads_batch.append(weights.grad_batch.clone().view(weights.grad_batch.size(0), -1))
return grads_mean, dict_grads_batch
for i in range(100):
model = Sequential(Flatten(), Linear(784, 128), Linear(128, 10)) # I added an additional layer here
lossfunc = CrossEntropyLoss()
model = extend(model)
lossfunc = extend(lossfunc)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for batch_idx, (x, target) in enumerate(train_loader):
x_syn = torch.rand(x.shape, requires_grad=True, device="cuda:0")
y_syn = torch.ones_like(target)
y_syn = y_syn.to(device)
optimizer_alpha = torch.optim.Adam([x_syn], lr=1e-3)
x = x.to(device)
target = target.to(device)
loss_model = lossfunc(model(x),target)
grad, grad_batch = get_grads(model(x), target, model, lossfunc,real_sample=True)
grad_syn, grad_batch_syn = get_grads(model(x_syn), y_syn, model, lossfunc,real_sample=False)
loss =0
for i in range(len(grad)):
loss +=dist(grad[i], grad_syn[i], method='l2_mean')
loss += compute_distance_grads_var(grad_batch,grad_batch_syn)
optimizer.zero_grad()
loss.backward()
optimizer_alpha.step()
I think that this implementation should provide a lot of utilities on per-sample gradient computation. One of famous utilization is "https://arxiv.org/abs/2109.02934", which updates the model parameter based on gradient variance matching by leveraging backpack.
The main difference is that my code tries to learn gradient on synthetic images, rather than the gradient on model parameter.
Adding a zero_grad
on the inputs gradients seem to fix the issue.
As in changing the last few lines of the above script
optimizer.zero_grad()
loss.backward()
optimizer_alpha.step()
to
optimizer.zero_grad()
loss.backward()
optimizer_alpha.step()
optimizer_alpha.zero_grad() # ---
It should not change its behavior as optimizer_alpha
is re-initialized at each iteration.
Not sure why it's not garbage collected though.
Hi, i checked your solution and it slightly reduces the memory leak. However, the memory increases in very small increments as the iteration goes on. It seems like some parts are still not garbage collected..!
Hi,
just wanted to bring it up because I saw there is a .backward(..., create_graph=True)
in your code: There's a memory leak when using full_backward_hook
s with create_graph=True
in PyTorch (#82528). You could try installing PyTorch with the fix (#82788) to see if that's causing the memory leak.
Hi,
first, thanks for your generous reply on my question. I really appreciate your sincere reply with efforts. As you referred, the cause of the memory leakage problem seems to be pytorch.
So, i upgraded pytorch with preview (nightly) version.
(Speaking of reference, the _grad_input_padding
currently used by backpack was depreciated in the corresponding version of pytorch, so adjustments were needed to implement the code.)
Afterwards, i got new type of error from the code below:
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
import torch
from backpack import backpack, extend
from backpack.extensions import BatchGrad
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader
download_root = './MNIST_DATASET'
import torchvision.transforms as transforms
# Normalize data with mean=0.5, std=1.0
mnist_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))
])
train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
# option 값 정의
batch_size = 512
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
device = torch.device("cuda:0")
def compute_distance_grads_var(dict_grad_1,dict_grad_2):
penalty = 0
penalty += l2_between_lists(dict_grad_1, dict_grad_2)
return penalty
def l2_between_lists(list_1, list_2):
assert len(list_1) == len(list_2)
return (
torch.cat(tuple([t.view(-1) for t in list_1])) -
torch.cat(tuple([t.view(-1) for t in list_2]))
).pow(2).sum()
def dist(x, y, method='mse'):
"""Distance objectives
"""
if method == 'mse':
dist_ = (x - y).pow(2).sum()
elif method == 'l1':
dist_ = (x - y).abs().sum()
elif method == 'l1_mean':
n_b = x.shape[0]
dist_ = (x - y).abs().reshape(n_b, -1).mean(-1).sum()
elif method == 'cos':
x = x.reshape(x.shape[0], -1)
y = y.reshape(y.shape[0], -1)
dist_ = torch.sum(1 - torch.sum(x * y, dim=-1) /(torch.norm(x, dim=-1) * torch.norm(y, dim=-1) + 1e-6))
elif method == 'l2_mean':
dist_ = torch.norm(x-y, 2)
return dist_
def get_grads(logits, y,model,bce_extended,real_sample):
loss = bce_extended(logits, y).sum()
with backpack(BatchGrad(),debug=True):
if real_sample:
# loss.backward(inputs=list(model.parameters()))
loss.backward(inputs=list(model.parameters()))
else:
loss.backward(inputs=list(model.parameters()),create_graph = True)
grads_mean = []
dict_grads_batch = []
for name, weights in model.named_parameters():
if real_sample:
grads_mean.append(weights.grad.detach().clone())
dict_grads_batch.append(weights.grad_batch.detach().clone().view(weights.grad_batch.size(0), -1))
else:
grads_mean.append(weights.grad)
dict_grads_batch.append(weights.grad_batch.view(weights.grad_batch.size(0), -1))
return grads_mean, dict_grads_batch
for i in range(100):
model = Sequential(Flatten(), Linear(784, 128), Linear(128, 10)) # I added an additional layer here
lossfunc = CrossEntropyLoss()
model = extend(model)
lossfunc = extend(lossfunc)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for batch_idx, (x, target) in enumerate(train_loader):
x_syn = torch.rand(x.shape, requires_grad=True, device="cuda:0")
y_syn = torch.ones_like(target)
y_syn = y_syn.to(device)
optimizer_alpha = torch.optim.Adam([x_syn], lr=1e-3)
x = x.to(device)
target = target.to(device)
grad, grad_batch = get_grads(model(x), target, model, lossfunc,real_sample=True)
grad_syn, grad_batch_syn = get_grads(model(x_syn), y_syn, model, lossfunc,real_sample=False)
loss =0
for i in range(len(grad)):
loss +=dist(grad[i], grad_syn[i], method='l2_mean')
loss += compute_distance_grads_var(grad_batch,grad_batch_syn)
loss.backward()
optimizer_alpha.step()
optimizer_alpha.zero_grad()
Traceback (most recent call last):
File "condense_example.py", line 125, in
It seems like the error arises because model backpropagating twice in the graph, which is essential behavior for my task...
The link below shows that wandb was a cause for the same error, which needs an access into the backward function. (https://discuss.pytorch.org/t/runtimeerror-module-backward-hook-for-grad-input-is-called-before-the-grad-output-one-this-happens-because-the-gradient-in-your-nn-m-odule-flows-to-the-modules-input-without-passing-through-the-modules-output/119763.)
Thanks.
Hi again, thanks for the report.
one more (rather miscellaneous) thing you might want to try is the following: You should be able to use retain_graph=True
rather than create_graph=True
to be able to get the gradient for synthetic samples, right? Does this still lead to the above exception?
Hi. Thanks for your kind suggestion. Actually i already have tried that. The choice leads to the following error below:
File "condense_example.py", line 102, in <module>
loss.backward()
File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I conjecture that create_graph = True
should be followed to be able to calculate the second order derivative, which is crucial for calculating the gradient for synthetic samples...
You're correct. I missed the second loss.backward()
call that differentiates through loss
which contains gradient terms.
I think the backpack is the only way to realistically calculate the gradient variance of each model parameter through a quick computation of the per-sample gradient.
However, the problem from my code seems to be an inherent problem in pytorch rather than a problem from backpack. I will keep track of the memory leakage problem to find a solution. However, it would be a great help if you could give me some helpful advice.
Keep me posted about the memory leak problem.
You might also want to try out functorch
(it should be possible to integrate it into your existing PyTorch code without too much effort), or jax
. They can also compute individual gradients that you can use to get the variance.
Thanks. I appreciate your suggestion. Did you try any kind of time comparison between backpack and functorch?
I have heard that functorch is still expensive to compute per-sample gradient, which takes too much time...
Sadly I don't have any data yet how BackPACK compares to functorch
in terms of runtime, e.g. for computing individual gradients.
But I think you will be able to port your existing code to functorch
with relatively few changes to try if it works fast enough for your needs.