TorchSharp icon indicating copy to clipboard operation
TorchSharp copied to clipboard

TorchSharp memory issue

Open HCareLou opened this issue 1 year ago • 19 comments

 m = torch.nn.Conv2d(3, 64, 7, 2, 3, bias=False).cuda()

    for i in range(1000000):
        x = torch.randn(1, 3, 224, 224, dtype=torch.float).cuda()
        y = m.forward(x)

image

var m=TorchSharp.torch.nn.Conv2d(3, 64, 7, 2, 3, bias: false).cuda();

for (int i = 0; i < 1000000; i++)
{
    var x = torch.randn(1, 3, 224, 224).@float().cuda();
    var y = m.forward(x);
}

image In PyTorch, when using GPU inference, GPU memory can be released at the appropriate time. In TorchSharp, when using GPU inference, there is a GPU memory leak that requires manual release.

HCareLou avatar Apr 02 '24 06:04 HCareLou

var m = torch.nn.Conv2d(3, 64, 7, 2, 3, bias: false).cuda();
for (int i = 0; i < 1000000; i++)
{
    using var x = torch.randn(1, 3, 224, 224).@float().cuda();
    using var y = m.forward(x);
}

Although, however, this is the only way it can be written.

HCareLou avatar Apr 02 '24 08:04 HCareLou

Perhaps this page could help: https://github.com/dotnet/TorchSharp/wiki/Memory-Management

yueyinqiu avatar Apr 03 '24 01:04 yueyinqiu

torch.NewDisposeScope() is a relatively elegant solution, although not as elegant as Pytorch.

HCareLou avatar Apr 03 '24 02:04 HCareLou

I'm kind of considering giving up, as unexpected exceptions occur when torch.NewDisposeScope is nested, especially when one function calls another and the called function also has torch.NewDisposeScope. Objects that shouldn't be disposed of are being released. It seems training AI with C# is not very realistic, which is quite frustrating.

HCareLou avatar Apr 03 '24 09:04 HCareLou

I‘m sorry to hear that. But DisposeScope is designed to work in that case. Could you describe the issue more specifically and thus we could fix that?

Oh... Since you mentioned that 'objects that shouldn't be disposed of are being released', I guess MoveToOuterDisposeScope could work?

yueyinqiu avatar Apr 03 '24 11:04 yueyinqiu

It might be because f is not disposed. Hope this could help:

using TorchSharp;

for (int i = 0; i < 10000000000; i++)
{
    using (torch.NewDisposeScope())
    {
        var f = torch.randn(1, 3, 224, 224).@float().cuda();
        using (torch.NewDisposeScope())
        {
            var f3 = torch.randn(1, 3, 224, 224).@float().cuda();
            f[..] = f3;
        }
    }
}
Console.ReadKey();

By the way, are you a Chinese user? I have created a qq group (957204993) just now so we perhaps could discuss there with instant messages, which could be more convenient.

yueyinqiu avatar Apr 03 '24 15:04 yueyinqiu

public static class Ops{
    public static Tensor clip_boxes(Tensor boxes, int[] shape)
    {
        using (torch.NewDisposeScope())
        {
            boxes[TensorIndex.Ellipsis, 0] = boxes[TensorIndex.Ellipsis, 0].clamp(0, shape[1]);
            boxes[TensorIndex.Ellipsis, 1] = boxes[TensorIndex.Ellipsis, 1].clamp(0, shape[0]);
            boxes[TensorIndex.Ellipsis, 2] = boxes[TensorIndex.Ellipsis, 2].clamp(0, shape[1]);
            boxes[TensorIndex.Ellipsis, 3] = boxes[TensorIndex.Ellipsis, 3].clamp(0, shape[0]);
            return boxes.MoveToOuterDisposeScope();
        }
    }
public static Tensor scale_boxes(int[] img1_shape, Tensor boxes, int[] img0_shape, (int, int)[] ratio_pad = null!, bool padding = true, bool xywh = false)
{
    using (torch.NewDisposeScope())
    {
        double gain;
        (double, double) pad;
        if (ratio_pad == null)
        {
            gain = Math.Min(img1_shape[0] * 1.0 / img0_shape[0], img1_shape[1] * 1.0 / img0_shape[1]);
            pad = (
                Math.Round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
                Math.Round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)
            );
        }
        else
        {
            gain = ratio_pad[0].Item1;
            pad = ratio_pad[1];
        }

        if (padding)
        {
            boxes[TensorIndex.Ellipsis, 0] -= pad.Item1;
            boxes[TensorIndex.Ellipsis, 1] -= pad.Item2;
            if (!xywh)
            {
                boxes[TensorIndex.Ellipsis, 2] -= pad.Item1;
                boxes[TensorIndex.Ellipsis, 3] -= pad.Item2;
            }
        }

        boxes[TensorIndex.Ellipsis, ..4] /= gain;
        return clip_boxes(boxes, img0_shape).MoveToOuterDisposeScope();
    }
}
}
public abstract class OutputData : IDisposable
{
    public abstract void Dispose();

