hifi-gan icon indicating copy to clipboard operation
hifi-gan copied to clipboard

Tracing to torchscript

Open ctlaltdefeat opened this issue 4 years ago • 14 comments

Has anyone been able to successfully convert the generator model to torchscript?

I receive a bizarre error: while tracing works

zero = torch.full((1, 80, 10), -11.52).cuda()
with open("hifi-gan/config.json") as f:
    data = f.read()
h = env.AttrDict(json.loads(data))
vocoder = models.Generator(h).cuda()
vocoder.load_state_dict(
    torch.load("hifi-gan/pretrained_universal/g_02500000")["generator"]
)
vocoder.remove_weight_norm()
vocoder.eval()
with torch.no_grad():
    traced_vocoder = torch.jit.trace(vocoder, zero)
    torch.jit.save(traced_vocoder, "vocoder.pth")

Trying to then load the model gives a weird error:

traced_vocoder = torch.jit.load("vocoder.pth")
/opt/conda/lib/python3.8/site-packages/torch/jit/_serialization.py in load(f, map_location, _extra_files)
    159     cu = torch._C.CompilationUnit()
    160     if isinstance(f, str) or isinstance(f, pathlib.Path):
--> 161         cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files)
    162     else:
    163         cpp_module = torch._C.import_ir_module_from_buffer(

RuntimeError: Found character '45' in string, strings must be qualified Python identifiers

ctlaltdefeat avatar Jan 25 '21 06:01 ctlaltdefeat

I were able to convert to torchscript via jit.script after some slight modifications, can share the repo tonight

ErenBalatkan avatar Jan 26 '21 05:01 ErenBalatkan

@ErenBalatkan would appreciate it, thanks!

epochsimate avatar Jan 26 '21 12:01 epochsimate

@ErenBalatkan have you noticed any speed up for the scripted model? Thanks!

evrrn avatar Jan 26 '21 14:01 evrrn

You can find my modified version here

I have also included a simple benchmark for comparing scripted version to PyTorch.

@ErenBalatkan have you noticed any speed up for the scripted model? Thanks!

I did observe around %10 on my work laptop (cpu), %5 on my desktop on both CPU and GPU.

ErenBalatkan avatar Jan 26 '21 20:01 ErenBalatkan

Thank you for the modified version that compiles as a scripted module, however I still receive the same error when doing torch.jit.load. The fact that it seems to work for you is a bit puzzling given that I haven't done anything special to my installation and other scripted modules load fine.

ctlaltdefeat avatar Jan 28 '21 16:01 ctlaltdefeat

Hmm, it works fine both on my home and work computers. I suggest trying the script with Nvidia's PyTorch docker container, it may help with your problem.

https://ngc.nvidia.com/catalog/containers/nvidia:pytorch

ErenBalatkan avatar Jan 28 '21 17:01 ErenBalatkan

It's definitely an environment issue and/or pytorch bug, as I did confirm it working on a different set up. In any case, I'll leave this issue open here so that your modifications may be merged if the authors wish to.

ctlaltdefeat avatar Jan 28 '21 18:01 ctlaltdefeat

Have been working with TorchScript on another project, and stumbled across this issue with the exact same error message.

For me the issue was that I importing modules I was tracing with a dashes (character 45) in the paths. Maybe the dash in hifi-gan is the problem for you too? I don't know why this information is incorporated into the TorchScript binary file, but changing the path to underscores fixed it the error when loading in C++ for me.

Axelwickm avatar Feb 02 '21 23:02 Axelwickm

Have been working with TorchScript on another project, and stumbled across this issue with the exact same error message.

For me the issue was that I importing modules I was tracing with a dashes (character 45) in the paths. Maybe the dash in hifi-gan is the problem for you too? I don't know why this information is incorporated into the TorchScript binary file, but changing the path to underscores fixed it the error when loading in C++ for me.

Thanks! I was indeed using dashes and loading using importlib, and when I instead just added those paths to the sys path the error goes away. Does seem like a weird Torchscript bug.

ctlaltdefeat avatar Feb 03 '21 05:02 ctlaltdefeat

Hey there. I have the same error for a different model. Even after removing the importlib uses from my code, still getting the same behavior as described in the #51: model traces without errors, and upon loading hits the error with character 45. I've also tried removing all extra_files from the trace, to no avail.

Anyone have another idea? Or can point me to the best way to debug this given that it's a cpp_module?

SarBH avatar Feb 15 '22 12:02 SarBH

You can find my modified version here

I have also included a simple benchmark for comparing scripted version to PyTorch.

@ErenBalatkan have you noticed any speed up for the scripted model? Thanks!

I did observe around %10 on my work laptop (cpu), %5 on my desktop on both CPU and GPU.

@ErenBalatkan Could you upload again?

zhangsanfeng86 avatar Mar 26 '22 03:03 zhangsanfeng86

稍作修改后,我可以通过 jit.script 转换为 torchscript,今晚可以分享 repo

can you share your repo again? Thank you very much!

exercise-book-yq avatar Aug 08 '22 08:08 exercise-book-yq

你可以在这里找到我的修改版本

我还提供了一个简单的基准,用于将脚本版本与 PyTorch 进行比较。

@ErenBalatkan您是否注意到脚本模型的任何加速?谢谢!

我确实在工作笔记本电脑 (cpu) 上观察到 %10 左右,在 CPU 和 GPU 上的台式机上观察到 %5。

@ErenBalatkan 可以再上传吗?

can you share your repo again?

exercise-book-yq avatar Aug 08 '22 08:08 exercise-book-yq

@ErenBalatkan can you share your script again? Would be super helpful for my project. Appreciate it!

vionwinnie avatar Dec 01 '22 22:12 vionwinnie