Hard limit for instance count in multi-instance models inside of inference graph
It's sometimes desirable to have the ability to select the max number of instances that a multi-instance model will return. Currently, we implement this through tracking, but sometimes we might want to do this without having to run the tracker.
Use cases:
- On-demand (e.g., in GUI or programmatically) instance count filtering
- During HITL training when we don't track contiguous frames
- (This issue) In exported inference models so the logic doesn't have to be reimplemented downstream
- For performance: extraneous detections in the centroid stage can really slow down top-down models, which scale roughly linearly with number of instances
The problem is that there are several strategies for selecting among N instance detections.
#717 should solve this in a more general form by providing standalone filtering functions that operate on single LabeledFrames.
This issue proposes a smaller and less general solution to this that will work for some of the use cases.
The idea is to implement this with tensorflow graph compatible ops within the InferenceLayer/InferenceModel subclasses. It's less general, but compatible with exported models.
This could go here: https://github.com/talmolab/sleap/blob/6cac6519208dbc77a89a1e7fb019fed03d9514ac/sleap/nn/inference.py#L2034
Or even better, here during centroid detection/cropping: https://github.com/talmolab/sleap/blob/6cac6519208dbc77a89a1e7fb019fed03d9514ac/sleap/nn/inference.py#L1659-L1660
Where we could use tf.math.top_k on the peak values like:
max_instances: Optional[int] = None
# ...
# in call() method:
if self.max_instances is not None:
top_points = tf.math.top_k(centroid_vals, k=self.max_instances)
top_inds = top_points.indices
centroid_vals = tf.gather(centroid_vals, top_inds)
centroid_points = tf.gather(centroid_points, top_inds)
crop_sample_inds = tf.gather(crop_sample_inds, top_inds)
We still need to allow setting the max instances for bottom-up models
This is now implemented.