How To Add NMS in pytorch model so that it gets converted into TFLITE
First off, thanks for this amazing repo. Im working on a ssd model in pytorch and i want to add post processing (NMS) into the tflite, how can i add it into my model so that it gets translated to tflite's NMS OP. thanks
@gj-raza This is currently unsupported. We will take a look later. This shouldn't be too difficult I guess.
@gj-raza Is torchvision.ops.nms sufficient for your usage?
JFYI, I've added the TFlite custom post processing OP in a keras object detection model following this method, but since pytorch has no NMS layer, so its getting tricky here.
Also i've noticed, TFLite's NMS v4 and v5, exisits in your code here , can you please explain what is the overall flow of tinynn converter. so that i can contribute this feature myself
@gj-raza Is
torchvision.ops.nmssufficient for your usage?
i've tried this but it seems to not work. have a look at below code
# Simple Pytorch model
class PyModel(nn.Module):
def __init__(self):
super(PyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, (3,3), 1, 1)
self.conv2 = nn.Conv2d(64, 64, (3,3), 1,1)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = torch.reshape(out,(640000,4))
out1 = out [:,0]
out3 = torchvision.ops.nms(out, out1, 0.5)
return out3
def main_worker(args):
print("###### TinyNeuralNetwork quick start for beginner ######")
torch.cuda.empty_cache()
model = PyModel()
device = get_device()
model.to(device=device)
# Provide a viable input for the model
dummy_input = torch.rand((1, 3, 200, 200))
context = DLContext()
context.device = device
context.train_loader, context.val_loader = get_dataloader(args.data_path, 220, args.batch_size, args.workers)
print("\n###### Start preparing the model for quantization ######")
# We provides a QATQuantizer class that may rewrite the graph for and perform model fusion for quantization
# The model returned by the `quantize` function is ready for QAT training
quantizer = QATQuantizer(model, dummy_input, work_dir='out')
qat_model = quantizer.quantize()
print("\n###### Start converting the model to TFLite ######")
with torch.no_grad():
qat_model.eval()
qat_model.cpu()
# The step below converts the model to an actual quantized model, which uses the quantized kernels.
qat_model = torch.quantization.convert(qat_model)
print (type(qat_model))
# When converting quantized models to TFLite, please ensure the quantization backend is QNNPACK.
torch.backends.quantized.engine = 'qnnpack'
# The code section below is used to convert the model to the TFLite format
converter = TFLiteConverter(qat_model, dummy_input, tflite_path='out/qat_model_small.tflite')
converter.convert()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data-path', metavar='DIR', default="/data/datasets/cifar10", help='path to cifar10 dataset')
parser.add_argument('--workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=256)
args = parser.parse_args()
main_worker(args)
it gives following error,
TinyNeuralNetwork quick start for beginner
Start preparing the model for quantization
Traceback (most recent call last): File "/home/gj/hazen/alibaba_tinynn/TinyNeuralNetwork/tinynn/graph/tracer.py", line 2061, in trace new_graph.init() File "/home/gj/hazen/alibaba_tinynn/TinyNeuralNetwork/tinynn/graph/tracer.py", line 1317, in init self.module(*actual_input) File "/home/gj/anaconda3/envs/pyt-tinynn/lib/python3.6/site-packages/torch/nn/modules/module.py", line 726, in _call_impl hook_result = hook(self, input, result) File "/home/gj/hazen/alibaba_tinynn/TinyNeuralNetwork/tinynn/graph/tracer.py", line 1110, in _model_tracer add_output_node(node, outputs) File "/home/gj/hazen/alibaba_tinynn/TinyNeuralNetwork/tinynn/graph/tracer.py", line 982, in add_output_node node.prev_nodes.append(current_graph().nodes_map[current_graph().tensor_pre_node_dict[id(t)]]) KeyError: 140563279179208 ERROR (tinynn.graph.tracer) inputs: ['input_1'] ERROR (tinynn.graph.tracer) forwards: ['conv1', 'conv2', 'reshape_1', 'getitem_1'] ERROR (tinynn.graph.tracer) outputs: [] ERROR (tinynn.graph.tracer) constants: []
Hi, The torchvision::ops::batched_nms in PyTorch looks like the tensorflow::ops::combined-non-max-suppression (set the q to 1) in Tensorflow (I'm not sure about this numerical equivalence and I should do some verification about this), that's much faster than torchvision.ops.nms. Is it possible to implement this transformation as well?
Also i've noticed, TFLite's NMS v4 and v5, exisits in your code here , can you please explain what is the overall flow of tinynn converter. so that i can contribute this feature myself
It is a long story. Currently, you have to make changes to multiple components to make it work. Let me list things to do here.
- [x] [tracer] track
torchvision.ops.x- https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/graph/configs/gen_funcs_yml.sh
- https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/graph/configs/torch_func_override_1_10.yml
- https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/graph/configs/torch_func_override_1_9.yml
- https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/graph/configs/torch_func_override_1_8.yml
- https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/graph/configs/torch_func_override_1_7.yml
- https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/graph/configs/torch_func_override_1_6.yml
- [x] [converter] Generate torchvision schema
- https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/torch/aten_schema.py
- [ ] [converter] Implement
torchvision::nms->tfl.NonMaxSuppressiontranslation- https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/torch/aten.py
- [ ] [converter] Register
torchvision::nms- https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/torch/init.py
::batched_nmsin PyTorch looks like thetensorflow::ops::combined-non-max-suppression(set theqto 1) in Tensorflow (I'm not sure about this numerical equivalence and I should do some verification about this), that's much faster thantorchvision.ops.nms. Is it possible to implement this transformation as well?
Yeah, you are right. However, according to this post, combined-non-max-suppression translates to a Flex op in TFLite. Currently, supporting the Flex op is a low-priority work for us. Patches are welcome.
@gj-raza I've done the first two tasks. If you are interested, you may take a look at the latter two, which should be fairly easy.
@gj-raza I've done the first two tasks. If you are interested, you may take a look at the latter two, which should be fairly easy.
@peterjc123 Sure, but I have not previously worked with Pytorch/TFL internals or schemas so it might take me some time figuring it all out on my own, so if there is any documentation, links, tutorials etc that you think will be helpful please share, it'll get on it asap
@gj-raza As for the PyTorch side, the schema is quite clear. https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/torch/torchvision_schema.py#L40 https://pytorch.org/vision/stable/ops.html#torchvision.ops.nms With regard to the TFLite side, you may refer to the following docs. https://www.tensorflow.org/mlir/tfl_ops#tflnon_max_suppression_v4_mlirtflnonmaxsuppressionv4op https://www.tensorflow.org/api_docs/python/tf/image/non_max_suppression_padded https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/non_max_suppression.cc#L29-L60
To implement NMS translation, you may have to do the following things:
- Create a new file
torchvision.pyunder https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/torch/ and import some libraries just as inaten.py. - Create a skeleton class
TorchVisionNmsOperator
class TorchVisionNmsOperator(TorchVisionNmsSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
- In the
parsefunction, you may get the input/output arguments or tensors throughself.input_tensorandself.output_tensorrespectively. We also provideself.find_or_create_inputandself.to_tfl_tensorsfor converting those tensors to the TFL format. If you need TFL tensors other than the I/O tensors, there's alsoself.create_attr_tensorandself.create_transform_tensorfor creating tensors that serves as constants and variables. - As for OP creation at the TFLite side, you just need one line to do that, in which the
inputsandoutputsare lists oftfl.Tensors.
graph_converter.add_operator(tfl.NonMaxSuppressionV4Operator(inputs, outputs))
- Write the translation logic (self.input_tensor, self.output_tensor -> inputs, outputs)
- Register the operator translator here.
The tricky parts include:
NMSreturns a tensor of dynamic size, so you need to pad the PT tensors to a maximum size since TFLite doesn't support dynamic size.tfl.NonMaxSuppressionV4Operatorprovides an additional argumentscore_threshold, which you may need to set it to the default valuefloat(-inf).- The format of the bounding box seems different for both backends. TF:
[y1, x1, y2, x2]PT:(x1, y1, x2, y2), which you may need to reorder them usingtfl.GatherND.
Plus:
There's another OP TFLite_Detection_PostProcess, which is a class-aware alternative.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
It also has the benefit of including the u8 implementation. But the problem is that it accepts different format of boxes (cx, cy, w, h)(could be converted using torchvision.ops.box_convert, which uses torch.stack as the underlying call). Also, its schema is not included in TinyNerualNetwork yet. Needs some modification in the following files: https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/tflite/base.py#L15
https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/tflite/custom.py
Hi, The
torchvision::ops::batched_nmsin PyTorch looks like thetensorflow::ops::combined-non-max-suppression(set theqto 1) in Tensorflow (I'm not sure about this numerical equivalence and I should do some verification about this), that's much faster thantorchvision.ops.nms. Is it possible to implement this transformation as well?
@zhiqwang To my surprise, torchvision.ops.batched_nms is implemented by doing some transforms before calling torchvision.ops.nms here. So actually, it could also be supported.
To my surprise,
torchvision.ops.batched_nmsis implemented by doing some transforms before callingtorchvision.ops.nmshere. So actually, it could also be supported.
Yep @peterjc123 , and the name _batched_nms_coordinate_trick tells us everything about the trick between agnostic nms and batched nms.
FYI @peterjc123 I guess the following figure can explain the secret here
Copyright of this figure: https://github.com/ultralytics/yolov5/discussions/5825#discussioncomment-1717311 .
@gj-raza The schema of the TFLITE_DETECTION_POSTPROCESS op is added here in case you may need that.
@gj-raza The schema of the
TFLITE_DETECTION_POSTPROCESSop is added here in case you may need that.
Thanks @peterjc123 . So now i'll have to map this operator to a new created pytorch operator in step 6 mentioned above?
@gj-raza Yes, you'll need to do steps 1-6 as I described.
Not going to implement it at our side because it is rarely used in our scenarios. However, if you have any questions implementing this feature, you are free to ask the questions. @gj-raza