BiRefNet icon indicating copy to clipboard operation
BiRefNet copied to clipboard

The refine_foreground function takes a long time to execute; how can it be optimized?

Open xiao-keeplearning opened this issue 7 months ago • 17 comments

When executing the demo, I found that refine_foreground takes more than 4 seconds. Is there any way to speed up this process?

xiao-keeplearning avatar Jun 03 '25 03:06 xiao-keeplearning

Really? That step cost me almost 0 seconds. The algorithm I used there was called fast-foreground-estimation. I think it should be fast. Maybe you can check your device?

ZhengPeng7 avatar Jun 03 '25 03:06 ZhengPeng7

It's absolutely true. My machine's CPU is Intel(R) Xeon(R) Gold 6248R CPU @ 3.00GHz. I noticed that the CPU usage is low when executing this code.

xiao-keeplearning avatar Jun 03 '25 11:06 xiao-keeplearning

Aha, that's sad... You can have a try on colab, which is the same for us and easy to debug problems.

ZhengPeng7 avatar Jun 03 '25 11:06 ZhengPeng7

I run BiRefNet_inference.ipynb, The statistics of refine_foreground time consumed are as follows:

Processing ../images_todo/Helicopter-HR.jpg ... refine_time= 2.82s Processing ../images_todo/1.jpg ... refine_time= 3.14s

I think this speed is too slow. Compared with the 10ms model inference, this speed is the performance bottleneck of the entire process

xiao-keeplearning avatar Jun 04 '25 03:06 xiao-keeplearning

Thanks for the feedback! I'll figure it out on my local machine and come back to you today.

ZhengPeng7 avatar Jun 04 '25 03:06 ZhengPeng7

Yes, I also tested it by myself. It's very slow, times of the original inference... I'll try to find a way for some acceleration (not guaranteed :)).

ZhengPeng7 avatar Jun 04 '25 04:06 ZhengPeng7

Thanks for the prompt reply, if I make progress I'll synchronize it as well.

xiao-keeplearning avatar Jun 04 '25 10:06 xiao-keeplearning

Hi, I tried several ways, but most of them failed. But some upgrades have still been updated: https://github.com/ZhengPeng7/BiRefNet/blob/81c95f5390ec7e2535628972bcec3d2de88fa7ed/image_proc.py#L10 Manually setting the dtype of ndarrays to np.float32 can easily accelerate performance by ~20% in my tests without degrading results. It's still not enough... If you have better methods, plz give me a clue :)

ZhengPeng7 avatar Jun 05 '25 10:06 ZhengPeng7

@xiao-keeplearning You can try to use native torch functions instead of opencv functions, together with GPU acceleration, refine_foreground only costs maybe 100ms depends on your GPU devices. My code is as below

def mean_blur(x, kernel_size):
    """
    equivalent to cv.blur
    x:  [B, C, H, W]
    """
    if kernel_size % 2 == 0:
        pad_l = kernel_size // 2 - 1
        pad_r = kernel_size // 2
        pad_t = kernel_size // 2 - 1
        pad_b = kernel_size // 2
    else:
        pad_l = pad_r = pad_t = pad_b = kernel_size // 2

    x_padded = torch.nn.functional.pad(x, (pad_l, pad_r, pad_t, pad_b), mode='replicate')

    return torch.nn.functional.avg_pool2d(
        x_padded,
        kernel_size=(kernel_size, kernel_size),
        stride=1,
        count_include_pad=False
    )


def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):

    blurred_alpha = mean_blur(alpha, kernel_size=r)

    blurred_FA = mean_blur(F * alpha, kernel_size=r)
    blurred_F = blurred_FA / (blurred_alpha + 1e-5)

    blurred_B1A = mean_blur(B * (1 - alpha), kernel_size=r)
    blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)

    F_output = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
    F_output = torch.clamp(F_output, 0, 1)

    return F_output, blurred_B

def refine_foreground(image, mask, r=90):
    F, B = FB_blur_fusion_foreground_estimator(image, image, image, mask, r=r)
    estimated_foreground, _ = FB_blur_fusion_foreground_estimator(image, F, B, mask, r=6)
    return estimated_foreground

lucasgblu avatar Jun 18 '25 07:06 lucasgblu

Oh, thanks for the codes~ I'll check it and integrate it if successful!

ZhengPeng7 avatar Jun 18 '25 13:06 ZhengPeng7

