DeepFilterNet icon indicating copy to clipboard operation
DeepFilterNet copied to clipboard

Pure torch implementation

Open grazder opened this issue 1 year ago • 26 comments

I've created this issue about pure torch reimplementation - https://github.com/Rikorose/DeepFilterNet/issues/430

Sharing code. This is draft PR, so right now work still in progress, and i can make some changes later. You can find my implementation in folder torchDF. Also there is a README.md there with some details.

I'll be glad to hear your feedback.

Also there are some changes in deepfilternet3.py, modules.py, multiframe.py. It was necessary to reach 100% compatibility of streaming tract model and offline enhance method.

This fork was created based on https://github.com/Rikorose/DeepFilterNet/commit/ca46bf54afaf8ace3272aaee5931b4317bd6b5f4. Therefore, some code may be a little outdated.

Offline model torch implementation in torchDF/torch_df_offline.py.

Streaming model torch implementation in torchDF/torch_df_streaming.py

To convert streaming model to onnx you can use torchDF/model_onnx_export.py

grazder avatar Sep 20 '23 08:09 grazder

Thank you for your great work!

In the code [torch_df_offline.py], sample_rate is set to 48000, and self.erb_indices = torch.tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 5, 7, 7, 8, 10, 12, 13 , 15, 18, 20, 24, 28, 31, 37, 42, 50, 56, 67])

When sample_rate=16000, what should self.erb_indices be set to?

dingchaoyue avatar Oct 07 '23 15:10 dingchaoyue

@dingchaoyue

here is original code, you can calculate feature bank using it

https://github.com/Rikorose/DeepFilterNet/blob/59789e135cb5ed0eb86bb50e8f1be09f60859d5c/libDF/src/lib.rs#L62

grazder avatar Oct 07 '23 15:10 grazder

Thank you. Problem solved

dingchaoyue avatar Oct 07 '23 18:10 dingchaoyue

Hi, thanks for this PR! I would be interested in merging it. A few things that I want to discuss:

  1. Only keep things related to the pytorch real-time path in here. Make separat PRs for other stuff, e.g. jemalloc for pyDF etc.
  2. If possible, I would like not introduce another python wheel. Do you see any option, to include this version as a tool.poetry.script within the DeepFilterNet wheel?
  3. Can you add a test maybe here?

Rikorose avatar Oct 18 '23 05:10 Rikorose

  1. I'm using the pyDF changes to run a test to compare the original Rust and Torch threading implementations. What do you think about the tests then? Do I need to remove these test then? Do we need to keep this test in this PR?
  2. Yeah, I see this implementation more for users understanding. We can not introduce this as an another wheel. I can try to add it as tool.poetry.script and figure out how to make it better. Can you describe the logic behind it and how you use tool.poetry.script right now?
  3. I can add it here, yes. But we need to decide what to do with tests (as i described in point 1)

grazder avatar Oct 19 '23 08:10 grazder

Hey if anyone can recreate the LADSPA plugin with Onnx, please do as the single thread of tract needs a pretty big single core.

StuartIanNaylor avatar Oct 20 '23 14:10 StuartIanNaylor

A few notes: 1.1 Don't remove existing depending features in pyDF (logging, transforms) 1.2 Don't add features in this PR that is unrelated (jemalloc) 1.3 Hide tract dependency behind a feature flag. This should not be compiled by default

[features]
tract = ["deep_filter/tract", "deep_filter/default-model"]

However, I am not sure if there is a way to add this as an optional dependency to the deepfilterlib dependency in the pyproject file. Because, by default, the tract dependency is not necessary to compile for a standard use case (e.g. pytorch inference or pytorch training). But can you add the deepfilterlib without tract as it is currently, and only add deepfilterlib[tract] e.g. only when DeepFilterNet is installed with an additional tract feature? 1.4 is tract_core really necessary? It should be pulled via libDF/tract 2.1 You can just add an additional readme_py_rt.md or so in DeepFilterNet and create a script at DeepFilterNet/scripts. Then you can add this python script to the tool.poetry.script section. You can move the other scripts e.g. model_onnx_export.py also to the scripts folder. But we need to find a solution for the tract dependency of deepfilterlib. 2.2 A different option would be, to create a submodule e.g. at DeepFilterNet/py_rt/. Maybe this is more convenient. 4. Remove the .wav example. Please use the existing samples at assets/noisy_snr0.wav

