Support TFLite and TFJS conversion for CenterNet multi-class keypoints
Description
Rewrote the _postprocess_keypoints_multi_class() method in center_net_meta_arch.py to not use tf.tensor_scatter_nd_add ops, which seem to:
- Break conversion to TFJS (Error message explicitly says
TensorScatterAddis not supported). - Not break conversion to TFLite, but inference will fail unless at least one object from all classes is predicted in any given output batch (see this issue raised in tensorflow/tensorflow).
This is a non-breaking change: The model's outputs in the exported SavedModel format remain 100% unchanged after this rewrite. Additionally, no negative impact on latency was observed - my tests yielded a latency improvement of ~3%:
| Before rewrite | After rewrite | Improvement | |
|---|---|---|---|
| Latency (batch size 1) | 8.63 ms | 8.37 ms | 3.10% |
| Latency (batch size 32) | 192.29 ms | 185.54 ms | 3.63% |
For this reason, I replaced the code in the existing _postprocess_keypoints_multi_class() method. However, please let me know if I should leave that method as-is and add a new one instead.
Type of change
- [x] Bug fix (non-breaking change which fixes an issue)
Tests
Besides the unit test added in this PR, I ran tests to ensure that the TFLite and TFJS models run successfully without crashing, as well as to test outputs and latency of SavedModels. I have uploaded the models and Python scripts in the drive link below:
https://drive.google.com/drive/folders/1wCUvu_LQN6maRQFAKkAU8l1JnsOYm_-F?usp=sharing
Test Configuration: The scripts were ran in Python 3.10.13 and Tensorflow 2.8.1/2.14.0.
Checklist
- [x] I have signed the Contributor License Agreement.
- [x] I have read guidelines for pull request.
- [x] My code follows the coding guidelines.
- [x] I have performed a self code review of my own code.
- [x] I have commented my code, particularly in hard-to-understand areas.
- [x] I have made corresponding changes to the documentation.
- [x] My changes generate no new warnings.
- [x] I have added tests that prove my fix is effective or that my feature works.
Hi @jch1, any updates on this? Is this repo still accepting contributions?