pose-hg-3d
pose-hg-3d copied to clipboard
Trace the Model to use in C++
With the Release of pytorch b1.0 it is now possible to load pytorch models trained in Python in C++. But for doing so it is necessary to trace the model with the torch.jit.trace function.
So I tried it. I trained the model and tried to trace it using the following code.
import torch
import torchvision
import cv2
from utils.debugger import Debugger
from utils.eval import getPreds
import numpy as np
# An instance of your model.
debugger = Debugger()
model = torch.load('/home/narvis/Dev/pytorch-pose-hg-3d/src/model_10.pth').cuda()
# An example input you would normally provide to your model's forward() method.
img = cv2.imread("/home/narvis/Dev/Datasets/mpii_human/images/000001163.jpg")
frame = cv2.resize(img, (256, 256))
input = torch.from_numpy(frame.transpose(2, 0, 1)).float() / 256.
input = input.view(1, input.size(0), input.size(1), input.size(2))
input_var = torch.autograd.Variable(input).float().cuda()
output = model(input_var)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
pred = getPreds((output[-2].data).cpu().numpy())[0] * 4
debugger.addImg((input[0].numpy().transpose(1, 2, 0) * 256).astype(np.uint8))
debugger.addPoint2D(pred)
reg = (output[-1].data).cpu().numpy().reshape(pred.shape[0], 1)
debugger.addPoint3D(np.concatenate([pred, (reg + 1) / 2. * 256], axis = 1))
#debugger.showImg(pause = True)
#debugger.show3D()
traced_script_module = torch.jit.trace(model, input_var)
traced_script_module.save("model_og.pt")
print("end")
unfortunately I get an error in the torch.jit.trace method:
Traceback (most recent call last):
File "/snap/pycharm-professional/89/helpers/pydev/pydevd.py", line 1664, in <module>
main()
File "/snap/pycharm-professional/89/helpers/pydev/pydevd.py", line 1658, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "/snap/pycharm-professional/89/helpers/pydev/pydevd.py", line 1068, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "/snap/pycharm-professional/89/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home/narvis/Dev/pytorch-pose-hg-3d/src/pythontracing.py", line 30, in <module>
traced_script_module = torch.jit.trace(model, input_var)
File "/home/narvis/miniconda3/envs/torch1/lib/python3.6/site-packages/torch/jit/__init__.py", line 565, in trace
module._create_method_from_trace('forward', func, example_inputs)
RuntimeError: Only tensors and (possibly nested) tuples of tensors are supported as inputs or outputs of traced functions (toIValue at /opt/conda/conda-bld/pytorch-nightly_1538559991109/work/torch/csrc/jit/pybind_utils.h:74)
frame #0: <unknown function> + 0x3fe33f (0x7ff67a05b33f in /home/narvis/miniconda3/envs/torch1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #1: <unknown function> + 0x45be99 (0x7ff67a0b8e99 in /home/narvis/miniconda3/envs/torch1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x457e62 (0x7ff67a0b4e62 in /home/narvis/miniconda3/envs/torch1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x463a20 (0x7ff67a0c0a20 in /home/narvis/miniconda3/envs/torch1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x1a661d (0x7ff679e0361d in /home/narvis/miniconda3/envs/torch1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #5: _PyCFunction_FastCallDict + 0x154 (0x564024fad9e4 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #6: <unknown function> + 0x19cf4e (0x56402503af4e in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #7: _PyEval_EvalFrameDefault + 0x2fa (0x56402505f94a in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #8: <unknown function> + 0x196206 (0x564025034206 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #9: <unknown function> + 0x1971cf (0x5640250351cf in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #10: <unknown function> + 0x19ced5 (0x56402503aed5 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #11: _PyEval_EvalFrameDefault + 0x2fa (0x56402505f94a in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #12: PyEval_EvalCodeEx + 0x329 (0x564025035cb9 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #13: PyEval_EvalCode + 0x1c (0x564025036a4c in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #14: <unknown function> + 0x1bf37b (0x56402505d37b in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #15: _PyCFunction_FastCallDict + 0x91 (0x564024fad921 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #16: <unknown function> + 0x19cfe0 (0x56402503afe0 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #17: _PyEval_EvalFrameDefault + 0x2fa (0x56402505f94a in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #18: <unknown function> + 0x196206 (0x564025034206 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #19: <unknown function> + 0x1971cf (0x5640250351cf in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #20: <unknown function> + 0x19ced5 (0x56402503aed5 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #21: _PyEval_EvalFrameDefault + 0x2fa (0x56402505f94a in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #22: <unknown function> + 0x196206 (0x564025034206 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #23: <unknown function> + 0x1971cf (0x5640250351cf in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #24: <unknown function> + 0x19ced5 (0x56402503aed5 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #25: _PyEval_EvalFrameDefault + 0x2fa (0x56402505f94a in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #26: <unknown function> + 0x196f8b (0x564025034f8b in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #27: <unknown function> + 0x19ced5 (0x56402503aed5 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #28: _PyEval_EvalFrameDefault + 0x2fa (0x56402505f94a in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #29: PyEval_EvalCodeEx + 0x329 (0x564025035cb9 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #30: PyEval_EvalCode + 0x1c (0x564025036a4c in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #31: <unknown function> + 0x214c44 (0x5640250b2c44 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #32: PyRun_FileExFlags + 0xa1 (0x5640250b3041 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #33: PyRun_SimpleFileExFlags + 0x1c4 (0x5640250b3244 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #34: Py_Main + 0x624 (0x5640250b6d24 in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #35: main + 0xee (0x564024f7e75e in /home/narvis/miniconda3/envs/torch1/bin/python)
frame #36: __libc_start_main + 0xf0 (0x7ff69debe830 in /lib/x86_64-linux-gnu/libc.so.6)
frame #37: <unknown function> + 0x1c847b (0x56402506647b in /home/narvis/miniconda3/envs/torch1/bin/python)
The debugger.showImg or show3D functions are both working so there seems to be no problem with the input_var variable. Also I inspected the input_var variable and it is definitely a tensor. I dont know what the issue with this code is.
@tobiascz : Hi, I'm facing the same issue, did you come up with a workaround? thx