BirdNET-Analyzer icon indicating copy to clipboard operation
BirdNET-Analyzer copied to clipboard

does the TFLite model allow batching?

Open sammlapp opened this issue 2 years ago • 1 comments

I am using the TF Lite checkpoint https://github.com/kahst/BirdNET-Analyzer/raw/main/checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_FP16.tflite

it seems I cannot pass batched samples, as passing any np array shape other than [1,144000] results in the model complaining the sample shape is not the expected shape. Is there an alternative approach to batching? Or does this model not support batching? (It's likely I'm simply using the model wrong)

set-up code, abbreviated

model_path = "https://github.com/kahst/BirdNET-Analyzer/raw/main/checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_FP16.tflite"
network = tflite.Interpreter(
    model_path=model_path, num_threads=1
)
network.allocate_tensors()

# ...load some audio into a 2d np array... 
sample = ... #2d np array, shape [2,144000]

input_details = network.get_input_details()[0]
output_details = network.get_output_details()[0]
embedding_idx = output_details["index"] - 1

attempt to call the network on batched samples fails

network.set_tensor(
    input_details["index"], np.float32(sample)
)
network.invoke()
logits = network.get_tensor(output_details["index"])
embeddings = self.network.get_tensor(embedding_idx)

error:

File [~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:720](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:720), in Interpreter.set_tensor(self, tensor_index, value)
    [704](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:704) def set_tensor(self, tensor_index, value):
    [705](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:705)   """Sets the value of the input tensor.
    [706](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:706) 
    [707](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:707)   Note this copies data in `value`.
   (...)
    [718](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:718)     ValueError: If the interpreter could not set the tensor.
    [719](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:719)   """
--> [720](https://vscode-remote+ssh-002dremote-002bsnow.vscode-resource.vscode-cdn.net/home/sml161/nb_opso/model_zoo/~/miniconda3/envs/opso_tf_cuda/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:720)   self._interpreter.SetTensor(tensor_index, value)

ValueError: Cannot set tensor: Dimension mismatch. Got 2 but expected 1 for dimension 0 of input 0.

calling network with batch size=1 works (sample.shape=[1,144000])

sample = np.float32(audio[0])[np.newaxis, :] #shape [1,144000]
network.set_tensor(
    input_details["index"], sample
)
network.invoke()
logits = network.get_tensor(output_details["index"])
embeddings = self.network.get_tensor(embedding_idx)

sammlapp avatar Feb 29 '24 22:02 sammlapp

Batching is possible, please have a look at our code, to see how we resize the input according to the batch size.

Josef-Haupt avatar Mar 04 '24 14:03 Josef-Haupt

I can't tell how the input is shaped from that code. I've tried all of these shapes and none are valid: [1, 2,144000] [2,1,144000] [2,144000] [1, 144000, 2]

the only shape that works is [1,144000]

sammlapp avatar Apr 16 '24 13:04 sammlapp

You have to resize the input tensor to your batchsize using resize_tensor_input have a look at the documentation and our code.

Josef-Haupt avatar Apr 16 '24 15:04 Josef-Haupt

Thank you, I have it working now. As a PyTorch user, it didn't occur to me that I would need to modify the model before passing batched samples.

sammlapp avatar Apr 26 '24 17:04 sammlapp