@ZhengPeng7 Hi, I wanted to follow up on the previous code implementation. My apologies - I discovered a minor precision issue in certain edge cases. When using lower-precision inputs like torch.bfloat16, the division operation within FB_blur_fusion_foreground_estimator could cause numerical overflow, resulting in black artifacts in the fusion output (as shown below).

Image

The fixed version stabilizes the calculations by adding explicit numerical conversion. Revised code:

def mean_blur(x, kernel_size):
    """
    equivalent to cv.blur
    x:  [B, C, H, W]
    """
    if kernel_size % 2 == 0:
        pad_l = kernel_size // 2 - 1
        pad_r = kernel_size // 2
        pad_t = kernel_size // 2 - 1
        pad_b = kernel_size // 2
    else:
        pad_l = pad_r = pad_t = pad_b = kernel_size // 2

    x_padded = torch.nn.functional.pad(x, (pad_l, pad_r, pad_t, pad_b), mode='replicate')

    return torch.nn.functional.avg_pool2d(
        x_padded,
        kernel_size=(kernel_size, kernel_size),
        stride=1,
        count_include_pad=False
    )

def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
    as_dtype = lambda x, dtype: x.to(dtype) if x.dtype != dtype else x

    input_dtype = image.dtype
    # convert image to float to avoid overflow
    image = as_dtype(image, torch.float32)
    F = as_dtype(F, torch.float32)
    B = as_dtype(B, torch.float32)
    alpha = as_dtype(alpha, torch.float32)

    blurred_alpha = mean_blur(alpha, kernel_size=r)

    blurred_FA = mean_blur(F * alpha, kernel_size=r)
    blurred_F = blurred_FA / (blurred_alpha + 1e-5)

    blurred_B1A = mean_blur(B * (1 - alpha), kernel_size=r)
    blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)

    F_output = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
    F_output = torch.clamp(F_output, 0, 1)

    return as_dtype(F_output, input_dtype), as_dtype(blurred_B, input_dtype)

def refine_foreground(image, mask, r=90):
    """both image and mask are in range of [0, 1]"""
    F, B = FB_blur_fusion_foreground_estimator(image, image, image, mask, r=r)
    estimated_foreground, _ = FB_blur_fusion_foreground_estimator(image, F, B, mask, r=6)
    return estimated_foreground

Hope this helps improve robustness! Happy to discuss if you encounter any related issues.

lucasgblu avatar Jun 20 '25 05:06 lucasgblu

Wow, that's great! Thanks a lot. I was too busy to finish the rest of checking. But will definitely do it in the next days. Yeah, that's useful -- I also validated the bf16 training of BiRefNet in recent versions, which was very good. So, the improvement here is very valuable.

ZhengPeng7 avatar Jun 20 '25 08:06 ZhengPeng7

Hi, @lucasgblu, I tried to integrate your implementation with GPU acceleration, but the procedure seems to be very slow, and the refinement is not good enough (there are some effects, but weaker than the original version).

I made a colab notebook for you to have a check. You can see the results and time consuming are not expected as you mentioned above. Maybe I made a mistake? Can you help me check it? That would be of much help to me. Thanks a lot!

Here are the vanilla / GPU version refinement / original refinement (white bg here are not good to look carefully, you can right click them to open them in new tabs with black bg for png images):

Image

Image

Image

ZhengPeng7 avatar Jun 23 '25 09:06 ZhengPeng7

@ZhengPeng7 No problem. maybe this weekends cause I have to work during the weekdays. I'll check your code

lucasgblu avatar Jun 24 '25 02:06 lucasgblu

@ZhengPeng7 Sorry for the late reply. After testing your code, I did find some issues. I will first briefly list the problems and then explain them in detail.

1. Why is your implementation slow?

In refine_foreground, when the device='gpu', you didn't move image and mask tensors to GPU, that is to say, your speed test result is torch-version running on CPU. You have to use image = image.cuda() for running on GPU

Image

2. Why the result is not good as original?

In your jupyter implementation, at the end of refine_foreground, you add green background for both GPU and CPU version, BUT not add for the original verison, so the result of both refinement has green blur. For my perspective, I can hardly tell the difference of the results between torch and opencv version. I'll show my results.

Implementation

I’m not accustomed to using Jupyter Notebook, and when I ran your code, I sometimes encountered the timeit function getting stuck. Therefore, I used the mask predicted by the BiRefNet model from your code and re-tested it within my own code. I will include my code and the results at the end.

