TinyNeuralNetwork
TinyNeuralNetwork copied to clipboard
RMS Norm doesn't seem to be supported
Hi, converting a model that uses nn.RMSNorm
does not work:
class RMSNormModel(nn.Module):
def __init__(self):
super().__init__()
self.norm = nn.RMSNorm(3, 0.1)
def forward(self, x):
x = x.transpose(1, 3) # [N, H, W, C]
x = self.norm(x)
return x.transpose(1, 3) # [N, C, H, W]
def _main():
dummy_input = torch.rand(1, 3, 224, 224)
model = RMSNormModel()
qat_config = { "backend": "qnnpack" }
quantizer = PostQuantizer(
model, (dummy_input), work_dir="rms_model", config=qat_config
)
ptq_coarse_matcher = quantizer.quantize()
error:
ERROR (tinynn.graph.tracer) Connection is lost when generating code for transpose_1_f of type torch.Tensor.transpose
Traceback (most recent call last):
File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 3380, in trace
new_graph.init()
File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 2041, in init
self.module(*actual_input)
File ".../lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "test_rms_norm.py", line 15, in forward
return x.transpose(1, 3) # [N, C, H, W]
^^^^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 1089, in new_func
trace_func = TraceFunction(key, is_class).parse_args(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 646, in parse_args
arg_str = _parse_args(args)
^^^^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 589, in _parse_args
self.tensor_names.append(_tensor_name(a))
^^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 549, in _tensor_name
pre_node_name = current_graph().tensor_pre_node_dict[id(a)]
KeyError: 130226902469760
ERROR (tinynn.graph.tracer) inputs: ['input_0_f']
ERROR (tinynn.graph.tracer) forwards: ['transpose_0_f']
ERROR (tinynn.graph.tracer) outputs: []
ERROR (tinynn.graph.tracer) constants: []