Model with stack does not work with int8 target type
Converting this dummy model with quantize_target_type="int8" and per_tensor=True throws an error in tflite
import torch.nn as nn
import torch
from tinynn.graph.quantization.quantizer import PostQuantizer
from tinynn.converter import TFLiteConverter
class StackModel(nn.Module):
def forward(self, x: torch.Tensor):
"""
Args:
x: [N, H, W, C]
"""
return torch.stack([-x, x], dim=-1)
def _main():
dummy_input = torch.rand(1, 60, 60, 256).float()
model = StackModel()
qat_config = {
"backend": "qnnpack",
"per_tensor": True,
"disable_requantization_for_cat": True
}
quantizer = PostQuantizer(
model, (dummy_input,), work_dir="stack_model", config=qat_config
)
ptq_coarse_matcher = quantizer.quantize()
ptq_coarse_matcher(dummy_input)
with torch.no_grad():
ptq_coarse_matcher.eval()
ptq_coarse_matcher.cpu()
ptq_coarse_matcher = quantizer.convert(ptq_coarse_matcher)
torch.backends.quantized.engine = quantizer.backend
converter = TFLiteConverter(
ptq_coarse_matcher,
(dummy_input),
"stack_model.tflite",
fuse_quant_dequant=True,
quantize_target_type="int8"
)
converter.convert()
if __name__ == '__main__':
_main()
Tflite error:
return self._interpreter.AllocateTensors()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /tensorflow/tensorflow/lite/kernels/concatenation.cc:184 t->params.zero_point != output->params.zero_point (-1 != 0)Node number 3 (CONCATENATION) failed to prepare.
Note that the model works fine if I remove the "negative x" and instead send the same tensor twice, and it works with uint8
Well, we need to apply the same logic to stack.
Seems to also happen if I turn it into a cat op:
class CatModel(nn.Module):
def forward(self, x: torch.Tensor):
"""
Args:
x: [N, H, W, C]
"""
return torch.cat([-x.unsqueeze(-1), x.unsqueeze(-1)], dim=-1)
@spacycoder What about this?
class CatModel(nn.Module):
def forward(self, x: torch.Tensor):
"""
Args:
x: [N, H, W, C]
"""
z = x.unsqueeze(-1)
return torch.cat([-z, z], dim=-1)
That also fails
Or this?
class CatModel(nn.Module):
def forward(self, x: torch.Tensor):
"""
Args:
x: [N, H, W, C]
"""
return torch.cat([-x, x], dim=-1).view(x.shape[:-1] + [-1, 2])
Nope, doesn't work either
Okay, will look into it tomorrow.
@spacycoder It seems that the problem is on mul_scalar. The q-params for this op is calculated on the fly.
@spacycoder Things should work with https://github.com/alibaba/TinyNeuralNetwork/pull/360
This also fails with the same concatenation error:
import torch.nn as nn
import torch
from tinynn.graph.quantization.quantizer import PostQuantizer
from tinynn.converter import TFLiteConverter
class EncoderLayer(nn.Module):
def __init__(
self,
d_model: int = 256
):
super().__init__()
self.mlp0 = nn.Linear(d_model, d_model, bias=False)
self.mlp1 = nn.Linear(d_model * 2, d_model, bias=False)
def forward(
self,
x: torch.Tensor,
):
x = x.permute(0, 2, 3, 1)
m = self.mlp0(x)
m = torch.cat([x, m], dim=-1)
m = self.mlp1(m)
return x + m
class Dummy(nn.Module):
def __init__(self):
super().__init__()
self.encoder = EncoderLayer(256)
def forward(self, x, y):
x = self.encoder(x)
y = self.encoder(y)
return x, y
def _main():
dummy_input0 = torch.rand(1, 256, 60, 60).float()
dummy_input1 = torch.rand(1, 256, 60, 60).float()
model = Dummy()
ptq_config = {
"backend": "qnnpack",
"per_tensor": True,
"disable_requantization_for_cat": True
}
quantizer = PostQuantizer(
model, (dummy_input0, dummy_input1), work_dir="cat_model", config=ptq_config
)
ptq_model = quantizer.quantize()
ptq_model(dummy_input0, dummy_input1)
with torch.no_grad():
ptq_model.eval()
ptq_model.cpu()
ptq_model = quantizer.convert(ptq_model)
torch.backends.quantized.engine = quantizer.backend
converter = TFLiteConverter(
ptq_model,
(dummy_input0, dummy_input1),
"cat_model.tflite",
fuse_quant_dequant=True,
quantize_target_type="int8"
)
converter.convert()
if __name__ == '__main__':
_main()
FYI having two separate encoders works (but I need them to be the same):
class Dummy(nn.Module):
def __init__(self):
super().__init__()
self.encoder0 = EncoderLayer(256)
self.encoder1 = EncoderLayer(256)
def forward(self, x, y):
x = self.encoder0(x)
y = self.encoder1(y)
return x, y
This also fails with the same concatenation error:
import torch.nn as nn import torch from tinynn.graph.quantization.quantizer import PostQuantizer from tinynn.converter import TFLiteConverter class EncoderLayer(nn.Module): def __init__( self, d_model: int = 256 ): super().__init__() self.mlp0 = nn.Linear(d_model, d_model, bias=False) self.mlp1 = nn.Linear(d_model * 2, d_model, bias=False) def forward( self, x: torch.Tensor, ): x = x.permute(0, 2, 3, 1) m = self.mlp0(x) m = torch.cat([x, m], dim=-1) m = self.mlp1(m) return x + m class Dummy(nn.Module): def __init__(self): super().__init__() self.encoder = EncoderLayer(256) def forward(self, x, y): x = self.encoder(x) y = self.encoder(y) return x, y def _main(): dummy_input0 = torch.rand(1, 256, 60, 60).float() dummy_input1 = torch.rand(1, 256, 60, 60).float() model = Dummy() ptq_config = { "backend": "qnnpack", "per_tensor": True, "disable_requantization_for_cat": True } quantizer = PostQuantizer( model, (dummy_input0, dummy_input1), work_dir="cat_model", config=ptq_config ) ptq_model = quantizer.quantize() ptq_model(dummy_input0, dummy_input1) with torch.no_grad(): ptq_model.eval() ptq_model.cpu() ptq_model = quantizer.convert(ptq_model) torch.backends.quantized.engine = quantizer.backend converter = TFLiteConverter( ptq_model, (dummy_input0, dummy_input1), "cat_model.tflite", fuse_quant_dequant=True, quantize_target_type="int8" ) converter.convert() if __name__ == '__main__': _main()
Okay, I guess it is because we refuse to traverse into the same nodes in the computation graph again. We need to refine the constraints a little bit.
Related code snippet: https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/graph/quantization/quantizer.py#L1214C9-L1233
This seems to be a decent workaround for the moment:
class Dummy(nn.Module):
def __init__(self):
super().__init__()
self.encoder = EncoderLayer(256)
def forward(self, x, y):
x_cat = torch.cat([x, y], dim=0)
x_cat = self.encoder(x_cat)
x, y = torch.chunk(x_cat, 2, dim=0)
return x, y
@spacycoder I'm glad it works and it looks cleaner.