keras
keras copied to clipboard
Add: `MelSpectrogram` layer
MelSpectrogram layer was a long overdue. It got halted during the keras v2 to v3 port and was stuck for a while due to some missing dependencies. This PR will add this layer along with it's dependencies (linear_to_mel_weight_matrix, frame, etc).
- issue (keras-v3): https://github.com/keras-team/keras/issues/18405
- issue (keras-v2): https://github.com/keras-team/tf-keras/issues/55
- PR (keras-v2): https://github.com/keras-team/keras/pull/17717
cc: @fchollet kindly need your review on this, specifically I'm not sure how this layer will fit to current keras v3. For now, I'm keeping everything in layers/preprocessing/audio_preprocessing.py.
Also, while porting from TensorFlow I skipped dynamic shape parts as I wasn't sure how they translate to keras 3. So need to do some tests.
Codecov Report
Attention: Patch coverage is 98.61111% with 1 lines in your changes are missing coverage. Please review.
Project coverage is 75.50%. Comparing base (
c8700f4) to head (d45517e). Report is 64 commits behind head on master.
| Files | Patch % | Lines |
|---|---|---|
| keras/layers/preprocessing/audio_preprocessing.py | 98.59% | 0 Missing and 1 partial :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #19194 +/- ##
==========================================
- Coverage 80.14% 75.50% -4.64%
==========================================
Files 341 363 +22
Lines 36163 39221 +3058
Branches 7116 7583 +467
==========================================
+ Hits 28982 29614 +632
- Misses 5578 7971 +2393
- Partials 1603 1636 +33
| Flag | Coverage Δ | |
|---|---|---|
| keras | 75.35% <98.61%> (-4.64%) |
:arrow_down: |
| keras-jax | 59.73% <98.61%> (-3.33%) |
:arrow_down: |
| keras-numpy | 54.15% <88.88%> (-2.93%) |
:arrow_down: |
| keras-tensorflow | 60.89% <98.61%> (-3.76%) |
:arrow_down: |
| keras-torch | 60.36% <98.61%> (-3.51%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Visual results seem fine,
import librosa
import matplotlib.pyplot as plt
import librosa.display
num_mel_bins = 256
fft_stride = 128
num_fft_bins = 1024
fmin = 512
fmax = 8000
y, sr = librosa.load("/cardinal-37075.mp3", sr=None)
audio = keras.ops.convert_to_tensor(y)
spec = keras.layers.MelSpectrogram(
num_mel_bins=num_mel_bins,
sampling_rate=sr,
fft_stride=fft_stride,
num_fft_bins=num_fft_bins,
min_freq=fmin,
max_freq=fmax,
)(audio)
spec = keras.ops.convert_to_numpy(spec)
# display mel spectrogram with librosa
plt.figure(figsize=(10, 3 * 2))
plt.subplot(211)
librosa.display.specshow(
spec,
hop_length=fft_stride,
sr=sr,
n_fft=num_fft_bins,
y_axis="mel",
x_axis="time",
fmin=fmin,
fmax=fmax,
)
plt.xlabel(None)
# plot the waveform
plt.subplot(212)
librosa.display.waveshow(y, sr=sr)
plt.xlim([0, len(y)/sr]) # set x-axis limits
plt.tight_layout()
plt.show()
Looking good! Please add a set of unit tests as well.
Thanks. I will add the unit test soon.
@fchollet, while performing correctness tests, I noticed that the result from keras matches with torchaudio but doesn't match with librosa, even when the parameters are the same. However, visually, they have similar shape results. Here is a numerical comparison:
Keras:
Code:
spec_keras = keras.layers.MelSpectrogram(
num_mel_bins=num_mel_bins,
sampling_rate=sr,
sequence_stride=fft_stride,
fft_length=num_fft_bins,
min_freq=fmin,
max_freq=fmax,
ref_power=1.0,
)(y)
spec_keras = keras.ops.convert_to_numpy(spec_keras)
Output:
array([[-38.097645, -38.097645, -38.097645, ..., -31.054651, -31.315052,
-32.428814],
[-38.097645, -38.097645, -38.097645, ..., -38.097645, -38.097645,
-38.097645],
[-38.097645, -38.097645, -38.097645, ..., -33.124752, -38.097645,
-38.06842 ],
...,
[-38.097645, -38.097645, -38.097645, ..., -38.097645, -38.097645,
-38.097645],
[-38.097645, -38.097645, -38.097645, ..., -38.097645, -38.097645,
-38.097645],
[-38.097645, -38.097645, -38.097645, ..., -38.097645, -38.097645,
-38.097645]], dtype=float32),
Librosa:
Code:
spec_librosa = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=num_fft_bins,
hop_length=fft_stride, n_mels=num_mel_bins,
fmin=fmin, fmax=fmax)
spec_librosa = librosa.power_to_db(spec_librosa, ref=1.0, top_db=80.0)
Output:
array([[-53.241608, -53.241608, -53.241608, ..., -41.145687, -42.29181 ,
-46.312866],
[-53.241608, -53.241608, -53.241608, ..., -48.96924 , -53.241608,
-53.241608],
[-53.241608, -53.241608, -53.241608, ..., -43.543972, -53.241608,
-51.195637],
...,
[-53.241608, -53.241608, -53.241608, ..., -53.241608, -53.241608,
-53.241608],
[-53.241608, -53.241608, -53.241608, ..., -53.241608, -53.241608,
-53.241608],
[-53.241608, -53.241608, -53.241608, ..., -53.241608, -53.241608,
-53.241608]], dtype=float32)
Torchaudio:
Code:
mel_spec_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_fft=num_fft_bins,
hop_length=fft_stride,
n_mels=num_mel_bins,
f_min=fmin,
f_max=fmax
)
spec_torch = mel_spec_transform(y)
spec_torch = torchaudio.transforms.AmplitudeToDB(top_db=80.0)(spec_torch).numpy()
Output
array([[-38.098427, -38.098427, -38.098427, ..., -31.064793, -31.325191,
-32.438957],
[-38.098427, -38.098427, -38.098427, ..., -38.098427, -38.098427,
-38.098427],
[-38.098427, -38.098427, -38.098427, ..., -33.12678 , -38.098427,
-38.070446],
...,
[-38.098427, -38.098427, -38.098427, ..., -38.098427, -38.098427,
-38.098427],
[-38.098427, -38.098427, -38.098427, ..., -38.098427, -38.098427,
-38.098427],
[-38.098427, -38.098427, -38.098427, ..., -38.098427, -38.098427,
-38.098427]], dtype=float32))
This appears to be correct and normal: https://github.com/pytorch/audio/issues/1058
I found librosa use slaney normalization for the mel-filterbank creation as the default, while torchaudio is no normalization by default.
How would one go about adding slaney normalization in Keras? Is it possible for the user to add it externally from the layer?
This appears to be correct and normal: pytorch/audio#1058
I found librosa use slaney normalization for the mel-filterbank creation as the default, while torchaudio is no normalization by default.
How would one go about adding slaney normalization in Keras? Is it possible for the user to add it externally from the layer?
I tried to implement "slaney" normalization in keras from librosa ref but the values still doesn't match.
if norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
mel_f = ops.math.extract_sequences(
ops.linspace(
_hertz_to_mel(lower_edge_hertz),
_hertz_to_mel(upper_edge_hertz),
num_mel_bins + 2,
),
sequence_length=1,
sequence_stride=1,
)
enorm = 2.0 / (mel_f[2 : num_mel_bins + 2] - mel_f[:num_mel_bins])
mel_weights_matrix *= ops.transpose(enorm)
elif norm is not None:
raise NotImplementedError(f"Unsupported norm={norm}")
output:
Before:
array([[-38.097645, -38.097645, -38.097645, ..., -31.054651, -31.315052,
-32.428814],
[-38.097645, -38.097645, -38.097645, ..., -38.097645, -38.097645,
-38.097645],
[-38.097645, -38.097645, -38.097645, ..., -33.124752, -38.097645,
-38.06842 ],
...,
[-38.097645, -38.097645, -38.097645, ..., -38.097645, -38.097645,
-38.097645],
[-38.097645, -38.097645, -38.097645, ..., -38.097645, -38.097645,
-38.097645],
[-38.097645, -38.097645, -38.097645, ..., -38.097645, -38.097645,
-38.097645]], dtype=float32),
After:
array([[-47.464535, -47.464535, -47.464535, ..., -40.421543, -40.68195 ,
-41.795708],
[-47.464535, -47.464535, -47.464535, ..., -47.464535, -47.464535,
-47.464535],
[-47.464535, -47.464535, -47.464535, ..., -42.491627, -47.464535,
-47.435295],
...,
[-47.464535, -47.464535, -47.464535, ..., -47.464535, -47.464535,
-47.464535],
[-47.464535, -47.464535, -47.464535, ..., -47.464535, -47.464535,
-47.464535],
[-47.464535, -47.464535, -47.464535, ..., -47.464535, -47.464535,
-47.464535]], dtype=float32)
LIbrosa:
array([[-53.241608, -53.241608, -53.241608, ..., -41.145687, -42.29181 ,
-46.312866],
[-53.241608, -53.241608, -53.241608, ..., -48.96924 , -53.241608,
-53.241608],
[-53.241608, -53.241608, -53.241608, ..., -43.543972, -53.241608,
-51.195637],
...,
[-53.241608, -53.241608, -53.241608, ..., -53.241608, -53.241608,
-53.241608],
[-53.241608, -53.241608, -53.241608, ..., -53.241608, -53.241608,
-53.241608],
[-53.241608, -53.241608, -53.241608, ..., -53.241608, -53.241608,
-53.241608]], dtype=float32)
Did you try to compare the output of mel_weights_matrix side by side between the two implementations?
If it matches, then the implication is that there's an additional difference with Librosa.
In any case, I think that the match with torch audio is sufficient validation of the correctness of the implementation.
@fchollet We can use an external mel matrix like this if we want to use norm="slaney". I've checked the result, and it matches with librosa, which means a mismatch is happening in the mel weights.
class MelSpectrogram(keras.layers.MelSpectrogram):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _melscale(self, inputs):
weights = librosa.filters.mel(
sr=self.sampling_rate, n_fft=self.fft_length,
n_mels=self.num_mel_bins, fmin=self.min_freq,
fmax=self.max_freq, htk=True, norm="slaney",
)
weights = keras.backend.convert_to_tensor(weights.T)
return keras.backend.tensordot(inputs, weights, axes=1)
I've tried comparing the mel matrix from Keras and Librosa side by side with the following code:
matrix_librosa = librosa.filters.mel(
sr=sr, n_fft=num_fft_bins, n_mels=num_mel_bins,
fmin=fmin, fmax=fmax, norm="slaney", htk=True,
)
matrix_keras = linear_to_mel_weight_matrix(
num_mel_bins=num_mel_bins,
num_spectrogram_bins=num_fft_bins // 2 + 1,
sampling_rate=sr,
lower_edge_hertz=fmin,
upper_edge_hertz=fmax,
norm="slaney",
).T
(matrix_keras - matrix_librosa).sum()
When htk=True and norm=None, the Librosa mel matrix matches exactly with Keras, and the final output from mel-spec layers also matches. Mismatches happen when we set norm="slaney" in Keras and Librosa; then, the mel matrix doesn't match.
So, If we want an identical result across librosa, keras, and torchaudio, all we need to do is set htk=True and norm=None in librosa.
@fchollet should I add correctness test with torchaudio?
Also do we need tf_data compatibility? I tried adding tf_data_compatibility test. But it fails.
code:
input_data = np.random.random((2, 8000))
layer = keras.layers.MelSpectrogram(num_mel_bins=80,
sampling_rate=8000,
sequence_stride=128,
fft_length=2048)
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2)
ds = ds.map(layer)
error:
NotImplementedError: Exception encountered when calling MelSpectrogram.call().
Cannot convert a symbolic tf.Tensor (Cast:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
Arguments received by MelSpectrogram.call():
• inputs=<tf.Tensor 'Cast:0' shape=(None, 8000) dtype=float32>
But when I try this, it runs fine code:
layer(next(iter(ds.take(1))))
Also do we need tf_data compatibility? I tried adding tf_data_compatibility test. But it fails.
It is feasible. The current layer code is not set up for that. Take a look at the other layers that inherit from TFDataLayer. They need to use self.backend.numpy/etc instead of ops for op access (so we can swap out the backend).
As for whether we should do it -- how do you think the layer will be used? Inside a model? Externally?
As for whether we should do it -- how do you think the layer will be used? Inside a model? Externally?
Using this layer in a model is definitely the best choice, as for later usage, all we have to do is load the model and then pass the audio directly to the model. If we use this layer in data pipeline, then we have to process the audio separately with proper hyper parameters before passing audio to the model. I think YAMNet in TensorFlow Hub does the same.
But yes having the flexibility of using this layer both in model and data pipeline would be great. I will take a one last shot at this.
Ok -- it's definitely better to make it work with tf.data even if that's not the main use case, then. If you're running into major blockers, then we can revert and merge a version that only works with tf.data when using the TF backend. That's ok.
@fchollet I've added support for tf.data but as it required using self.backend instead of keras.ops, I had to move the independent functions into MelSpectrogram layer. I've checked it, it runs fine in both model and tf.data.
I also tried using keras.backend instead of self.backend but it throws same error.
Right -- if you want tf.data support with a backend other than TF, all ops applied by the layer must go through self.backend. Anything else won't work.
Is this layer compatible/convertible to TFLite?