resemble-enhance icon indicating copy to clipboard operation
resemble-enhance copied to clipboard

How to export models to ONNX

Open lanyuer opened this issue 1 year ago • 5 comments

Thank you for your work, the performance of this model is quite good. I would like to deploy and use it. Is there a way to export it to ONNX?

lanyuer avatar Jul 24 '24 05:07 lanyuer

I've made onnx version of the denoiser here: https://github.com/skeskinen/resemble-denoise-onnx-inference It's deployed in https://smartmediacutter.com/ and works quite nicely.

This does not do the enhancer and I'm not sure how easy it is to do the enhancer model in onnx. The enhancer is too slow to run for my usecase at the moment.

skeskinen avatar Jul 26 '24 10:07 skeskinen

I've made onnx version of the denoiser here: https://github.com/skeskinen/resemble-denoise-onnx-inference It's deployed in https://smartmediacutter.com/ and works quite nicely.

This does not do the enhancer and I'm not sure how easy it is to do the enhancer model in onnx. The enhancer is too slow to run for my usecase at the moment.

Thank you so much for sharing it! It's really helpful and I appreciate your work.

lanyuer avatar Jul 26 '24 17:07 lanyuer

My workaround led me to these fixes for make it exportable into ONNX

Comment the Generator because it's not exportable by ONNX

resemble-enhance/enhancer/lcfm/cfm.py

def _sample_ψ0(self, x: Tensor):
    """
    Args:
        x: (b c t), which implies the shape of ψ0
    """
    shape = list(x.shape)
    shape[1] = self.output_dim
    # if self.training:
    #     g = None
    # else:
    #     g = torch.Generator(device=x.device)
    #     g.manual_seed(0)  # deterministic sampling during eval
    ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype) # , generator=g)
    return ψ0

Change .expand function to .repeat.

resemble/enhancer/univnet/alias_free_torch/resample.py

def forward(self, x):
    _, C, _ = x.shape
    x = F.pad(x, (self.pad, self.pad), mode='replicate')
    weight = self.filter.repeat([int(x.shape[1]), 1, 1])
    x = self.ratio * F.conv_transpose1d(x, weight, stride=self.stride, groups=C) # self.filter.expand(C, -1, -1)
    shape = x.shape
    shape = [int(elem) for elem in shape]
    x = torch.reshape(x, shape)
    x = x[..., self.pad_left:-self.pad_right]
 
    return x

Write your own custom Tensor.unfold function

resemble_enhance/enhancer/univnet/lvcnet.py

    def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
        """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
        Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
        Args:
            x (Tensor): the input sequence (batch, in_channels, in_length).
            kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
            bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
            dilation (int): the dilation of convolution.
            hop_size (int): the hop_size of the conditioning sequence.
        Returns:
            (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
        """
        batch, _, in_length = x.shape
        batch, _, out_channels, kernel_size, kernel_length = kernel.shape
 
        assert in_length == (
            kernel_length * hop_size
        ), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
 
        padding = dilation * int((kernel_size - 1) / 2)
        x = F.pad(x, (padding, padding), "constant", 0)  # (batch, in_channels, in_length + 2*padding)
        x = custom_unfold_dim_2(x, hop_size + 2 * padding, hop_size)
        # x = x.unfold(2, hop_size + 2 * padding, hop_size)  # (batch, in_channels, kernel_length, hop_size + 2*padding)
 
        if hop_size < dilation:
            x = F.pad(x, (0, dilation), "constant", 0)
        x = custom_unfold_dim_3(x, dilation, dilation)
        # x = x.unfold(3, dilation, dilation)  # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
        x = x[:, :, :, :, :hop_size]
        x = x.transpose(3, 4)  # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
        x = custom_unfold_dim_4(x, kernel_size, 1)
        # x = x.unfold(4, kernel_size, 1)  # (batch, in_channels, kernel_length, dilation, _, kernel_size)
 
        o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
        o = o.to(memory_format=torch.channels_last_3d)
        bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
        o = o + bias
        o = o.contiguous().view(batch, out_channels, -1)
 
        return o
 
