PiDTLN icon indicating copy to clipboard operation
PiDTLN copied to clipboard

Using multi-channel audio

Open razor1179 opened this issue 3 years ago • 1 comments

Hi there,

The current code works great when using a single channel audio input, but when the channel is set to 2 the code throws an error

Exception ignored from cffi callback <function _StreamBase.__init__.<locals>.callback_ptr at 0x7f73d4faf0>:
Traceback (most recent call last):
  File "/home/pi/.local/lib/python3.9/site-packages/sounddevice.py", line 880, in callback_ptr
    return _wrap_callback(
  File "/home/pi/.local/lib/python3.9/site-packages/sounddevice.py", line 2681, in _wrap_callback
    callback(*args)
  File "/home/pi/PiDTLN/ns.py", line 128, in callback
    indata = indata[:, [args.channel]]
IndexError: index 2 is out of bounds for axis 1 with size 2

I did modify the code to separately process the two streams but I get a input overflow output underflow issue as the way I have modified the code slows things down. Do you have a suggestion on how to solve the issue? The modified code is below

# set some parameters
block_len_ms = 32
block_shift_ms = 8
fs_target = 16000
# create the interpreters
interpreter_1 = tflite.Interpreter(model_path='./models/dtln_ns_quant_1.tflite', num_threads=args.threads)
interpreter_1.allocate_tensors()
interpreter_2 = tflite.Interpreter(model_path='./models/dtln_ns_quant_2.tflite', num_threads=args.threads)
interpreter_2.allocate_tensors()
# Get input and output tensors.
input_details_1 = interpreter_1.get_input_details()
output_details_1 = interpreter_1.get_output_details()
input_details_2 = interpreter_2.get_input_details()
output_details_2 = interpreter_2.get_output_details()
# create states for the lstms
states_1_ch1 = np.zeros(input_details_1[1]['shape']).astype('float32')
states_2_ch1 = np.zeros(input_details_2[1]['shape']).astype('float32')
states_1_ch2 = np.zeros(input_details_1[1]['shape']).astype('float32')
states_2_ch2 = np.zeros(input_details_2[1]['shape']).astype('float32')
# calculate shift and length
block_shift = int(np.round(fs_target * (block_shift_ms / 1000)))
block_len = int(np.round(fs_target * (block_len_ms / 1000)))
# create buffer
in_buffer_ch1 = np.zeros((block_len)).astype('float32')
out_buffer_ch1 = np.zeros((block_len)).astype('float32')
in_buffer_ch2 = np.zeros((block_len)).astype('float32')
out_buffer_ch2 = np.zeros((block_len)).astype('float32')

if args.no_fftw:
    g_use_fftw = False
if g_use_fftw:
    fft_buf_ch1 = pyfftw.empty_aligned(512, dtype='float32')
    rfft_ch1 = pyfftw.builders.rfft(fft_buf_ch1, threads=args.threads)
    ifft_buf_ch1 = pyfftw.empty_aligned(257, dtype='complex64')
    irfft_ch1 = pyfftw.builders.irfft(ifft_buf_ch1, threads=args.threads)
    fft_buf_ch2 = pyfftw.empty_aligned(512, dtype='float32')
    rfft_ch2 = pyfftw.builders.rfft(fft_buf_ch2, threads=args.threads)
    ifft_buf_ch2 = pyfftw.empty_aligned(257, dtype='complex64')
    irfft_ch2 = pyfftw.builders.irfft(ifft_buf_ch2, threads=args.threads)

t_ring = collections.deque(maxlen=100)


def callback(indata, outdata, frames, buf_time, status):
    # buffer and states to global
    global in_buffer_ch1, in_buffer_ch2, out_buffer_ch1, out_buffer_ch2, states_1_ch1, states_2_ch1, states_1_ch2,\
        states_2_ch2, t_ring, g_use_fftw
    if args.measure:
        start_time = time.time()
    if status:
        print(status)
    # if args.channels is not None:
    #     indata = indata[:, [args.channels]]
    indata = indata[:, args.channels]
    print(indata.shape)
    if args.no_denoise:
        outdata[:] = indata
        if args.measure:
            t_ring.append(time.time() - start_time)
        return
    # write to buffer
    in_buffer_ch1[:-block_shift] = in_buffer_ch1[block_shift:]
    in_buffer_ch1[-block_shift:] = np.squeeze(indata[:, 0])
    in_buffer_ch2[:-block_shift] = in_buffer_ch2[block_shift:]
    in_buffer_ch2[-block_shift:] = np.squeeze(indata[:, 1])
    # calculate fft of input block
    if g_use_fftw:
        fft_buf_ch1[:] = in_buffer_ch1
        in_block_fft_ch1 = rfft_ch1()
        fft_buf_ch2[:] = in_buffer_ch2
        in_block_fft_ch2 = rfft_ch2()
    else:
        in_block_fft = np.fft.rfft(in_buffer_ch1)
        in_block_fft = np.fft.rfft(in_buffer_ch2)
    in_mag_ch1 = np.abs(in_block_fft_ch1)
    in_phase_ch1 = np.angle(in_block_fft_ch1)
    in_mag_ch2 = np.abs(in_block_fft_ch2)
    in_phase_ch2 = np.angle(in_block_fft_ch2)
    # reshape magnitude to input dimensions
    in_mag_ch1 = np.reshape(in_mag_ch1, (1, 1, -1)).astype('float32')
    in_mag_ch2 = np.reshape(in_mag_ch2, (1, 1, -1)).astype('float32')

    # set tensors to the first model
    interpreter_1.set_tensor(input_details_1[1]['index'], states_1_ch1)
    interpreter_1.set_tensor(input_details_1[0]['index'], in_mag_ch1)
    # run calculation
    interpreter_1.invoke()
    # get the output of the first block
    out_mask_ch1 = interpreter_1.get_tensor(output_details_1[0]['index'])
    states_1_ch1 = interpreter_1.get_tensor(output_details_1[1]['index'])
    # calculate the ifft
    estimated_complex_ch1 = in_mag_ch1 * out_mask_ch1 * np.exp(1j * in_phase_ch1)
    if g_use_fftw:
        ifft_buf_ch1[:] = estimated_complex_ch1
        estimated_block_ch1 = irfft_ch1()
    else:
        estimated_block_ch1 = np.fft.irfft(estimated_complex_ch1)
    # reshape the time domain block
    estimated_block_ch1 = np.reshape(estimated_block_ch1, (1, 1, -1)).astype('float32')

    # set tensors to the first model
    interpreter_1.set_tensor(input_details_1[1]['index'], states_1_ch2)
    interpreter_1.set_tensor(input_details_1[0]['index'], in_mag_ch2)
    # run calculation
    interpreter_1.invoke()
    # get the output of the first block
    out_mask_ch2 = interpreter_1.get_tensor(output_details_1[0]['index'])
    states_1_ch2 = interpreter_1.get_tensor(output_details_1[1]['index'])
    # calculate the ifft
    estimated_complex_ch2 = in_mag_ch2 * out_mask_ch2 * np.exp(1j * in_phase_ch2)
    if g_use_fftw:
        ifft_buf_ch2[:] = estimated_complex_ch2
        estimated_block_ch2 = irfft_ch2()
    else:
        estimated_block_ch2 = np.fft.irfft(estimated_complex_ch2)
    # reshape the time domain block
    estimated_block_ch2 = np.reshape(estimated_block_ch2, (1, 1, -1)).astype('float32')

    # set tensors to the second block
    interpreter_2.set_tensor(input_details_2[1]['index'], states_2_ch1)
    interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block_ch1)
    # run calculation
    interpreter_2.invoke()
    # get output tensors
    out_block_ch1 = interpreter_2.get_tensor(output_details_2[0]['index'])
    states_2_ch1 = interpreter_2.get_tensor(output_details_2[1]['index'])
    # write to buffer
    out_buffer_ch1[:-block_shift] = out_buffer_ch1[block_shift:]
    out_buffer_ch1[-block_shift:] = np.zeros((block_shift))
    out_buffer_ch1 += np.squeeze(out_block_ch1)
    # output to soundcard
    # outdata[:, 0] = np.expand_dims(out_buffer_ch1[:block_shift], axis=-1)
    outdata[:, 0] = out_buffer_ch1[:block_shift]

    # set tensors to the second block
    interpreter_2.set_tensor(input_details_2[1]['index'], states_2_ch2)
    interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block_ch2)
    # run calculation
    interpreter_2.invoke()
    # get output tensors
    out_block_ch2 = interpreter_2.get_tensor(output_details_2[0]['index'])
    states_2_ch2 = interpreter_2.get_tensor(output_details_2[1]['index'])
    # write to buffer
    out_buffer_ch2[:-block_shift] = out_buffer_ch2[block_shift:]
    out_buffer_ch2[-block_shift:] = np.zeros((block_shift))
    out_buffer_ch2 += np.squeeze(out_block_ch2)
    # output to soundcard
    # outdata[:, 1] = np.expand_dims(out_buffer_ch2[:block_shift], axis=-1)
    outdata[:, 1] = out_buffer_ch2[:block_shift]
    # print(indata.shape)
    # print(outdata.shape)
    if args.measure:
        t_ring.append(time.time() - start_time)


def open_stream():
    with sd.Stream(device=(args.input_device, args.output_device), samplerate=fs_target, blocksize=block_shift,
                   dtype=np.float32, latency=args.latency, channels=args.channels, callback=callback):
        print('#' * 80)
        print('Ctrl-C to exit')
        print('#' * 80)
        if args.measure:
            while True:
                time.sleep(1)
                print('Processing time: {:.2f} ms'.format(1000 * np.average(t_ring)), end='\r')
        else:
            threading.Event().wait()


try:
    if args.daemonize:
        with daemon.DaemonContext():
            open_stream()
    else:
        open_stream()
except KeyboardInterrupt:
    parser.exit('')
except Exception as e:
    parser.exit(type(e).__name__ + ': ' + str(e))

razor1179 avatar Jul 01 '22 19:07 razor1179