table-transformer icon indicating copy to clipboard operation
table-transformer copied to clipboard

How to convert this model to ONNX

Open HighPoint opened this issue 2 years ago • 4 comments

Is there any guidance on converting this model to ONNX? Other DETR models have had issues with nested tensors in the conversion.

HighPoint avatar Jun 26 '22 20:06 HighPoint

Hope this helps.

def load_args(json_path):
    data = None
    with open(json_path) as f:
        data = json.load(f)
    return data

def get_model(args, device):
    model, criterion, postprocessors = build_model(args)
    model.to(device)
    if args.model_load_path:
        print(f"loading model from checkpoint: {args.model_load_path}")
        loaded_state_dict = torch.load(args.model_load_path, map_location=device)
        model_state_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in loaded_state_dict.items()
            if k in model_state_dict and model_state_dict[k].shape == v.shape
        }
        model_state_dict.update(pretrained_dict)
        model.load_state_dict(model_state_dict, strict=True)
    return model, criterion, postprocessors

path_to_config = "./src/detection_config.json"                              
path_to_weight = "./pretrained_models/pubtables1m_detection_detr_r18.pth"   

args = load_args(path_to_config)
args["model_load_path"] = path_to_weight 
args = type("Args", (object,), args)                                            
                                                                                  
device = "cpu"                                                             
model, _, postprocessors = get_model(args, device)       
model.eval() 

dummy_input = torch.randn(1, 3, 1854, 1341)
torch.onnx.export(
    model, 
    dummy_input, 
    "detection.onnx", 
    export_params=True, 
    opset_version=12, 
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

LxYuan0420 avatar Jun 27 '22 06:06 LxYuan0420

Hi, getting this error when im trying to do inference on the onnx model: Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from structure.onnx failed:Fatal error: ATen is not a registered function/op Has anyone found out how to fix it? @LxYuan0420

saffie91 avatar Jul 15 '22 15:07 saffie91

@saffie91 did you fix the issue?

yellowjs0304 avatar Feb 28 '23 00:02 yellowjs0304

any update?

nissansz avatar Oct 30 '23 11:10 nissansz