coremltools
coremltools copied to clipboard
Add support for Torch conv aliases
When trying to convert Torch Conv1d
layer I get:
the following model ops are MISSING:
conv1d
...
RuntimeError: PyTorch convert function for op 'conv1d' not implemented.
And although it seems to be supported, the aliases aren't there hence the error.
A similar situation happens with conv_transpose
It seems to be supported but it fails for the same reason. Though, in this case some code had be added to match the order of inputs used by Torch
I hope this helps to improve the code.
Note: I'm using a torch.script
model
This should also close https://github.com/apple/coremltools/issues/1753
@alealv - thanks for the pull request. In order to merge this, we need unit tests. Please add unit tests for these new aliases.
@alealv does the torch.jit.trace
works for your model?
That is what we recommend in generally, instead of using torch.jit.script
.
I'm also facing the same issue but with conv_transpose2d
. I'm using torch.jit.script
because while torch.jit.trace
runs, it outputs the following: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error: Tensor-likes are not close! Mismatched elements: 239964 / 259200 (92.6%)
@katelyn-chen - I don't know. That is a PyTorch issue. I suggest you ask in a PyTorch forum.
@alealv does the
torch.jit.trace
works for your model? That is what we recommend in generally, instead of usingtorch.jit.script
.
I just tried with tracing as you suggested. Here are the differences
With torch.script
Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.
Converting PyTorch Frontend ==> MIL Ops: 52%|███████████████████████████████████████▉ | 95/183 [00:00<00:00, 6825.15 ops/s]
the following model ops are IMPLEMENTED:
add
clamp
complex
constant
constantchunk
cos
exp
gelu
layer_norm
linear
mul
sin
transpose
the following model ops are MISSING:
conv1d
istft
Traceback (most recent call last):
File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 73, in <module>
main()
File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 54, in main
model.to_coreml(
File "/home/aalvarez/.virtualenvs/tts-train-XZ1ykfT_-py3.9/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/vocos/__init__.py", line 502, in to_coreml
coreml_mdl = ct.convert(
File "/home/aalvarez/Projects/coremltools/coremltools/converters/_converters_entry.py", line 553, in convert
mlmodel = mil_convert(
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 212, in _mil_convert
proto, mil_program = mil_convert_to_proto(
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
prog = frontend_converter(model, **kwargs)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
return load(*args, **kwargs)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 75, in load
return _perform_torch_convert(converter, debug)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 122, in _perform_torch_convert
raise e
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 114, in _perform_torch_convert
prog = converter.convert()
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 484, in convert
convert_nodes(self.context, self.graph)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
raise RuntimeError(
RuntimeError: PyTorch convert function for op 'conv1d' not implemented.
With torch.tracing
Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.
Converting PyTorch Frontend ==> MIL Ops: 96%|████████████████████████████████████████████████████████████████████████▉ | 145/151 [00:00<00:00, 3665.71 ops/s]
the following model ops are IMPLEMENTED:
_convolution
add
complex
constant
constantchunk
cos
exp
gelu
layer_norm
linear
listconstruct
mul
sin
transpose
the following model ops are MISSING:
clip
istft
Traceback (most recent call last):
File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 73, in <module>
main()
File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/export.py", line 54, in main
model.to_coreml(
File "/home/aalvarez/.virtualenvs/tts-train-XZ1ykfT_-py3.9/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/aalvarez/Projects/main/apps/tts-train/vocoder/vocos/__init__.py", line 531, in to_coreml
coreml_mdl = ct.convert(
File "/home/aalvarez/Projects/coremltools/coremltools/converters/_converters_entry.py", line 553, in convert
mlmodel = mil_convert(
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 212, in _mil_convert
proto, mil_program = mil_convert_to_proto(
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
prog = frontend_converter(model, **kwargs)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
return load(*args, **kwargs)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 75, in load
return _perform_torch_convert(converter, debug)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 122, in _perform_torch_convert
raise e
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 114, in _perform_torch_convert
prog = converter.convert()
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 484, in convert
convert_nodes(self.context, self.graph)
File "/home/aalvarez/Projects/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
raise RuntimeError(
RuntimeError: PyTorch convert function for op 'clip' not implemented.
So the Conv1D is mapped to _convolution
but not for torch.script
that's what my MR adds.
I don't know all the mechanism behind coremltools
but they should end up with the same representation. Meaning we should also map to _convolution
, shouldn't we?
@alealv - if you want to add conv1d
support for torch.script
, that's fine. We'll just need unit tests for this functionality. FYI - coremltools support for torch.script
is only "experimental". So this isn't a priority for us.
Regarding your torch.trace
error - it looks like clip
is just alias for clamp
which we already support.
I just updated Convolution tests to be also tested when using torch script.
I'm trying to figure out how the nn.Conv1D(...)
get's converted with JIT.
The test fails with:
E ValueError: Torch var bias not found in context
I'm getting a problem only when bias
is False
. And I don't fully understand what should I do
Here is the output graph:
graph(%self : __torch__.torch.nn.modules.conv.Conv1d,
%input.1 : Tensor):
%weight : Tensor = prim::GetAttr[name="weight"](%self)
%bias : Tensor? = prim::GetAttr[name="bias"](%self)
%4 : int = prim::Constant[value=1]() # /root/coremltools/envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:306:45
%5 : int = prim::Constant[value=0]() # /root/coremltools/envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:307:24
%6 : int = prim::Constant[value=3]() # /root/coremltools/envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:307:38
%7 : int[] = prim::ListConstruct(%4)
%8 : int[] = prim::ListConstruct(%5)
%9 : int[] = prim::ListConstruct(%6)
%10 : Tensor = aten::conv1d(%input.1, %weight, %bias, %7, %8, %9, %4) # /root/coremltools/envs/coremltools-dev-py3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:306:15
return (%10)
So my understanding is that we have only one input, the weights
. Which makes sense given that we bias
is False
.
But aten::conv1d
requires bias
, though it can it says it's optional. Hence, it works because it isn't provided.
How does coremltools
handles this?
I see that:
Node: %bias : Tensor? = prim::GetAttr[name="bias"](%self)
Type: <class 'torch.Node'>
Is tensor: False
Is quantize tensor: False
prefix: bias
Module: None
And we have if it's a tensor it does nothing.
def _lower_graph_block(graph):
for node in list(graph.nodes()):
...
is_tensor = _check_is_tensor(node, module)
is_quantized_tensor = _check_is_quantized_tensor(node, module)
if is_tensor or is_quantized_tensor:
...
def _check_is_tensor(node, module):
if not isinstance(module, torch.Tensor):
return False
if str(node.output().type()) not in ("Tensor", "Optional[Tensor]"):
raise TypeError(f'Type "{node.output().type()}" not supported')
return True
Can anyone help me to understand this?