def custom_unfold_dim_2(x: torch.Tensor, window_size: int, step: int):
    dim = 2
    subtensors = [x[:, :, i:i + window_size, ...] for i in range(0, x.size(dim) - window_size + 1, step)]
    result = torch.stack(subtensors, dim=dim)
    return result
 
def custom_unfold_dim_3(x: torch.Tensor, window_size: int, step: int):
    dim = 3
    subtensors = [x[:, :, :, i:i + window_size, ...] for i in range(0, x.size(dim) - window_size + 1, step)]
    result = torch.stack(subtensors, dim=dim)
    return result
 
def custom_unfold_dim_4(x: torch.Tensor, window_size: int, step: int):
    dim = 4
    subtensors = [x[:, :, :, :, i:i + window_size, ...] for i in range(0, x.size(dim) - window_size + 1, step)]
    result = torch.stack(subtensors, dim=dim)
    return result

Export script

P.S. I had to rename the resemble-enhance source folder to src to avoid problems with package resemble-enhance which is also was installed in my venv.

import logging
import time
import os

import click
import torch
import torchaudio
from torch.nn.functional import pad
from torchaudio.functional import resample

from src.enhancer.inference import load_enhancer
from src.hparams import HParams
from src.inference import merge_chunks, remove_weight_norm_recursively

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

@click.command()
@click.option('--wav-path', type=click.Path(exists=True), help='Path to input wav file')
@click.option('--save-path', type=str, default='output.wav', help='Path to save output wav file')
@click.option('--run-dir', type=str, default=None, help='Path to run directory')
@click.option('--device', type=str, default='cuda', help='Device to use for computation')
@click.option('--nfe', type=int, default=32, help='Number of function evaluations')
@click.option('--solver', type=str, default='midpoint', help='Numerical solver to use')
@click.option('--lambd', type=float, default=0.5, help='Denoise strength')
@click.option('--tau', type=float, default=0.5, help='CFM prior temperature')
@click.option('--chunk-seconds', type=float, default=30.0, help='Length of each chunk in seconds')
@click.option('--overlap-seconds', type=float, default=1.0, help='Overlap between chunks in seconds')
@click.option('--export-onnx', type=bool, default=False, help='Do you need to export enhancer model to ONNX?')
@click.option('--onnx-path', type=click.Path(exists=True), default="onnx", help='The path where ONNX files will be saved')
def main(
    wav_path: str,
    save_path: str = "output.wav",
    run_dir: str | None = None,
    device: str = "cuda",
    nfe: int = 32,
    solver: str = "midpoint",
    lambd: float = 0.5,
    tau: float = 0.5,
    chunk_seconds: float = 30.0, 
    overlap_seconds: float = 1.0,
    export_onnx: bool = False,
    onnx_path: str = "onnx",
):
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        if device == "cuda":
            torch.cuda.empty_cache()
    else:
        device = "cpu"
    
    enhancer = load_enhancer(run_dir, device)
    enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
    
    enhancer.eval()
    enhancer.lcfm.eval()
    remove_weight_norm_recursively(enhancer)
    hp: HParams = enhancer.hp
    enhancer.to(device)
    
    dwav, sr = torchaudio.load(wav_path)
    dwav = dwav.mean(dim=0)
    
    dwav = resample(
        dwav,
        orig_freq=sr,
        new_freq=hp.wav_rate,
        lowpass_filter_width=64,
        rolloff=0.9475937167399596,
        resampling_method="sinc_interp_kaiser",
        beta=14.769656459379492,
    )
    
    result_audio_length = dwav.shape[-1]

    start_time = time.perf_counter()

    chunk_length = int(sr * chunk_seconds)
    overlap_length = int(sr * overlap_seconds)
    hop_length = chunk_length - overlap_length
    
    chunks = [dwav[i:i + chunk_length] for i in range(0, dwav.shape[-1], hop_length)]
    input_chunks = torch.stack([pad(chunk, (0, chunk_length - len(chunk))) for chunk in chunks], dim=0)
    
    abs_max = input_chunks.abs().max(dim=1, keepdim=True).values
    abs_max[abs_max == 0] = 10e-7
    input_chunks = input_chunks / abs_max
    input_chunks = input_chunks.to(device)
    
    with torch.inference_mode() and torch.no_grad():
        output = enhancer(input_chunks).to("cpu")
        output = output * abs_max
    
    audio = merge_chunks(output, chunk_length, hop_length, sr=hp.wav_rate)
    
    elapsed_time = time.perf_counter() - start_time
    logger.info(f"Elapsed time: {elapsed_time:.3f} s, {audio.shape[-1] / elapsed_time / 1000:.3f} kHz")
    
    torchaudio.save(save_path, audio[None, :result_audio_length], hp.wav_rate)
    
    if export_onnx:
        logger.info("Exporting enhancer model to ONNX")
        with torch.no_grad():
            mel_spectrogram = enhancer.to_mel(input_chunks)
            normalizer_result = enhancer.normalizer(mel_spectrogram)
            lcfm_result = enhancer.lcfm(normalizer_result)
            
            logger.info("Exporting normalizer model to ONNX")
            torch.onnx.export(
                enhancer.normalizer,
                mel_spectrogram,
                os.path.join(onnx_path, "normalizer.onnx"),
                export_params=True,
                opset_version=17,
                do_constant_folding=True,
                input_names = ['input'],
                output_names = ['output'],
                dynamic_axes={'input' : {0 : 'batch_size', 2 : 'mel_length'},
                            'output' : {0 : 'batch_size', 2 : 'mel_length'}}
            )
            
            logger.info("Exporting lcfm model to ONNX")
            torch.onnx.export(
                enhancer.lcfm,
                normalizer_result,
                os.path.join(onnx_path, "lcfm.onnx"),
                export_params=True,
                opset_version=17,
                do_constant_folding=True,
                input_names = ['input'],
                output_names = ['output'],
                dynamic_axes={'input' : {0 : 'batch_size', 2 : 'mel_length'},
                            'output' : {0 : 'batch_size', 2 : 'mel_length'}}
            )
            
            logger.info("Exporting vocder model to ONNX")
            torch.onnx.export(
                enhancer.vocoder,
                lcfm_result,
                os.path.join(onnx_path, "vocder.onnx"),
                export_params=True,
                opset_version=17,
                do_constant_folding=True,
                input_names = ['input'],
                output_names = ['output'],
                dynamic_axes={'input' : {0 : 'batch_size', 2 : 'mel_length'},
                            'output' : {0 : 'batch_size', 1 : 'audio_length'}}
            )
    

