does the TFLite model allow batching?
I am using the TF Lite checkpoint https://github.com/kahst/BirdNET-Analyzer/raw/main/checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_FP16.tflite
it seems I cannot pass batched samples, as passing any np array shape other than [1,144000] results in the model complaining the sample shape is not the expected shape. Is there an alternative approach to batching? Or does this model not support batching? (It's likely I'm simply using the model wrong)
set-up code, abbreviated
model_path = "https://github.com/kahst/BirdNET-Analyzer/raw/main/checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_FP16.tflite"
network = tflite.Interpreter(
model_path=model_path, num_threads=1
)
network.allocate_tensors()
# ...load some audio into a 2d np array...
sample = ... #2d np array, shape [2,144000]
input_details = network.get_input_details()[0]
output_details = network.get_output_details()[0]
embedding_idx = output_details["index"] - 1
attempt to call the network on batched samples fails
network.set_tensor(
input_details["index"], np.float32(sample)
)
network.invoke()
logits = network.get_tensor(output_details["index"])
embeddings = self.network.get_tensor(embedding_idx)
error:
File [~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:720](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:720), in Interpreter.set_tensor(self, tensor_index, value)
[704](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:704) def set_tensor(self, tensor_index, value):
[705](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:705) """Sets the value of the input tensor.
[706](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:706)
[707](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:707) Note this copies data in `value`.
(...)
[718](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:718) ValueError: If the interpreter could not set the tensor.
[719](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:719) """
--> [720](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:720) self._interpreter.SetTensor(tensor_index, value)
ValueError: Cannot set tensor: Dimension mismatch. Got 2 but expected 1 for dimension 0 of input 0.
calling network with batch size=1 works (sample.shape=[1,144000])
sample = np.float32(audio[0])[np.newaxis, :] #shape [1,144000]
network.set_tensor(
input_details["index"], sample
)
network.invoke()
logits = network.get_tensor(output_details["index"])
embeddings = self.network.get_tensor(embedding_idx)
Batching is possible, please have a look at our code, to see how we resize the input according to the batch size.
I can't tell how the input is shaped from that code. I've tried all of these shapes and none are valid: [1, 2,144000] [2,1,144000] [2,144000] [1, 144000, 2]
the only shape that works is [1,144000]
You have to resize the input tensor to your batchsize using resize_tensor_input have a look at the documentation and our code.
Thank you, I have it working now. As a PyTorch user, it didn't occur to me that I would need to modify the model before passing batched samples.