aimet
aimet copied to clipboard
Model Validitor for operators that cannot be defined as a class module.
How to ensure the model is validated for use with AIMET with primitive operations that cannot be defined as torch.nn.Module
, such as
- Add
- Divide
- Abs
- ...
e.g.
class ModelWithReusedNodes(torch.nn.Module):
""" Model that reuses a relu module. Expects input of shape (1, 3, 32, 32) """
def __init__(self):
super(ModelWithReusedNodes, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2, stride=2, padding=2, bias=False)
self.bn1 = torch.nn.BatchNorm2d(8)
self.relu1 = torch.nn.ReLU(inplace=True)
self.relu2 = torch.nn.ReLU(inplace=True)
self.linear = torch.nn.Linear(2592, 10)
def forward(self, *inputs):
x = self.conv1(inputs[0])
x = self.relu1(x)
x = self.bn1(x)
x = self.relu2(x)
x = torch.add(x,x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
When running the model validator it results in:
2022-01-14 16:34:53,999 - root - INFO - AIMET
2022-01-14 16:34:54,030 - Utils - INFO - Running validator check <function validate_for_reused_modules at 0x7f5ae96bed90>
2022-01-14 16:34:54,032 - Utils - INFO - Running validator check <function validate_for_missing_modules at 0x7f5a1dd22840>
2022-01-14 16:34:54,046 - Utils - WARNING - Ops with missing modules: ['Add_4']
This can be due to several reasons:
1. There is no mapping for the op in ConnectedGraph.op_type_map. Add a mapping for ConnectedGraph to recognize and be able to map the op.
2. The op is defined as a functional in the forward function, instead of as a class module. Redefine the op as a class module if possible. Else, check 3.
3. This op is one that cannot be defined as a class module, but has not been added to ConnectedGraph.functional_ops. Add to continue.
2022-01-14 16:34:54,047 - Utils - INFO - The following validator checks failed:
2022-01-14 16:34:54,047 - Utils - INFO - <function validate_for_missing_modules at 0x7f5a1dd22840>
How is one supposed to add these kind of operations to the connected graph? Can they be left as is?
I found a solution that works.
One needs to define modules for all the classes you wish to use. That way AIMET would appropriately wrap each module and have control on its input, outputs and params.
For the above example, converting the addition step to such a module resolves the validation check.
class Add(torch.nn.Module):
def __init__(self):
super(Add, self).__init__()
def forward(self, tensor1, tensor2):
return tensor1 + tensor2
Please advise if there is a less laborious way of doing this.
Hi @sohils Thank you for your query. Yes, the solution you have suggested is the way to work around this. AIMET does support a subset of these definitions that one could use , please check - https://github.com/quic/aimet/blob/develop/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py With PyTorch 1.9, one could use torch fx api to do auto updates as demonstrated in this example :https://github.com/quic/aimet/blob/develop/TrainingExtensions/torch/test/python/test_model_preparer.py#L179. Please do let me know if you have further questions.