if __name__ == "__main__":
    main()

@alexey-laletin-singularitynet have you tried running the exported onnx models with onnxruntime ? Here it fails when trying to run lcfm.onnx, with the following error:

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Conv node. Name:'/cfm/net/layers.1/dconv/Conv' Status Message: Dilation not supported for AutoPadType::SAME_UPPER or AutoPadType::SAME_LOWE

Which corresponds to this line: https://github.com/resemble-ai/resemble-enhance/blob/main/resemble_enhance/enhancer/lcfm/wn.py#L33 self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")

To do a quick check, just add this at the end of your onnx exporting code:

            import onnxruntime as ort
            print("checking normalizer...")
            ort_normalizer = ort.InferenceSession(os.path.join(onnx_path, "normalizer.onnx"))
            result = torch.tensor(ort_normalizer.run(None, {'input': mel_spectrogram.numpy()})).squeeze(0)
            print("checking lcfm...")
            ort_lcfm = ort.InferenceSession(os.path.join(onnx_path, "lcfm.onnx"))
            result = torch.tensor(ort_lcfm.run(None, {'input': normalizer_result.numpy()})).squeeze(0)
            print("checking vocoder...")
            ort_vocoder = ort.InferenceSession(os.path.join(onnx_path, "vocoder.onnx"))
            result = torch.tensor(ort_vocoder.run(None, {'input': lcfm_result.numpy()})).squeeze(0)
            print("checking done")

This is using torch 2.6.0, onnx 1.17.0, onnxruntime 1.21.0 with python 1.12 on macos.

divideconcept avatar Mar 09 '25 08:03 divideconcept

Ok, the trick was to replace instances of padding="same" with padding=int(dilation * (kernel_size - 1) // 2). https://github.com/microsoft/onnxruntime/issues/20582#issuecomment-2099348543

divideconcept avatar Mar 09 '25 12:03 divideconcept