bug w/ torch.unbind on fp32 precision
πDescribing the bug
Pytorch model that uses the torch.unbind function returns incorrect results after being converted by coremltools when targeting fp32 precision.
To Reproduce
- Please add a minimal code example that can reproduce the error when running it.
import itertools
import torch
import torch.nn as nn
import coremltools as ct
# reduced version from original:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class Attention(nn.Module):
def __init__(self, dim, use_unbind=True):
super(Attention, self).__init__()
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.use_unbind = use_unbind
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, 1, C).permute(2, 0, 3, 1, 4)
if self.use_unbind:
# offending line:
q, k, v = qkv.unbind(0)
else:
# doing this instead will produce correct result:
q = qkv[0]
k = qkv[1]
v = qkv[2]
attn = (q @ k.transpose(-2, -1))
return attn
def convert_to_coreml(model, input, compute_precision=None):
return ct.convert(
torch.jit.trace(model, (input,)),
convert_to='mlprogram',
compute_precision=compute_precision,
# computation is:
# correct on ComputeUnit.CPU_ONLY
# incorrect on ComputeUnit.CPU_AND_GPU & ComputeUnit.ALL
compute_units=ct.ComputeUnit.ALL,
inputs=[
ct.TensorType(
shape=ct.Shape(input.shape),
name=f'i0')
],
)
def run_coreml_model(model, input):
model_out = list(model.predict({'i0': input.numpy()}).values())[0]
return torch.from_numpy(model_out)
def main():
print(f'coremltools: {ct.__version__}')
print(f'torch: {torch.__version__}')
b = 1
n = 100
attn_dim = 128
input = torch.rand((b, n, attn_dim))
model = Attention(
dim=attn_dim, use_unbind=False).eval()
model_unbind = Attention(
dim=attn_dim, use_unbind=True).eval()
# copy weights..
model_unbind.qkv.weight[:] = model.qkv.weight
model_unbind.qkv.bias[:] = model.qkv.bias
out_orig = model(input)
out_orig_unbind = model_unbind(input)
assert torch.sum(torch.abs(out_orig - out_orig_unbind)).item() == 0.0
diff_strs = []
for model_, precision_ in itertools.product(
(model_unbind, model),
(ct.precision.FLOAT32, ct.precision.FLOAT16, None)):
model_cml = convert_to_coreml(model_, input, compute_precision=precision_)
out_cml = run_coreml_model(model_cml, input)
diff = torch.mean(torch.abs(out_orig - out_cml)).item()
diff_strs.append(f'diff (unbind:{int(model_.use_unbind)}, precision:{precision_})'.ljust(51) + f': {diff}')
print('')
for s in diff_strs:
print(s)
if __name__ == '__main__':
with torch.no_grad():
main()
Log output
coremltools: 5.2.0
torch: 1.10.2
Converting Frontend ==> MIL Ops: 97%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 30/31 [00:00<00:00, 2154.68 ops/s]
Running MIL Common passes: 0%| | 0/34 [00:00<?, ? passes/s]/Users/carson/miniconda3/envs/ai-test/lib/python3.8/site-packages/coremltools/converters/mil/mil/passes/name_sanitization_utils.py:129: UserWarning: Output, '37', of the source model, has been renamed to 'var_37' in the Core ML model.
warnings.warn(msg.format(var.name, new_name))
Running MIL Common passes: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 34/34 [00:00<00:00, 1970.79 passes/s]
Running MIL Clean up passes: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 9/9 [00:00<00:00, 1483.66 passes/s]
Converting Frontend ==> MIL Ops: 97%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 30/31 [00:00<00:00, 2522.94 ops/s]
Running MIL Common passes: 0%| | 0/34 [00:00<?, ? passes/s]/Users/carson/miniconda3/envs/ai-test/lib/python3.8/site-packages/coremltools/converters/mil/mil/passes/name_sanitization_utils.py:129: UserWarning: Output, '37', of the source model, has been renamed to 'var_37' in the Core ML model.
warnings.warn(msg.format(var.name, new_name))
Running MIL Common passes: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 34/34 [00:00<00:00, 2027.56 passes/s]
Running MIL FP16ComputePrecision pass: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [00:00<00:00, 73.65 passes/s]
Running MIL Clean up passes: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 9/9 [00:00<00:00, 411.92 passes/s]
Converting Frontend ==> MIL Ops: 97%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 30/31 [00:00<00:00, 2763.17 ops/s]
Running MIL Common passes: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 34/34 [00:00<00:00, 2274.10 passes/s]
Running MIL FP16ComputePrecision pass: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [00:00<00:00, 81.98 passes/s]
Running MIL Clean up passes: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 9/9 [00:00<00:00, 422.63 passes/s]
Converting Frontend ==> MIL Ops: 97%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 33/34 [00:00<00:00, 2610.17 ops/s]
Running MIL Common passes: 0%| | 0/34 [00:00<?, ? passes/s]/Users/carson/miniconda3/envs/ai-test/lib/python3.8/site-packages/coremltools/converters/mil/mil/passes/name_sanitization_utils.py:129: UserWarning: Output, '38', of the source model, has been renamed to 'var_38' in the Core ML model.
warnings.warn(msg.format(var.name, new_name))
Running MIL Common passes: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 34/34 [00:00<00:00, 2103.68 passes/s]
Running MIL Clean up passes: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 9/9 [00:00<00:00, 1231.16 passes/s]
Converting Frontend ==> MIL Ops: 97%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 33/34 [00:00<00:00, 2661.67 ops/s]
Running MIL Common passes: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 34/34 [00:00<00:00, 2026.46 passes/s]
Running MIL FP16ComputePrecision pass: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [00:00<00:00, 89.91 passes/s]
Running MIL Clean up passes: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 9/9 [00:00<00:00, 460.67 passes/s]
Converting Frontend ==> MIL Ops: 97%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 33/34 [00:00<00:00, 2607.46 ops/s]
Running MIL Common passes: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 34/34 [00:00<00:00, 2118.90 passes/s]
Running MIL FP16ComputePrecision pass: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [00:00<00:00, 95.29 passes/s]
Running MIL Clean up passes: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 9/9 [00:00<00:00, 480.03 passes/s]
diff (unbind:1, precision:ComputePrecision.FLOAT32): 5.272708892822266
diff (unbind:1, precision:ComputePrecision.FLOAT16): 0.0005657298606820405
diff (unbind:1, precision:None) : 0.0005657298606820405
diff (unbind:0, precision:ComputePrecision.FLOAT32): 0.0
diff (unbind:0, precision:ComputePrecision.FLOAT16): 0.0005657298606820405
diff (unbind:0, precision:None) : 0.0005657298606820405
System environment (please complete the following information):
Behavior is the same for: coremltools==5.2.0, torch==1.10.2 and coremltools==6.0b2, torch==1.11.0
OS: macOS Monterey 12.5.1. Macbook Pro (14-inch, 2021). Apple M1 Pro, 16 GB memory
Additional context
The workaround is trivial so low priority for me, just wanted to report it :)
I'm not able to execute your code. This line:
model_unbind.qkv.weight[:] = model.qkv.weight
produces the following error:
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
It would be helpful, if I could reproduce the issue.
It would also be help to have an example which is as short as possible but only show the bug, not the other cases that are working.
A concise summary of the issue would also be helpful. I'm not sure I understand the issue.
When using ct.precision.FLOAT32 and not ct.ComputeUnit.CPU_ONLY, the output of torch.ubind doesn't match, is that right?
Summary: the output of the coreml model generated via coremltools is wrong (different from the original torch implementation).
I wonder why the torch.no_grad() call didn't work for you.. Anyway you can try this instead.
import torch
import torch.nn as nn
import coremltools as ct
class Attention(nn.Module):
def __init__(self, dim):
super(Attention, self).__init__()
self.qkv = nn.Linear(dim, dim * 3, bias=True)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, 1, C).permute(2, 0, 3, 1, 4)
# offending line:
q, k, v = qkv.unbind(0)
# doing this instead will produce correct result:
# q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1))
return attn
def main():
attn_dim = 128
input = torch.rand((1, 100, attn_dim))
model = Attention(dim=attn_dim).eval()
out_orig = model(input)
model_cml = ct.convert(
torch.jit.trace(model, (input,)),
convert_to='mlprogram',
compute_precision=ct.precision.FLOAT32,
inputs=[
ct.TensorType(
shape=ct.Shape(input.shape),
name=f'i0')])
out_cml = torch.from_numpy(list(model_cml.predict({'i0': input.numpy()}).values())[0])
diff = torch.mean(torch.abs(out_orig - out_cml)).item()
print(f'mean diff (should be close to 0): {diff}')
if __name__ == '__main__':
with torch.no_grad():
main()
When using
ct.precision.FLOAT32and notct.ComputeUnit.CPU_ONLY, the output oftorch.unbinddoesn't match, is that right?
Yes basically. But to be precise, it's the output of the model that is wrong when unbind is used, not necessarily the output of unbind. The other operations before/after seem to be relevant too. If the model just returns the output of unbind and does not perform the matrix multiply, then the bug doesn't appear.
I can not reproduce this issue. The mean diff I get is 3.2442756037198706e-07. This is on an M1 Pro with macOS 13.1. coremltools==6.0 and torch==1.11.0.
If there is an issue here, it an issue with the Core ML Framework. Please check if this is still an issue with the latest version of coremltools and macOS 13. If this is still an issue please create an issue here: https://developer.apple.com/bug-reporting/.