DeepFilterNet
DeepFilterNet copied to clipboard
Pure torch implementation
I've created this issue about pure torch reimplementation - https://github.com/Rikorose/DeepFilterNet/issues/430
Sharing code. This is draft PR, so right now work still in progress, and i can make some changes later.
You can find my implementation in folder torchDF
. Also there is a README.md there with some details.
I'll be glad to hear your feedback.
Also there are some changes in deepfilternet3.py
, modules.py
, multiframe.py
. It was necessary to reach 100% compatibility of streaming tract model and offline enhance method.
This fork was created based on https://github.com/Rikorose/DeepFilterNet/commit/ca46bf54afaf8ace3272aaee5931b4317bd6b5f4. Therefore, some code may be a little outdated.
Offline model torch implementation in torchDF/torch_df_offline.py
.
Streaming model torch implementation in torchDF/torch_df_streaming.py
To convert streaming model to onnx you can use torchDF/model_onnx_export.py
Thank you for your great work!
In the code [torch_df_offline.py], sample_rate is set to 48000, and self.erb_indices = torch.tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 5, 7, 7, 8, 10, 12, 13 , 15, 18, 20, 24, 28, 31, 37, 42, 50, 56, 67])
When sample_rate=16000, what should self.erb_indices be set to?
@dingchaoyue
here is original code, you can calculate feature bank using it
https://github.com/Rikorose/DeepFilterNet/blob/59789e135cb5ed0eb86bb50e8f1be09f60859d5c/libDF/src/lib.rs#L62
Thank you. Problem solved
Hi, thanks for this PR! I would be interested in merging it. A few things that I want to discuss:
- Only keep things related to the pytorch real-time path in here. Make separat PRs for other stuff, e.g. jemalloc for pyDF etc.
- If possible, I would like not introduce another python wheel. Do you see any option, to include this version as a tool.poetry.script within the DeepFilterNet wheel?
- Can you add a test maybe here?
- I'm using the pyDF changes to run a test to compare the original Rust and Torch threading implementations. What do you think about the tests then? Do I need to remove these test then? Do we need to keep this test in this PR?
- Yeah, I see this implementation more for users understanding. We can not introduce this as an another wheel. I can try to add it as
tool.poetry.script
and figure out how to make it better. Can you describe the logic behind it and how you usetool.poetry.script
right now? - I can add it here, yes. But we need to decide what to do with tests (as i described in point 1)
Hey if anyone can recreate the LADSPA plugin with Onnx, please do as the single thread of tract needs a pretty big single core.
A few notes: 1.1 Don't remove existing depending features in pyDF (logging, transforms) 1.2 Don't add features in this PR that is unrelated (jemalloc) 1.3 Hide tract dependency behind a feature flag. This should not be compiled by default
[features]
tract = ["deep_filter/tract", "deep_filter/default-model"]
However, I am not sure if there is a way to add this as an optional dependency to the deepfilterlib
dependency in the pyproject file. Because, by default, the tract dependency is not necessary to compile for a standard use case (e.g. pytorch inference or pytorch training). But can you add the deepfilterlib
without tract as it is currently, and only add deepfilterlib[tract]
e.g. only when DeepFilterNet is installed with an additional tract
feature?
1.4 is tract_core really necessary? It should be pulled via libDF/tract
2.1 You can just add an additional readme_py_rt.md or so in DeepFilterNet and create a script at DeepFilterNet/scripts. Then you can add this python script to the tool.poetry.script
section. You can move the other scripts e.g. model_onnx_export.py also to the scripts folder. But we need to find a solution for the tract dependency of deepfilterlib
.
2.2 A different option would be, to create a submodule e.g. at DeepFilterNet/py_rt/
. Maybe this is more convenient.
4. Remove the .wav example. Please use the existing samples at assets/noisy_snr0.wav
@Rikorose I saw your comment, right not I have more important tasks now. But I'll come back in a while and fix it.
Hello author, when I build the model, TypeError: forward() takes 3 positional arguments but 4 were given. This is caused by no hidden parameters. How can I solve it? I don’t know where the enc module code loaded there is located, encoder Adding hidden vectors to the module still cannot solve the problem. The hidden state is not loaded during training.
Hello! Can you provide more details? What code are you running?
Encoder in this reimplementation takes 4 parameters
https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/DeepFilterNet/df/deepfilternet3.py#L168
https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF/torch_df_streaming.py#L492
self, feat_erb: Tensor, feat_spec: Tensor, hidden: Tensor
Btw this code for inference only, i didn't do anything with training.
Also I didn't commit in this branch for a long time, so nothing should break if you running code with no changes.
What do you mean by "build"? Do tests passing? You can watch here how to run tests
Hello, this is because I installed the df library in the environment. The function calls are caused by the functions in the library not calling the functions you set.
Hello, the onnx model is generated through your model_onnx_export.py, and the audio is read from the device and monitored in real time through onnx. The effect is not good. Is it because the code is not used correctly.
@hulucky1102 You can look here for an example on how to inference streaming version correctly
https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF/test_torchdf.py#L70
Also, check that you export with always_apply_all_stages=True
parameter:
https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF/model_onnx_export.py#L232
Thank you very much, this problem is caused by me passing in attend_lim_db. Is there some setting in the model that will cause muting, which will cause some of my speech to be incoherent, so I would like to remove the muting setting in the model.
Both use DeepFilterNet3, one uses API enhance simulation and the other uses onnx simulation. The above two results are inconsistent. The result of using onnx simulation has vocal loss.
Is it happening only with ONNX inference? Or with torch inference of streaming version too?
This situation occurs in both onnx and torch streaming
Hi @grazder,
Thanks for your amazing work! I was hoping you could help clear up something for me.
As part of your change-set, you added a new 'hidden_states' input/output tensor to Encoder, ErbDecoder, and DfDecoder (in deepfilterne3.py). And I see that in your streaming implementation, these are used as part of the state management logic.
What I am confused about is that in the main branch's streaming implementation (i.e. the Rust implementation), it appears to work using these onnx models, DeepFilterNet3_onnx.tar.gz. But I don't see these 'hidden states' tensors exposed as inputs / outputs to these models. So how does the Rust implementation work without the ability to manage these? It seems like it's necessary based on your pure pytorch implementation.
It's very possible that I overlooked something simple..
Thanks again! Ryan
hi, @grazder thanks for your work, I found some issues in the code, (or maybe I'm wrong) In torch_df_streaming.py, self.rolling_spec_buf_y_shape = (self.df_order + 2, self.freq_size, 2) ,but in lines 706:current_spec = new_rolling_spec_buf_y[self.df_order - 1] , I think this should be (self.df_order + 2 -1), Would love to hear your opinion. Best wish!
Hello, we need to store two future frames for a single frame prediction, so self.df_order - 1
is correct
Also you can find it here:
https://github.com/Rikorose/DeepFilterNet/blob/f2445da10ce7760ac41d272ce4699200333a6e32/libDF/src/tract.rs#L586
Thanks for your answer! @grazder
Sorry to bother you again.I would like to get a tflite model that can perform real-time inference on DSP. If I use the onnx model obtained from your model_onnx_export.py file, will there be any problems when converting the onnx model to tf model? I noticed that you are using a newly registered operator.
Secondly, if I want to use C language for inference on DSP, would you suggest that I use the original author's single complete model or three models? I have relatively little experience with deployment.
Wish everything goes well with your work!
@FisherDom
will there be any problems when converting the onnx model to tf model
I think that you cat face problems with RFFT / IRFFT, I don't know a lot about tf operations, so I can't say exactly.
I noticed that you are using a newly registered operator
Yeah, new operator is in torchDF_main branch. New operator gave me like ~x2.5 speedup. Also in this branch there are some else graph optimizations.
You can use torchDF-changes branch (from this PR). In that variant RFFT / IRFFT implemented as matmul, which is suboptimal, but you I bet you will not face problems with RFFT / IRFFT export.
I want to use C language for inference on DSP, would you suggest that I use the original author's single complete model or three models?
Well you can use C capi, you can find build in actions or you can look at actions config. I didn't benchmark original Rust speed and my implementation for a long time. But when I checked original implementation (using Rust) has the same speed as torchDF-changes branch. But torchDF_main now much faster, so you can try it.
Also if you want to use C, you can try onnxruntime C API
@grazder Hallo! Thank you very much for your contribution on providing the Pure Torch code , which saved me a lot of time that would have been spent learning Rust.
When I run your code,
torch_df_offline.py(https://github.com/grazder/DeepFilterNet/blob/torchDF-changes/torchDF/torch_df_offline.py)
I just found a little bug, which the durations do not match. The duration of the output audio file is shorter than the input. Therefore, upon inspecting your code, I found that the ‘frame_synthesis' function does not consider the second part of the last audio block.
Below is the result of my modification for your reference.
def frame_synthesis(self, input_data, i_last_record, out_chunks): # added two additional variables.
"""
Original code - libDF/src/lib.rs - frame_synthesis()
Inverse rfft for one frame. Every frame is summarized with buffer from previous frame.
And saving buffer for next frame.
Parameters:
input_data: Complex[F] - Enhanced audio spectrogram
Returns:
output: Float[f] - Enhanced audio
"""
x = torch.fft.irfft(input_data, norm='forward') * self.window
x_first, x_second = torch.split(x, [self.frame_size, x.shape[0] - self.frame_size])
output = x_first + self.synthesis_mem
self.synthesis_mem = x_second
if i_last_record == out_chunks: # if the chunk of outwav is the last chunk
output = output + x_second + x_second
return output
@grazder sry to bother you! I want to quantify the model at 8 bits now, but I have only quantified the CV model. Can you give me some hints or information? Is it quantized on torch or onnx? Best wish!
@FisherDom
Hello! I've tried to quantify with ONNX here - https://github.com/grazder/DeepFilterNet/blob/torchDF-temp/torchDF/model_onnx_export.py But it didn't gave me anything, seems like old ops became faster, but because of many quantize / dequantize nodes models didn't became much faster or smaller.