I conducted the comparison tests using an A100 GPU. It’s important to note that since refine_foreground involves moving tensors to the GPU, to ensure a fair comparison, I measured the execution time of the two functions FB_blur_fusion_foreground_estimator_gpu_2 and FB_blur_fusion_foreground_estimator_cpu_2. I used Python’s timeit module, running each function 100 times per test and repeating the test 5 times. The results are as follows:

  • torch GPU version: 65.07 ms ± 6.12 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)
  • torch CPU version: 2290.18 ms ± 4.27 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)
  • opencv CPU version: 955.26 ms ± 0.93 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)

As we can see, the GPU performance is close to that of the CPU, but the speed is about 15 times faster. I hope my results are helpful to you.

Here's the code and results

OpencCVCPU result Image

PyTorch GPU result Image

Pillow Original result Image

comparison Image

CODE

import cv2
import math
import torch
import time
import timeit
import torchvision
import numpy as np
from PIL import Image
from functools import partial
from typing import List

def convert_pil_to_tensor_0_to_1(
    image: Image.Image
) -> torch.Tensor:
    return torchvision.transforms.functional.to_tensor(image).float()

def convert_pil_to_array_0_to_1(
    image: Image.Image
) -> np.ndarray:
    return np.array(image, dtype=np.float32) / 255.0

def convert_array_to_image(
    array: np.ndarray,
    mode: str = 'RGB'
) -> Image.Image:
    """
    assume the color format of the array is RGB
    For BGR of opencv array, use convert_from_cv2_to_image
    """
    return Image.fromarray(array.astype(np.uint8)).convert(mode)

def ttimeit(func, number=100, repeat=5):
    """
    run func `number` times, and repeat this running for `repeat` times,
    e.g. number = 100, repeat = 5 means run 100 times func for 5 repeat.
    """
    times = timeit.repeat(func, number=number, repeat=repeat)
    times = np.array(times) / number  # 平均每次的耗时
    mean = times.mean()
    std = times.std()
    
    # 自动选择单位
    if mean >= 1:
        unit = "s"
        mean_show = mean
        std_show = std
    elif mean >= 1e-3:
        unit = "ms"
        mean_show = mean * 1e3
        std_show = std * 1e3
    else:
        unit = "μs"
        mean_show = mean * 1e6
        std_show = std * 1e6

    print(f"{mean_show:.2f} {unit} ± {std_show:.2f} {unit} per loop (mean ± std. dev. of {repeat} runs, {number} loops each)")
    

## CPU version refinement
def FB_blur_fusion_foreground_estimator_cpu(image, F, B, alpha, r=90):
    if isinstance(image, Image.Image):
        image = np.array(image) / 255.0
    blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]

    blurred_FA = cv2.blur(F * alpha, (r, r))
    blurred_F = blurred_FA / (blurred_alpha + 1e-5)

    blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
    blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
    F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
    F = np.clip(F, 0, 1)
    return F, blurred_B

def FB_blur_fusion_foreground_estimator_cpu_2(image, alpha, r=90):
    # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
    alpha = alpha[:, :, None]
    F, blur_B = FB_blur_fusion_foreground_estimator_cpu(image, image, image, alpha, r)
    return FB_blur_fusion_foreground_estimator_cpu(image, F, blur_B, alpha, r=6)[0]


## GPU version refinement
def mean_blur(x, kernel_size):
    """
    equivalent to cv.blur
    x:  [B, C, H, W]
    """
    if kernel_size % 2 == 0:
        pad_l = kernel_size // 2 - 1
        pad_r = kernel_size // 2
        pad_t = kernel_size // 2 - 1
        pad_b = kernel_size // 2
    else:
        pad_l = pad_r = pad_t = pad_b = kernel_size // 2

    x_padded = torch.nn.functional.pad(x, (pad_l, pad_r, pad_t, pad_b), mode='replicate')

    return torch.nn.functional.avg_pool2d(x_padded, kernel_size=(kernel_size, kernel_size), stride=1, count_include_pad=False)

