BiRefNet icon indicating copy to clipboard operation
BiRefNet copied to clipboard

关于图像Resize以及生成Mask的疑惑

Open Code-dogcreatior opened this issue 9 months ago • 5 comments

birefnet的默认resize输入应该是1024x1024,如果提升后例如1440x1440会理论上让抠图精度更高一些,可问题是例如我上传一张3000x4000的图像,但是后台返回的mask大概在,860x1000这样一个尺寸,所以这是正常还是不正常的? 为我了偷懒做了个这种服务直接传数据,前端不存在压缩图像的方法。 `import os import torch import base64 from flask import Flask, request, jsonify from flask_cors import CORS from PIL import Image import io from torchvision import transforms from models.birefnet import BiRefNet from utils import check_state_dict from waitress import serve

Initialize Flask app

app = Flask(name) CORS(app)

Global model variables

device = 'cuda' if torch.cuda.is_available() else 'cpu' birefnet_general = None birefnet_matting = None

Image transformations - separate for general and matting

transform_image_general = transforms.Compose([ transforms.Resize((1440, 1440)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])

transform_image_matting = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])

def load_models(): global birefnet_general, birefnet_matting # Load models birefnet_general = BiRefNet(bb_pretrained=False) state_dict_general = torch.load('BiRefNet-general-epoch_244.pth', map_location=device, weights_only=True) state_dict_general = check_state_dict(state_dict_general) birefnet_general.load_state_dict(state_dict_general)

birefnet_matting = BiRefNet(bb_pretrained=False)
state_dict_matting = torch.load('BiRefNet-matting-epoch_100.pth', map_location=device, weights_only=True)
state_dict_matting = check_state_dict(state_dict_matting)
birefnet_matting.load_state_dict(state_dict_matting)

# Move models to device and set to evaluation mode
birefnet_general.to(device).eval()
birefnet_matting.to(device).eval()

Image processing function

def process_image_and_get_mask(image, model, transform): # Convert to RGB if not already in RGB if image.mode != 'RGB': image = image.convert('RGB')

# Preprocess the image
input_image = transform(image).unsqueeze(0).to(device)

# Predict mask
with torch.no_grad():
    preds = model(input_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()

# Convert the prediction to a PIL image and resize
pred_pil = transforms.ToPILImage()(pred)
pred_pil = pred_pil.resize(image.size)

# Convert to grayscale to ensure single channel
pred_pil = pred_pil.convert('L')

# Convert to base64
buffered = io.BytesIO()
pred_pil.save(buffered, format="PNG", optimize=True)
img_str = base64.b64encode(buffered.getvalue()).decode()

return {
    'mask': img_str,
    'width': image.size[0],
    'height': image.size[1]
}

Flask routes

@app.route('/process_high_precision', methods=['POST']) def process_high_precision(): if 'image' not in request.files: return 'No image uploaded', 400

image_file = request.files['image']
image = Image.open(image_file.stream)

return jsonify(process_image_and_get_mask(image, birefnet_general, transform_image_general))

@app.route('/process_transparent_matting', methods=['POST']) def process_transparent_matting(): if 'image' not in request.files: return 'No image uploaded', 400

image_file = request.files['image']
image = Image.open(image_file.stream)

return jsonify(process_image_and_get_mask(image, birefnet_matting, transform_image_matting))

Run the application

if name == 'main': # Load models when the app starts load_models()

# Run the app with Waitress
serve(app, host='0.0.0.0', port=6004, threads=4)`

Code-dogcreatior avatar Mar 06 '25 01:03 Code-dogcreatior

用更大的size未必会得到更好的性能. 应该更准确点来说是, 除非结果特别受限于低分辨率以致于细节不清晰, 其他情况都不建议使用不同于模型训练时所使用的size, 因为之间存在domain shift的问题.

关于860x1000的话, 应该是你的代码的问题, 或者有些lib把图像压缩了, 因为我是没有这样resize的, 顶多写了后处理将预测结果resize到原始图像一致的size... 可以检查下preds = model(input_image)[-1].sigmoid().cpu()的输入输出的shape.

还有问题欢迎继续讨论.

ZhengPeng7 avatar Mar 06 '25 03:03 ZhengPeng7

所以理论上来说就最初发布的general模型resize设定1024,HR2048是最佳设定,我可以这样理解吗?关于1000px这个问题我再查错一下,还有个问题想请教,在您测试过程中是否发现部分mask可能存在锯齿化较为明晰的边缘问题呢,我在部分测试中在低分辨与高分辨都经历过这个事情,虽然并不常见但不时会出现。

Code-dogcreatior avatar Mar 06 '25 05:03 Code-dogcreatior

嗯, 是这样的. 我其实在demo之类的地方都注明了.

锯齿话会不会是size太小或太大了呢? 因为原图resize太多去推理或者推理结果resize太多至原图size, 都可能直接带来锯齿. 如果太严重而使用原始图片size的效果也还可以的话, 那我推荐你就设定原始图片的size.

ZhengPeng7 avatar Mar 06 '25 05:03 ZhengPeng7

首先缩放1000px那个问题确实是我的问题,请原谅我可能由于知识缘故,大部分询问问题偏向于实用方向而非开发方向。我做了一些测试,发现在resize为1024,1024情况下,相较于1440*1440在(我的案例为羽毛球拍网格测试)抠图情况下所展示出的效果略有降低,但处理时间却在1000-4000分辨率下有大概3x上升,超过4000后大概2x上升,所以,我猜测这个问题源自Pytorch计算效率非线性导致的,此外在高分辨下您推荐使用原先的lite2k还是HR模型呢,如果是HR的话设置为1024还是2048? 此外附带一张我提到的问题,换为1440边缘化会有显著改善,这种问题是否能从opencv方向对图像预处理以改善或是对参数调整以进行改善呢? Image

Code-dogcreatior avatar Mar 06 '25 14:03 Code-dogcreatior

没事, 客气啦, 有问题多讨论就好. 目前高分辨率的肯定还是BiRefNet_HR效果最好呢. 这个可能没有定数, 如果你不是批量自动处理的话, 确实可以自行调节分辨率对比对比.

ZhengPeng7 avatar Mar 07 '25 06:03 ZhengPeng7