    public abstract List<dynamic> ToList();
    public abstract OutputData MoveToOuterDisposeScope();
    ~OutputData()
    {
        Dispose();
    }
}
public class DetectPredictData : OutputData
{
    public Tensor Y { get; set; }
    public List<Tensor> X { get; set; }

    public override void Dispose()
    {
        Y?.Dispose();
        X?.ForEach(x => x.Dispose());
    }

    public override OutputData MoveToOuterDisposeScope()
    {
        Y?.MoveToOuterDisposeScope();
        X.ForEach(x => x.MoveToOuterDisposeScope());
        return this;
    }

    public override List<dynamic> ToList()
    {
        return [Y,X];
    }
}
for (int i = 0; i < 10000000; i++)
{
    using (torch.NewDisposeScope())
    {
        OutputData data = new DetectPredictData()
        {
            Y = torch.randn(3, 84, 8400).@float().cuda(),
            X = [
                torch.randn(3, 144, 80,80).@float().cuda(),
                torch.randn(3, 288, 40,40).@float().cuda(),
                torch.randn(3, 576, 20,20).@float().cuda(),
                ]
        };
        var pre = data.ToList();
        using (torch.NewDisposeScope())
        {
            Tensor f = pre[0];
            int[]? f1 = [1080, 2];
            var boxes = Ops.scale_boxes([564, 640], f[.., ..4], f1);
        }
        data.Dispose();
    }
}

image After a day's work, I've located the position of the memory leak. Now I'm simulating the reproduction of this memory leak issue (with an 80% reproduction rate, as there are some unexplainable phenomena, thus 20% was not reproduced. This means that the solution for the simulated memory leak code does not apply to my project, and conversely, the solution for the memory leak in my project does not apply to this simulation code. It's quite awkward!!!)

HCareLou avatar Apr 04 '24 03:04 HCareLou

I suppose that Ops.clip_boxes and Ops.scale_boxes should not invoke MoveToOuterDisposeScope().

That's because boxes is created here:

image

So it's related dispose scope is:

286dc92cf5a7b5b41967fe2a2204414a

When using MoveToOuterDisposeScope once, it's dispose scope will be:

7f7f95363b09593293f13c9cf04dcd54

And after using it twice, there are no dispose scope for it. Then it leaks.

(Only the tensors/parameters that is created in one dispose scope will be automatically attached to it. And in place actions will not modify its dispose scope.)

yueyinqiu avatar Apr 04 '24 04:04 yueyinqiu

Yes, the simulated code can resolve the issue by removing MoveToOuterDisposeScope(), but for the code where the actual memory leak occurs, I cannot handle it in this way. I need to modify the code as follows, which is very confusing to me.

   public static Tensor clip_boxes(Tensor boxes, int[] shape)
   {
       using (torch.NewDisposeScope())
       {
           boxes[TensorIndex.Ellipsis, 0] = boxes[TensorIndex.Ellipsis, 0].clone().clamp(0, shape[1]);
           boxes[TensorIndex.Ellipsis, 1] = boxes[TensorIndex.Ellipsis, 1].clone().clamp(0, shape[0]);
           boxes[TensorIndex.Ellipsis, 2] = boxes[TensorIndex.Ellipsis, 2].clone().clamp(0, shape[1]);
           boxes[TensorIndex.Ellipsis, 3] = boxes[TensorIndex.Ellipsis, 3].clone().clamp(0, shape[0]);
           return boxes.MoveToOuterDisposeScope();
       }
       
   }

It's just by adding .clone(), which is very bizarre.

HCareLou avatar Apr 04 '24 04:04 HCareLou

