BiRefNet
BiRefNet copied to clipboard
inference speed extremely slow
Hello,
The inference speed is extremely slow. I am doing the inference with GPU, but its the same i am doing with u2net and ths speed there is 12x faster.
Is there anything i can do to speed up things?
I have also tried to export to onnx but get error
import torch import torch.onnx from models.birefnet import BiRefNet from utils import check_state_dict from torch.onnx import register_custom_op_symbolic
Register custom symbolic function for deform_conv2d
def deform_conv2d_symbolic(g, input, weight, offset, bias, stride, padding, dilation, groups, deformable_groups, use_mask=False, mask=None): return g.op("DeformConv2d", input, weight, offset, bias, stride_i=stride, padding_i=padding, dilation_i=dilation, groups_i=groups, deformable_groups_i=deformable_groups)
register_custom_op_symbolic('torchvision::deform_conv2d', deform_conv2d_symbolic, 11)
Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BiRefNet(bb_pretrained=False).to(device) state_dict = torch.load("/root/BiRefNet-massive-epoch_240.pth", map_location=device) state_dict = check_state_dict(state_dict) model.load_state_dict(state_dict) model.eval()
Dummy input to trace the model
dummy_input = torch.randn(1, 3, 1024, 1024).to(device)
Ensure to handle tensor-to-Python type conversions in your model
Example modifications:
if W % self.patch_size[1] != 0:
replace with
if (W % self.patch_size[1]).item() != 0:
Export the model
onnx_model_path = "/root/BiRefNet.onnx" torch.onnx.export( model, # model being run dummy_input, # model input (or a tuple for multiple inputs) onnx_model_path, # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=11, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names=['input'], # the model's input names output_names=['output'], # the model's output names dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # variable length axes )
print(f"Model has been converted to ONNX and saved at {onnx_model_path}")