def FB_blur_fusion_foreground_estimator_gpu(image, F, B, alpha, r=90):
    as_dtype = lambda x, dtype: x.to(dtype) if x.dtype != dtype else x

    input_dtype = image.dtype
    # convert image to float to avoid overflow
    image = as_dtype(image, torch.float32)
    F = as_dtype(F, torch.float32)
    B = as_dtype(B, torch.float32)
    alpha = as_dtype(alpha, torch.float32)

    blurred_alpha = mean_blur(alpha, kernel_size=r)

    blurred_FA = mean_blur(F * alpha, kernel_size=r)
    blurred_F = blurred_FA / (blurred_alpha + 1e-5)

    blurred_B1A = mean_blur(B * (1 - alpha), kernel_size=r)
    blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)

    F_output = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
    F_output = torch.clamp(F_output, 0, 1)

    return as_dtype(F_output, input_dtype), as_dtype(blurred_B, input_dtype)

def FB_blur_fusion_foreground_estimator_gpu_2(image, alpha, r=90):
    # Thanks to the source: https://github.com/ZhengPeng7/BiRefNet/issues/226#issuecomment-2989825094
    F, blur_B = FB_blur_fusion_foreground_estimator_gpu(image, image, image, alpha, r)
    return FB_blur_fusion_foreground_estimator_gpu(image, F, blur_B, alpha, r=6)[0]

def refine_foreground(image, mask, r=90, device='gpu', timeinfo=False):
    """both image and mask are in range of [0, 1]"""
    if mask.size != image.size:
        mask = mask.resize(image.size)

    if device == 'gpu':
        image = convert_pil_to_tensor_0_to_1(image).cuda()  # move to GPU
        mask = convert_pil_to_tensor_0_to_1(mask).cuda()
        image = image.unsqueeze(0)  # BCHW
        mask = mask.unsqueeze(0)    # BCHW

        estimated_foreground = FB_blur_fusion_foreground_estimator_gpu_2(image, mask, r=r)

        if timeinfo:
            func = partial(FB_blur_fusion_foreground_estimator_gpu_2, image, mask, r=r)
            ttimeit(func)
        
        estimated_foreground = estimated_foreground.squeeze().permute(1, 2, 0).cpu().numpy()    # HWC
        mask = mask.squeeze().cpu().numpy() # HW
    else:
        image = convert_pil_to_array_0_to_1(image)
        mask = convert_pil_to_array_0_to_1(mask)
        estimated_foreground = FB_blur_fusion_foreground_estimator_cpu_2(image, mask, r=r)

        if timeinfo:
            func = partial(FB_blur_fusion_foreground_estimator_cpu_2, image, mask, r=r)
            ttimeit(func)

    estimated_foreground = (estimated_foreground * 255.0).clip(0, 255).astype(np.uint8)

    estimated_foreground = Image.fromarray(estimated_foreground)

    return estimated_foreground


if __name__ == '__main__':
    pil_image = Image.open('./data/onnx_test-2-image.png')
    pil_mask = Image.open('./data/onnx_test-2-mask.png')    # predicted by BiRefNet

    # GPU version
    image_masked = refine_foreground(pil_image, pil_mask, device='gpu', timeinfo=True)
    image_masked.putalpha(pil_mask.resize(pil_image.size))
    image_masked.save('./data/onnx_test-2-result-GPU.png')

    # CPU version
    image_masked = refine_foreground(pil_image, pil_mask, device='cpu', timeinfo=True)
    image_masked.putalpha(pil_mask.resize(pil_image.size))
    image_masked.save('./data/onnx_test-2-result-CPU.png')

    # Original No Refine
    image_masked = pil_image.copy()
    image_masked.putalpha(pil_mask.resize(pil_image.size))
    image_masked.save('./data/onnx_test-2-result-RAW.png')

lucasgblu avatar Jun 29 '25 08:06 lucasgblu

Wow, thanks a lot @lucasgblu! I've been looking into your latest codes and found that the major time-consuming part changed to the transferring between CPU and CUDA, which made the whole processing still slow (~= 0.16s, on a 5090 + 16 cores of an AMD EPYC 9354). I successfully integrated your codes and made some upgrades on other pre- and post-processing codes and make it within 0.085s now.

Anyway, thanks a lot for this contribution. I'm going to make it the default setting. This would definitely be very useful to others.

Hi, @xiao-keeplearning, you can check the latest repo after 10 minutes to test the accelerated refinement.

ZhengPeng7 avatar Jun 30 '25 08:06 ZhengPeng7

@ZhengPeng7 glad to help 😎

lucasgblu avatar Jun 30 '25 11:06 lucasgblu