Rikorose avatar Oct 20 '23 18:10 Rikorose

@Rikorose I saw your comment, right not I have more important tasks now. But I'll come back in a while and fix it.

grazder avatar Oct 31 '23 08:10 grazder

Hello author, when I build the model, TypeError: forward() takes 3 positional arguments but 4 were given. This is caused by no hidden parameters. How can I solve it? I don’t know where the enc module code loaded there is located, encoder Adding hidden vectors to the module still cannot solve the problem. The hidden state is not loaded during training. image

hulucky1102 avatar Jan 03 '24 11:01 hulucky1102

Hello! Can you provide more details? What code are you running?

Encoder in this reimplementation takes 4 parameters

https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/DeepFilterNet/df/deepfilternet3.py#L168

https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF/torch_df_streaming.py#L492

self, feat_erb: Tensor, feat_spec: Tensor, hidden: Tensor

Btw this code for inference only, i didn't do anything with training.

Also I didn't commit in this branch for a long time, so nothing should break if you running code with no changes.

What do you mean by "build"? Do tests passing? You can watch here how to run tests

grazder avatar Jan 03 '24 11:01 grazder

Hello, this is because I installed the df library in the environment. The function calls are caused by the functions in the library not calling the functions you set.

hulucky1102 avatar Jan 04 '24 02:01 hulucky1102

Hello, the onnx model is generated through your model_onnx_export.py, and the audio is read from the device and monitored in real time through onnx. The effect is not good. Is it because the code is not used correctly. image

hulucky1102 avatar Jan 04 '24 11:01 hulucky1102

@hulucky1102 You can look here for an example on how to inference streaming version correctly

https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF/test_torchdf.py#L70

Also, check that you export with always_apply_all_stages=True parameter:

https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF/model_onnx_export.py#L232

grazder avatar Jan 04 '24 12:01 grazder

Thank you very much, this problem is caused by me passing in attend_lim_db. Is there some setting in the model that will cause muting, which will cause some of my speech to be incoherent, so I would like to remove the muting setting in the model.

hulucky1102 avatar Jan 05 '24 02:01 hulucky1102

Both use DeepFilterNet3, one uses API enhance simulation and the other uses onnx simulation. The above two results are inconsistent. The result of using onnx simulation has vocal loss. image image

hulucky1102 avatar Jan 05 '24 03:01 hulucky1102

Is it happening only with ONNX inference? Or with torch inference of streaming version too?

grazder avatar Jan 06 '24 08:01 grazder

This situation occurs in both onnx and torch streaming

hulucky1102 avatar Jan 06 '24 10:01 hulucky1102

Hi @grazder,

Thanks for your amazing work! I was hoping you could help clear up something for me.

As part of your change-set, you added a new 'hidden_states' input/output tensor to Encoder, ErbDecoder, and DfDecoder (in deepfilterne3.py). And I see that in your streaming implementation, these are used as part of the state management logic.

What I am confused about is that in the main branch's streaming implementation (i.e. the Rust implementation), it appears to work using these onnx models, DeepFilterNet3_onnx.tar.gz. But I don't see these 'hidden states' tensors exposed as inputs / outputs to these models. So how does the Rust implementation work without the ability to manage these? It seems like it's necessary based on your pure pytorch implementation.

It's very possible that I overlooked something simple..

Thanks again! Ryan

RyanMetcalfeInt8 avatar Mar 06 '24 13:03 RyanMetcalfeInt8

