Have better error message when using single stage of top-down in non-evaluation mode
Quick error reporting enhancement:
Top-down models have a centroid stage and centered instance stage. You can load them individually (via sleap.load_model) and have the other stage just output the ground truth, allowing us to evaluate their accuracy independently.
This is done by using the CentroidCropGroundTruth or FindInstancePeaksGroundTruth layers. These layers expect that the the inputs received by the model are a dictionary as generated by our LabelsReader-based pipelines.
If we provide the data raw or through VideoReader, we won't have the ground truth data though, so we'll get the error in the issue below.
We should probably catch this edge case and raise a more informative exception telling the user to load both stages when predicting on non-ground truth data.
Here are the relevant lines: https://github.com/talmolab/sleap/blob/740f9fa7b1d5a4640cf8d96b6e4185273219346a/sleap/nn/inference.py#L1842-L1851
In the if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth): block, we could check that example is a dict containing the necessary keys (i.e., "instance"). Or maybe we could check in FindInstancePeaksGroundTruth?
Getting stuck somehow here with this traceback:
predictor.predict(imgs[:16]) Predicting... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% ETA: -:--:-- ? Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/sleap/nn/inference.py", line 436, in predict self._make_labeled_frames_from_generator(generator, data) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/sleap/nn/inference.py", line 2126, in _make_labeled_frames_from_generator for ex in generator: File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/sleap/nn/inference.py", line 346, in _predict_generator ex = process_batch(ex) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/sleap/nn/inference.py", line 312, in process_batch preds = self.inference_model.predict_on_batch(ex, numpy=True) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/sleap/nn/inference.py", line 916, in predict_on_batch outs = super().predict_on_batch(data, **kwargs) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/keras/engine/training.py", line 1947, in predict_on_batch outputs = self.predict_function(iterator) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 885, in __call__ result = self._call(*args, **kwds) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 924, in _call results = self._stateful_fn(*args, **kwds) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3038, in __call__ filtered_flat_args) = self._maybe_define_function(args, kwargs) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3460, in _maybe_define_function args, kwargs, flat_args, filtered_flat_args, cache_key_context) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3382, in _define_function_with_shape_relaxation args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3308, in _create_graph_function capture_by_value=self._capture_by_value), File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 1007, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 668, in wrapped_fn out = weak_wrapped_fn().__wrapped__(*args, **kwds) File "/home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 994, in wrapper raise e.ag_error_metadata.to_exception(e) tensorflow.python.autograph.impl.api.StagingError: in user code: /home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/keras/engine/training.py:1586 predict_function * return step_function(self, iterator) /home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/sleap/nn/inference.py:1848 call * peaks_output = self.instance_peaks(example, crop_output) /home/jdelahanty/miniconda3/envs/sleap/lib/python3.7/site-packages/sleap/nn/inference.py:692 call * a = tf.expand_dims( KeyError: 'instances'Any tips?
You have to load both models when doing top-down:
predictor = sleap.load_model([
"/nadata/snlkt/spb_psilicon/Pilot_BehParameters/SLEAP/SLEAP_phase4/models/220801_174806.centroid.n=2092",
"/nadata/snlkt/spb_psilicon/Pilot_BehParameters/SLEAP/SLEAP_phase4/models/220801_181533.centered_instance.n=2092",
], batch_size=16)
Admittedly, we could probably have a more informative error message :)
Originally posted by @talmo in https://github.com/talmolab/sleap/discussions/833#discussioncomment-3311908