audio icon indicating copy to clipboard operation
audio copied to clipboard

Differentiable filtering using a cascade of second order IIR filters

Open SuperKogito opened this issue 1 year ago • 0 comments

🚀 The feature

A pytorch differentiable sosfilt() implementation like in https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfilt.html, will allow for filtering data along one dimension using cascaded second-order sections. This should allow for a better support of high order stable filtering.

Motivation, pitch

The current alternative is to convert the cascade of biquads (2nd order IIR filters) to a high order filter and then use https://pytorch.org/audio/main/generated/torchaudio.functional.lfilter.html to apply the filter. Unfortunately this only works to a certain order (order<6). The following code illustrates the stability issues faced using lfilter with a high order filter. Hence, an option for a cascaded filtering to maintain stability would be of great advantage.


import torch
import scipy.signal as signal
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchaudio.functional import lfilter

hardware = "cpu"
device = torch.device(hardware)
eps = 1e-8
    
def coeff_product(polynomials):
    n = polynomials.shape[0]
    if n == 1:
        return polynomials

    c1 = coeff_product(polynomials[n // 2 :])
    c2 = coeff_product(polynomials[: n // 2])
    if c1.shape[1] > c2.shape[1]:
        c1, c2 = c2, c1
    weight = c1.unsqueeze(1).flip(2)
    prod = F.conv1d(
        c2.unsqueeze(0),
        weight,
        padding=weight.shape[2] - 1,
        groups=c2.shape[0],
    ).squeeze(0)
    return prod

if __name__ == "__main__":
    for order in range(2, 12, 2):
        # Print the poles, zeros, and gain
        b, a = signal.ellip(order, 0.009, 80, 0.05, output='ba')
        sos = signal.ellip(order, 0.009, 80, 0.05, output='sos')
        zeros, poles, gain = signal.sos2zpk(sos)
        
        print("-" * 52)
        print("Zeroes : ", zeros)        
        print("Poles  : ", poles)        
        print("-" * 52)
        
        print("sos : ", sos)
        print("-" * 52)        
        
        print("b : ", b)
        print("a : ", a)
        print("-" * 52)
        # init var 
        fs = 500
        eps = 1e-8
        dirac = torch.tensor(signal.unit_impulse(fs), dtype=torch.float32)
        # PYTORCH IMPLEMENTATION
        # prepare coeffs 
        torch_sos = torch.tensor(sos, dtype=torch.float32)
        torch_a = torch_sos[:, 3:]
        torch_b = torch_sos[:, :3]
        high_order_a = coeff_product(torch_a)
        high_order_b = coeff_product(torch_b)
        
        print("sos : ", torch_sos)
        print("-" * 52)        
        
        print("torch_b : ", torch_b)
        print("torch_a : ", torch_a)
        print("-" * 52)
        
        print("high_order_b : ", high_order_b)
        print("high_order_a : ", high_order_a)
        print("-" * 52)
        
        # compute filter response
        y_torch_ba = lfilter(dirac.unsqueeze(0), high_order_a, high_order_b)
        
        ## SCIPY IMPLEMENTATION
        freq, freq_response = signal.sosfreqz(sos)
        x     = signal.unit_impulse(fs)
        y_tf  = signal.lfilter(high_order_b.squeeze(0).detach().numpy(), high_order_a.squeeze(0).detach().numpy(), x)
        y_sos = signal.sosfilt(sos, x)
        
        # plotting
        plt.figure(figsize=(15, 30))
        plt.subplot(3, 1, 1)
        plt.plot(y_sos, 'g', label='SOS')
        plt.legend(loc='best')
    
        plt.subplot(3, 1, 2)
        plt.plot(y_tf, 'k', label='TF')
        plt.legend(loc='best')
    
        plt.subplot(3, 1, 3)
        plt.plot(y_torch_ba.squeeze(0).detach().numpy(), "r", label="torch")
        plt.legend(loc='best')
        plt.show()

This feature would allow users to apply high order filtering (order>6) within loss functions and training loops.

Alternatives

The current alternative since no filtering based on a cascade of biquads is available are:

  • Use a for loop and feed in the biquad coefficients, one at a time, to https://pytorch.org/audio/main/generated/torchaudio.functional.lfilter.html. This is unfortunately very slow thus it is very unpractica within a loss function or a training loop.
  • Convert the cascade of biquads to a high order filterand use it with https://pytorch.org/audio/main/generated/torchaudio.functional.lfilter.html . This results in an unstable output as illustrated above which is expected and happens with Scipy too when using lfilter instead of sosfilt().

Additional context

https://dsp.stackexchange.com/questions/31457/multiple-biquads-vs-higher-order-filtering

SuperKogito avatar Jul 10 '24 09:07 SuperKogito