RobustVideoMatting icon indicating copy to clipboard operation
RobustVideoMatting copied to clipboard

C++ sample code available?

Open ewayboy opened this issue 3 years ago • 8 comments

Does anyone have c++ code to run the demo?

ewayboy avatar Sep 17 '21 10:09 ewayboy

I don't have c++ sample code, but I will keep this issue open for others to answer.

PeterL1n avatar Sep 17 '21 16:09 PeterL1n

Python to C++.

    auto device = torch::Device("cuda");
    auto precision = torch::kFloat16;
    auto downsampleRatio = 0.4;
    c10::optional<torch::Tensor> tensorRec0;
    c10::optional<torch::Tensor> tensorRec1;
    c10::optional<torch::Tensor> tensorRec2;
    c10::optional<torch::Tensor> tensorRec3;

    auto model = torch::jit::load("rvm_mobilenetv3_fp16.torchscript");
    //! freeze error.
    //model  = torch::jit::freeze(model );
    model.to(device);

    //! imgSrc: RGB image data, such as QImage.
    auto tensorSrc = torch::from_blob(imgSrc.bits(), { imgSrc.height(),imgSrc.width(),3 }, torch::kByte);
    tensorSrc = tensorSrc.to(device);
    tensorSrc = tensorSrc.permute({ 2,0,1 }).contiguous();
    tensorSrc = tensorSrc.to(precision).div(255);
    tensorSrc.unsqueeze_(0);

    //! Inference
    auto outputs = model.forward({ tensorSrc,tensorRec0,tensorRec1,tensorRec2,tensorRec3,downsampleRatio }).toList();

    const auto &fgr = outputs.get(0).toTensor();
    const auto &pha = outputs.get(1).toTensor();
    tensorRec0 = outputs.get(2).toTensor();
    tensorRec1 = outputs.get(3).toTensor();
    tensorRec2 = outputs.get(4).toTensor();
    tensorRec3 = outputs.get(5).toTensor();

    //! Green target bgr
    auto tensorTargetBgr = torch::tensor({ 120.f / 255, 255.f / 255, 155.f / 255 }).toType(precision).to(device).view({ 1, 3, 1, 1 });
    //! Compound
    auto res_tensor = pha * fgr + (1 - pha) * tensorTargetBgr;

    res_tensor = res_tensor.mul(255).permute({ 0,2,3,1 })[0].to(torch::kU8).contiguous().cpu();

BrightenWu avatar Sep 18 '21 06:09 BrightenWu

Does anyone have c++ code to run the demo?

@ewayboy @PeterL1n

C++ Demos for RobustVideoMatting:

rvm2021

DefTruth avatar Sep 20 '21 13:09 DefTruth

Python to C++.

    auto device = torch::Device("cuda");
    auto precision = torch::kFloat16;
    auto downsampleRatio = 0.4;
    c10::optional<torch::Tensor> tensorRec0;
    c10::optional<torch::Tensor> tensorRec1;
    c10::optional<torch::Tensor> tensorRec2;
    c10::optional<torch::Tensor> tensorRec3;

    auto model = torch::jit::load("rvm_mobilenetv3_fp16.torchscript");
    //! freeze error.
    //model  = torch::jit::freeze(model );
    model.to(device);

    //! imgSrc: RGB image data, such as QImage.
    auto tensorSrc = torch::from_blob(imgSrc.bits(), { imgSrc.height(),imgSrc.width(),3 }, torch::kByte);
    tensorSrc = tensorSrc.to(device);
    tensorSrc = tensorSrc.permute({ 2,0,1 }).contiguous();
    tensorSrc = tensorSrc.to(precision).div(255);
    tensorSrc.unsqueeze_(0);

    //! Inference
    auto outputs = model.forward({ tensorSrc,tensorRec0,tensorRec1,tensorRec2,tensorRec3,downsampleRatio }).toList();

    const auto &fgr = outputs.get(0).toTensor();
    const auto &pha = outputs.get(1).toTensor();
    tensorRec0 = outputs.get(2).toTensor();
    tensorRec1 = outputs.get(3).toTensor();
    tensorRec2 = outputs.get(4).toTensor();
    tensorRec3 = outputs.get(5).toTensor();

    //! Green target bgr
    auto tensorTargetBgr = torch::tensor({ 120.f / 255, 255.f / 255, 155.f / 255 }).toType(precision).to(device).view({ 1, 3, 1, 1 });
    //! Compound
    auto res_tensor = pha * fgr + (1 - pha) * tensorTargetBgr;

    res_tensor = res_tensor.mul(255).permute({ 0,2,3,1 })[0].to(torch::kU8).contiguous().cpu();

The code for continuous video matting will lead to memory explosion and program crash. Is there a better way to deal with it??thanks a lot.

semchan avatar Nov 17 '21 06:11 semchan

Need to release last res_tensor data? It's copy back to memory.

ewayboy avatar Nov 18 '21 00:11 ewayboy

But I found it is not caused by “res_tensor”, it maybe caused by “tensorRec0,tensorRec2...". When "tensorRec0,1,2,3" set as global value, it will lead to memory explosion. tensorRec0 = outputs.get(2).toTensor(); tensorRec1 = outputs.get(3).toTensor(); tensorRec2 = outputs.get(4).toTensor(); tensorRec3 = outputs.get(5).toTensor();

semchan avatar Nov 18 '21 01:11 semchan

@semchan I have the same problem. Did you solve it?

HZNUJeffreyRen avatar Jan 04 '22 09:01 HZNUJeffreyRen

But I found it is not caused by “res_tensor”, it maybe caused by “tensorRec0,tensorRec2...". When "tensorRec0,1,2,3" set as global value, it will lead to memory explosion. tensorRec0 = outputs.get(2).toTensor(); tensorRec1 = outputs.get(3).toTensor(); tensorRec2 = outputs.get(4).toTensor(); tensorRec3 = outputs.get(5).toTensor();

I detached the tensor and it seems solved the problem. tensorRec0 = outputs.get(2).toTensor().detach(); tensorRec1 = outputs.get(3).toTensor().detach(); tensorRec2 = outputs.get(4).toTensor().detach(); tensorRec3 = outputs.get(5).toTensor().detach();

ked19 avatar Nov 23 '22 08:11 ked19