Avoid fp32 cast for Torch div operator
The div Torch op was always casting both operands to fp32, even if both operands are of type fp16. This cast should get removed by the "common::add_fp16_cast" optimization pass. However, it causes issues during the PyTorch conversion, for example let's say we have a forward method like this:
class Foo:
def __init__(self):
super().__init__()
self.proj = torch.nn.Linear(16, 1)
def forward(self, x, y): # both fp16 tensors, shape [1, 16]
r = x / y # r is now fp32
return self.proj(r) # Problem
Now if we have moved the model (and it's parameters) to fp16 with eg. m = Foo().to(torch.float16), we get an error at conversion time:
In op, of type linear, named linear_0, the named input
biasmust have the same data type as the named inputx. However, bias has dtype fp16 whereas x has dtype fp32.
This is because the result of the div operation stays fp32, and this doesn't match the resulting type of the PyTorch expression.
Please add a unit test to test_torch_ops.py which fails without your fix but passes with your fix.
I stumbled across the same issue and managed to debug it. It turns out the root cause isn't the div op. This happens because the torch converter casts inputs to fp32 here.
Here's a minimal repro that does not use the div op and still fails with the same error:
import coremltools as ct
import numpy as np
import torch
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.proj = torch.nn.Linear(16, 1)
def forward(self, x):
return self.proj(x)
x = torch.randn(1, 16, dtype=torch.float16)
with torch.no_grad():
mlmodel = ct.convert(
torch.jit.trace(Net().half().eval(), x),
inputs=[ct.TensorType(name="x", shape=x.shape, dtype=np.float16)],
outputs=[ct.TensorType(name="output")],
convert_to="mlprogram",
compute_precision=ct.precision.FLOAT16,
minimum_deployment_target=ct.target.iOS17,
)
This fails with the same exception as the snippet above with the div op:
ValueError: In op, of type linear, named linear_0, the named input `bias` must have the same data type as the named input `weight`. However, bias has dtype fp16 whereas weight has dtype fp32.
I've got a fix for this in #2274.
Hi @HennerM, inputs=[ct.TensorType(dtype=np.float16)] and compute_precision=ct.precision.FLOAT16 are enough to obtain a fp16-input fp16-computation Core ML model. There is no need to make the PyTorch model itself fp16
Concretely, internally we translate torch model in fp32. Then,
- If given
compute_precision=ct.precision.FLOAT16, we will insert fp16 casts to make computation (i.e. weight & activation) fp16 - If given
inputs=[ct.TensorType(name="x", shape=x.shape, dtype=np.float16)], we will change input signature forxto fp16