resemble-enhance
resemble-enhance copied to clipboard
How to export models to ONNX
Thank you for your work, the performance of this model is quite good. I would like to deploy and use it. Is there a way to export it to ONNX?
I've made onnx version of the denoiser here: https://github.com/skeskinen/resemble-denoise-onnx-inference It's deployed in https://smartmediacutter.com/ and works quite nicely.
This does not do the enhancer and I'm not sure how easy it is to do the enhancer model in onnx. The enhancer is too slow to run for my usecase at the moment.
I've made onnx version of the denoiser here: https://github.com/skeskinen/resemble-denoise-onnx-inference It's deployed in https://smartmediacutter.com/ and works quite nicely.
This does not do the enhancer and I'm not sure how easy it is to do the enhancer model in onnx. The enhancer is too slow to run for my usecase at the moment.
Thank you so much for sharing it! It's really helpful and I appreciate your work.
My workaround led me to these fixes for make it exportable into ONNX
Comment the Generator because it's not exportable by ONNX
resemble-enhance/enhancer/lcfm/cfm.py
def _sample_ψ0(self, x: Tensor):
"""
Args:
x: (b c t), which implies the shape of ψ0
"""
shape = list(x.shape)
shape[1] = self.output_dim
# if self.training:
# g = None
# else:
# g = torch.Generator(device=x.device)
# g.manual_seed(0) # deterministic sampling during eval
ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype) # , generator=g)
return ψ0
Change .expand function to .repeat.
resemble/enhancer/univnet/alias_free_torch/resample.py
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
weight = self.filter.repeat([int(x.shape[1]), 1, 1])
x = self.ratio * F.conv_transpose1d(x, weight, stride=self.stride, groups=C) # self.filter.expand(C, -1, -1)
shape = x.shape
shape = [int(elem) for elem in shape]
x = torch.reshape(x, shape)
x = x[..., self.pad_left:-self.pad_right]
return x
Write your own custom Tensor.unfold function
resemble_enhance/enhancer/univnet/lvcnet.py
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
"""
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (
kernel_length * hop_size
), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x = custom_unfold_dim_2(x, hop_size + 2 * padding, hop_size)
# x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0)
x = custom_unfold_dim_3(x, dilation, dilation)
# x = x.unfold(3, dilation, dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = custom_unfold_dim_4(x, kernel_size, 1)
# x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o.to(memory_format=torch.channels_last_3d)
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
o = o + bias
o = o.contiguous().view(batch, out_channels, -1)
return o
def custom_unfold_dim_2(x: torch.Tensor, window_size: int, step: int):
dim = 2
subtensors = [x[:, :, i:i + window_size, ...] for i in range(0, x.size(dim) - window_size + 1, step)]
result = torch.stack(subtensors, dim=dim)
return result
def custom_unfold_dim_3(x: torch.Tensor, window_size: int, step: int):
dim = 3
subtensors = [x[:, :, :, i:i + window_size, ...] for i in range(0, x.size(dim) - window_size + 1, step)]
result = torch.stack(subtensors, dim=dim)
return result
def custom_unfold_dim_4(x: torch.Tensor, window_size: int, step: int):
dim = 4
subtensors = [x[:, :, :, :, i:i + window_size, ...] for i in range(0, x.size(dim) - window_size + 1, step)]
result = torch.stack(subtensors, dim=dim)
return result
Export script
P.S. I had to rename the resemble-enhance source folder to src to avoid problems with package resemble-enhance which is also was installed in my venv.
import logging
import time
import os
import click
import torch
import torchaudio
from torch.nn.functional import pad
from torchaudio.functional import resample
from src.enhancer.inference import load_enhancer
from src.hparams import HParams
from src.inference import merge_chunks, remove_weight_norm_recursively
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@click.command()
@click.option('--wav-path', type=click.Path(exists=True), help='Path to input wav file')
@click.option('--save-path', type=str, default='output.wav', help='Path to save output wav file')
@click.option('--run-dir', type=str, default=None, help='Path to run directory')
@click.option('--device', type=str, default='cuda', help='Device to use for computation')
@click.option('--nfe', type=int, default=32, help='Number of function evaluations')
@click.option('--solver', type=str, default='midpoint', help='Numerical solver to use')
@click.option('--lambd', type=float, default=0.5, help='Denoise strength')
@click.option('--tau', type=float, default=0.5, help='CFM prior temperature')
@click.option('--chunk-seconds', type=float, default=30.0, help='Length of each chunk in seconds')
@click.option('--overlap-seconds', type=float, default=1.0, help='Overlap between chunks in seconds')
@click.option('--export-onnx', type=bool, default=False, help='Do you need to export enhancer model to ONNX?')
@click.option('--onnx-path', type=click.Path(exists=True), default="onnx", help='The path where ONNX files will be saved')
def main(
wav_path: str,
save_path: str = "output.wav",
run_dir: str | None = None,
device: str = "cuda",
nfe: int = 32,
solver: str = "midpoint",
lambd: float = 0.5,
tau: float = 0.5,
chunk_seconds: float = 30.0,
overlap_seconds: float = 1.0,
export_onnx: bool = False,
onnx_path: str = "onnx",
):
if torch.cuda.is_available():
torch.cuda.synchronize()
if device == "cuda":
torch.cuda.empty_cache()
else:
device = "cpu"
enhancer = load_enhancer(run_dir, device)
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
enhancer.eval()
enhancer.lcfm.eval()
remove_weight_norm_recursively(enhancer)
hp: HParams = enhancer.hp
enhancer.to(device)
dwav, sr = torchaudio.load(wav_path)
dwav = dwav.mean(dim=0)
dwav = resample(
dwav,
orig_freq=sr,
new_freq=hp.wav_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="sinc_interp_kaiser",
beta=14.769656459379492,
)
result_audio_length = dwav.shape[-1]
start_time = time.perf_counter()
chunk_length = int(sr * chunk_seconds)
overlap_length = int(sr * overlap_seconds)
hop_length = chunk_length - overlap_length
chunks = [dwav[i:i + chunk_length] for i in range(0, dwav.shape[-1], hop_length)]
input_chunks = torch.stack([pad(chunk, (0, chunk_length - len(chunk))) for chunk in chunks], dim=0)
abs_max = input_chunks.abs().max(dim=1, keepdim=True).values
abs_max[abs_max == 0] = 10e-7
input_chunks = input_chunks / abs_max
input_chunks = input_chunks.to(device)
with torch.inference_mode() and torch.no_grad():
output = enhancer(input_chunks).to("cpu")
output = output * abs_max
audio = merge_chunks(output, chunk_length, hop_length, sr=hp.wav_rate)
elapsed_time = time.perf_counter() - start_time
logger.info(f"Elapsed time: {elapsed_time:.3f} s, {audio.shape[-1] / elapsed_time / 1000:.3f} kHz")
torchaudio.save(save_path, audio[None, :result_audio_length], hp.wav_rate)
if export_onnx:
logger.info("Exporting enhancer model to ONNX")
with torch.no_grad():
mel_spectrogram = enhancer.to_mel(input_chunks)
normalizer_result = enhancer.normalizer(mel_spectrogram)
lcfm_result = enhancer.lcfm(normalizer_result)
logger.info("Exporting normalizer model to ONNX")
torch.onnx.export(
enhancer.normalizer,
mel_spectrogram,
os.path.join(onnx_path, "normalizer.onnx"),
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names = ['input'],
output_names = ['output'],
dynamic_axes={'input' : {0 : 'batch_size', 2 : 'mel_length'},
'output' : {0 : 'batch_size', 2 : 'mel_length'}}
)
logger.info("Exporting lcfm model to ONNX")
torch.onnx.export(
enhancer.lcfm,
normalizer_result,
os.path.join(onnx_path, "lcfm.onnx"),
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names = ['input'],
output_names = ['output'],
dynamic_axes={'input' : {0 : 'batch_size', 2 : 'mel_length'},
'output' : {0 : 'batch_size', 2 : 'mel_length'}}
)
logger.info("Exporting vocder model to ONNX")
torch.onnx.export(
enhancer.vocoder,
lcfm_result,
os.path.join(onnx_path, "vocder.onnx"),
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names = ['input'],
output_names = ['output'],
dynamic_axes={'input' : {0 : 'batch_size', 2 : 'mel_length'},
'output' : {0 : 'batch_size', 1 : 'audio_length'}}
)
if __name__ == "__main__":
main()
@alexey-laletin-singularitynet have you tried running the exported onnx models with onnxruntime ? Here it fails when trying to run lcfm.onnx, with the following error:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Conv node. Name:'/cfm/net/layers.1/dconv/Conv' Status Message: Dilation not supported for AutoPadType::SAME_UPPER or AutoPadType::SAME_LOWE
Which corresponds to this line:
https://github.com/resemble-ai/resemble-enhance/blob/main/resemble_enhance/enhancer/lcfm/wn.py#L33
self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
To do a quick check, just add this at the end of your onnx exporting code:
import onnxruntime as ort
print("checking normalizer...")
ort_normalizer = ort.InferenceSession(os.path.join(onnx_path, "normalizer.onnx"))
result = torch.tensor(ort_normalizer.run(None, {'input': mel_spectrogram.numpy()})).squeeze(0)
print("checking lcfm...")
ort_lcfm = ort.InferenceSession(os.path.join(onnx_path, "lcfm.onnx"))
result = torch.tensor(ort_lcfm.run(None, {'input': normalizer_result.numpy()})).squeeze(0)
print("checking vocoder...")
ort_vocoder = ort.InferenceSession(os.path.join(onnx_path, "vocoder.onnx"))
result = torch.tensor(ort_vocoder.run(None, {'input': lcfm_result.numpy()})).squeeze(0)
print("checking done")
This is using torch 2.6.0, onnx 1.17.0, onnxruntime 1.21.0 with python 1.12 on macos.
Ok, the trick was to replace instances of padding="same" with padding=int(dilation * (kernel_size - 1) // 2).
https://github.com/microsoft/onnxruntime/issues/20582#issuecomment-2099348543