table-transformer
table-transformer copied to clipboard
How to convert this model to ONNX
Is there any guidance on converting this model to ONNX? Other DETR models have had issues with nested tensors in the conversion.
Hope this helps.
def load_args(json_path):
data = None
with open(json_path) as f:
data = json.load(f)
return data
def get_model(args, device):
model, criterion, postprocessors = build_model(args)
model.to(device)
if args.model_load_path:
print(f"loading model from checkpoint: {args.model_load_path}")
loaded_state_dict = torch.load(args.model_load_path, map_location=device)
model_state_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in loaded_state_dict.items()
if k in model_state_dict and model_state_dict[k].shape == v.shape
}
model_state_dict.update(pretrained_dict)
model.load_state_dict(model_state_dict, strict=True)
return model, criterion, postprocessors
path_to_config = "./src/detection_config.json"
path_to_weight = "./pretrained_models/pubtables1m_detection_detr_r18.pth"
args = load_args(path_to_config)
args["model_load_path"] = path_to_weight
args = type("Args", (object,), args)
device = "cpu"
model, _, postprocessors = get_model(args, device)
model.eval()
dummy_input = torch.randn(1, 3, 1854, 1341)
torch.onnx.export(
model,
dummy_input,
"detection.onnx",
export_params=True,
opset_version=12,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
Hi, getting this error when im trying to do inference on the onnx model: Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from structure.onnx failed:Fatal error: ATen is not a registered function/op Has anyone found out how to fix it? @LxYuan0420
@saffie91 did you fix the issue?
any update?