hi, @grazder thanks for your work, I found some issues in the code, (or maybe I'm wrong) In torch_df_streaming.py, self.rolling_spec_buf_y_shape = (self.df_order + 2, self.freq_size, 2) ,but in lines 706:current_spec = new_rolling_spec_buf_y[self.df_order - 1] , I think this should be (self.df_order + 2 -1), Would love to hear your opinion. Best wish!

FisherDom avatar Apr 07 '24 04:04 FisherDom

Hello, we need to store two future frames for a single frame prediction, so self.df_order - 1 is correct

Also you can find it here:

https://github.com/Rikorose/DeepFilterNet/blob/f2445da10ce7760ac41d272ce4699200333a6e32/libDF/src/tract.rs#L586

grazder avatar Apr 07 '24 14:04 grazder

Thanks for your answer! @grazder

Sorry to bother you again.I would like to get a tflite model that can perform real-time inference on DSP. If I use the onnx model obtained from your model_onnx_export.py file, will there be any problems when converting the onnx model to tf model? I noticed that you are using a newly registered operator.

Secondly, if I want to use C language for inference on DSP, would you suggest that I use the original author's single complete model or three models? I have relatively little experience with deployment.

Wish everything goes well with your work!

FisherDom avatar Apr 10 '24 09:04 FisherDom

@FisherDom

will there be any problems when converting the onnx model to tf model

I think that you cat face problems with RFFT / IRFFT, I don't know a lot about tf operations, so I can't say exactly.

I noticed that you are using a newly registered operator

Yeah, new operator is in torchDF_main branch. New operator gave me like ~x2.5 speedup. Also in this branch there are some else graph optimizations.

You can use torchDF-changes branch (from this PR). In that variant RFFT / IRFFT implemented as matmul, which is suboptimal, but you I bet you will not face problems with RFFT / IRFFT export.

I want to use C language for inference on DSP, would you suggest that I use the original author's single complete model or three models?

Well you can use C capi, you can find build in actions or you can look at actions config. I didn't benchmark original Rust speed and my implementation for a long time. But when I checked original implementation (using Rust) has the same speed as torchDF-changes branch. But torchDF_main now much faster, so you can try it.

Also if you want to use C, you can try onnxruntime C API

grazder avatar Apr 10 '24 10:04 grazder

@grazder Hallo! Thank you very much for your contribution on providing the Pure Torch code , which saved me a lot of time that would have been spent learning Rust. When I run your code, torch_df_offline.py(https://github.com/grazder/DeepFilterNet/blob/torchDF-changes/torchDF/torch_df_offline.py) I just found a little bug, which the durations do not match. The duration of the output audio file is shorter than the input. Therefore, upon inspecting your code, I found that the ‘frame_synthesis' function does not consider the second part of the last audio block. Below is the result of my modification for your reference.

    def frame_synthesis(self, input_data, i_last_record, out_chunks):  # added two additional variables.
        """
        Original code - libDF/src/lib.rs - frame_synthesis()
        Inverse rfft for one frame. Every frame is summarized with buffer from previous frame.
        And saving buffer for next frame.

        Parameters:
            input_data: Complex[F] - Enhanced audio spectrogram

        Returns:
            output:     Float[f] - Enhanced audio
        """
        x = torch.fft.irfft(input_data, norm='forward') * self.window
        x_first, x_second = torch.split(x, [self.frame_size, x.shape[0] - self.frame_size])
        output = x_first + self.synthesis_mem  

        self.synthesis_mem = x_second

        if i_last_record == out_chunks:  # if the chunk of outwav is the last chunk
            output = output + x_second + x_second

        return output

viki347 avatar Apr 22 '24 09:04 viki347

@grazder sry to bother you! I want to quantify the model at 8 bits now, but I have only quantified the CV model. Can you give me some hints or information? Is it quantized on torch or onnx? Best wish!

FisherDom avatar May 20 '24 03:05 FisherDom

@FisherDom

Hello! I've tried to quantify with ONNX here - https://github.com/grazder/DeepFilterNet/blob/torchDF-temp/torchDF/model_onnx_export.py But it didn't gave me anything, seems like old ops became faster, but because of many quantize / dequantize nodes models didn't became much faster or smaller.

grazder avatar May 20 '24 08:05 grazder