You could just remove that:

    public static Tensor clip_boxes(Tensor boxes, int[] shape)
    {
        using (torch.NewDisposeScope())
        {
            boxes[TensorIndex.Ellipsis, 0] = boxes[TensorIndex.Ellipsis, 0].clamp(0, shape[1]);
            boxes[TensorIndex.Ellipsis, 1] = boxes[TensorIndex.Ellipsis, 1].clamp(0, shape[0]);
            boxes[TensorIndex.Ellipsis, 2] = boxes[TensorIndex.Ellipsis, 2].clamp(0, shape[1]);
            boxes[TensorIndex.Ellipsis, 3] = boxes[TensorIndex.Ellipsis, 3].clamp(0, shape[0]);
            return boxes;
        }
    }

I suppose there is no problem with this. Are you worried about any other things?

yueyinqiu avatar Apr 04 '24 04:04 yueyinqiu

No no no, your code will cause a memory leak in my project, but adding .clone() fixes it. However, the simulated code still leaks memory even with .clone() added. Please trust me, there is still an issue with torch.NewDisposeScope().

HCareLou avatar Apr 04 '24 04:04 HCareLou

Actually your clone cannot solve this problem. You could use it on the return instead:

    public static Tensor clip_boxes(Tensor boxes, int[] shape)
    {
        using (torch.NewDisposeScope())
        {
            boxes[TensorIndex.Ellipsis, 0] = boxes[TensorIndex.Ellipsis, 0].clamp(0, shape[1]);
            boxes[TensorIndex.Ellipsis, 1] = boxes[TensorIndex.Ellipsis, 1].clamp(0, shape[0]);
            boxes[TensorIndex.Ellipsis, 2] = boxes[TensorIndex.Ellipsis, 2].clamp(0, shape[1]);
            boxes[TensorIndex.Ellipsis, 3] = boxes[TensorIndex.Ellipsis, 3].clamp(0, shape[0]);
            return boxes.clone().MoveToOuterDisposeScope();
        }
    }

Or:

    public static Tensor clip_boxes(Tensor boxes, int[] shape)
    {
        using (torch.NewDisposeScope())
        {
            boxes = boxes.clone();
            boxes[TensorIndex.Ellipsis, 0] = boxes[TensorIndex.Ellipsis, 0].clamp(0, shape[1]);
            boxes[TensorIndex.Ellipsis, 1] = boxes[TensorIndex.Ellipsis, 1].clamp(0, shape[0]);
            boxes[TensorIndex.Ellipsis, 2] = boxes[TensorIndex.Ellipsis, 2].clamp(0, shape[1]);
            boxes[TensorIndex.Ellipsis, 3] = boxes[TensorIndex.Ellipsis, 3].clamp(0, shape[0]);
            return boxes.MoveToOuterDisposeScope();
        }
    }

But I suppose there is no reason to use clone and keep MoveToOuterDisposeScope. Would there be any other problems in your project that cause a memory leak?

yueyinqiu avatar Apr 04 '24 05:04 yueyinqiu

By the way, you could track the tensor's dispose scope here:

a9b3f111243da9798f874c27c6189d22

Hope this could help when debugging.

yueyinqiu avatar Apr 04 '24 05:04 yueyinqiu

Please watch the VCR. https://github.com/dotnet/TorchSharp/assets/55724885/20f3a37a-1267-4ebd-a2e8-81365a122a8a

HCareLou avatar Apr 04 '24 05:04 HCareLou

Hmm I'm really not sure about that. Is it possible to share the whole project with me?

yueyinqiu avatar Apr 04 '24 05:04 yueyinqiu

Sorry about that, it’s not convenient at the moment.

HCareLou avatar Apr 04 '24 06:04 HCareLou

My only guess is that because of the higher usage of the memory (memory, not gpu memory), the garbage collection system is actived and thus the escaped tensors are released?

yueyinqiu avatar Apr 04 '24 06:04 yueyinqiu

Not too sure, haven't found the exact cause yet, it's a bit odd.

HCareLou avatar Apr 04 '24 06:04 HCareLou

 public static Tensor clip_boxes(Tensor boxes, int[] shape)
 {
     using (torch.NewDisposeScope())
     {
         boxes[TensorIndex.Ellipsis, 0].clamp_(0, shape[1]);
         boxes[TensorIndex.Ellipsis, 1].clamp_(0, shape[0]);
         boxes[TensorIndex.Ellipsis, 2].clamp_(0, shape[1]);
         boxes[TensorIndex.Ellipsis, 3].clamp_(0, shape[0]);
         return boxes;
     }

 }

This can also solve the problem of memory leaks.

HCareLou avatar Apr 04 '24 14:04 HCareLou