TRUNet icon indicating copy to clipboard operation
TRUNet copied to clipboard

Input Feature to TRUNET

Open yugeshav opened this issue 2 years ago • 71 comments

Hi

As per the paper, 4 features must be concatenated as input to TRUNET,

  1. log spectrum
  2. PCEN
  3. real part of demodulated phase
  4. imaginary part of demodulated phase

so the input will become (Batchsize, 4 features, No.of frames in STFT, No.of STFT bins) , so it is a 4 dimesional one

But in the sample code you are showing input as 3 dimension (1,4,257), since first layer is conv1d

I'm confused whether the input to TRUNET is 3 dimension or 4 dimension ?

Regards Yugesh

yugeshav avatar Jun 07 '22 04:06 yugeshav

I think the input tensor in sample code is one-frame feature. If you want to feed a wav into the model, the input dimension might be (B,4,frames,257), but I'm not sure. Please email me ([email protected]) if you have any insight.

AmosCch avatar Jun 15 '22 02:06 AmosCch

@yugeshav Hey! Any progress on this? I am also confused with the input shape

amirpashamobinitehrani avatar Jan 24 '23 09:01 amirpashamobinitehrani

@amirpashamobinitehrani The input shape for 1D conv is: (T, C,F) (Time frames, Channels(4 features), Frequency bins).

atabakp avatar Feb 08 '23 18:02 atabakp

Thanks for you reply. Interesting! Yes, I had some presumptions. What still remains a mystery to me is to inject batch dimension into the play.

(Batch, Time frames, Channels(4 features), Frequency bins)

Which I assume we should refrain from. Right? We are simply processing 4 different features of 1 audio file in (time-frame) steps. So the time-frame dimension is fulfilling Batch dimension's role.

amirpashamobinitehrani avatar Feb 09 '23 10:02 amirpashamobinitehrani

Correct!Each frame is a data sample here. If you want to use the (Batch, Time, Features, Frequency) you should use 2D Convolution and set the filters’ dimension to (n, 1).

atabakp avatar Feb 09 '23 18:02 atabakp

Hi,

I had the same question. Has anyone been able to successfully train this network? I think that as @atabakp mentioned, the input has to have shape (time_frames, features, fft_size // 2 + 1) so when a batch is being used, the time_frames axis will grow. Since this is assume to be the N input of a nn.Conv1d, the processing will still be frame-independent so bigger batch sizes would mean a bigger stack of frames. Could someone confirm this?

Thanks, Esteban

eagomez2 avatar May 05 '23 07:05 eagomez2

Hi,

I had the same question. Has anyone been able to successfully train this network? I think that as @atabakp mentioned, the input has to have shape (time_frames, features, fft_size // 2 + 1) so when a batch is being used, the time_frames axis will grow. Since this is assume to be the N input of a nn.Conv1d, the processing will still be frame-independent so bigger batch sizes would mean a bigger stack of frames. Could someone confirm this?

Thanks, Esteban

Hi Esteban, I am able to train this model. yes, you are right.

atabakp avatar May 05 '23 18:05 atabakp

Thanks @atabakp !

As a follow-up question: How are you obtaining the "demodulated phase"?

eagomez2 avatar May 06 '23 23:05 eagomez2

There are a few methods to do this, but I don't know what the Authors exactly mean. for example https://arxiv.org/pdf/1608.01953.pdf

But for my training, I used Log Magnitude and normalized real/imag as inputs.

atabakp avatar May 08 '23 19:05 atabakp

I managed to implement the demodulated phase, using (log_magnitude, demod_real, demod_imag) as inputs to train the model. For some reasons, I am not witnessing the model successfully doing anything useful. It would be nice to get some insights regarding the implementations if any has made a promising progress on this!

amirpashamobinitehrani avatar May 08 '23 20:05 amirpashamobinitehrani

Thanks once again @atabakp! I was thinking something similar:

  1. Use log magnitude (as in the paper)
  2. Use PCEN output (as in the paper) For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.

One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.

eagomez2 avatar May 09 '23 07:05 eagomez2

Thanks once again @atabakp! I was thinking something similar:

  1. Use log magnitude (as in the paper)
  2. Use PCEN output (as in the paper) For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.

One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.

Yes, you are right. The output is 10 channels. 2 sets of 5 channels; one set is for predicting direct, 2nd set is for predicting Noise, and you can derive the reverberation by having direct and Noise. 1- z(k)t,f 2- z(¬k)t,f 3- φ for the next 2 channels, refer to eq(3):https://arxiv.org/pdf/2006.00687.pdf 4- γ(0)(qt,f ) 5- γ(1)(qt,f ) if my assumption about the channels is correct, then we don't need 2 separate channels for magnitude(channel 1,2); one is the complement of the other.

atabakp avatar May 09 '23 17:05 atabakp

Thanks once again @atabakp! I was thinking something similar:

  1. Use log magnitude (as in the paper)
  2. Use PCEN output (as in the paper) For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.

One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.

Yes, you are right. The output is 10 channels. 2 sets of 5 channels; one set is for predicting direct, 2nd set is for predicting Noise, and you can derive the reverberation by having direct and Noise. 1- z(k)t,f 2- z(¬k)t,f 3- φ for the next 2 channels, refer to eq(3):https://arxiv.org/pdf/2006.00687.pdf 4- γ(0)(qt,f ) 5- γ(1)(qt,f ) if my assumption about the channels is correct, then we don't need 2 separate channels for magnitude(channel 1,2); one is the complement of the other.

Thanks a lot once again, @atabakp ! I'll report back my progress as I manage to allocate time for working on it

eagomez2 avatar May 11 '23 12:05 eagomez2

Section 3 of this paper also has some information about phase demodulation: https://www.isca-speech.org/archive_v0/Interspeech_2018/pdfs/1773.pdf

atabakp avatar May 16 '23 06:05 atabakp

Hi again @atabakp ,

How are you get the 10 channels? I looked again into the model's code and I'm getting only 5 channels. Here I'm attaching the I/O of each layer:

module type input_shape output_shape
root TRUNet (1, 4, 257) (1, 5, 257)
down1 StandardConv1d (1, 4, 257) (1, 64, 128)
down1.StandardConv1d Sequential (1, 4, 257) (1, 64, 128)
down1.StandardConv1d.0 Conv1d (1, 4, 257) (1, 64, 128)
down1.StandardConv1d.1 ReLU (1, 64, 128) (1, 64, 128)
down2 DepthwiseSeparableConv1d (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d Sequential (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.0 Conv1d (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.5 ReLU (1, 128, 128) (1, 128, 128)
down3 DepthwiseSeparableConv1d (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d Sequential (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64)
down3.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64)
down4 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64)
down5 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 32) (1, 128, 32)
down5.DepthwiseSeparableConv1d.5 ReLU (1, 128, 32) (1, 128, 32)
down6 DepthwiseSeparableConv1d (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d Sequential (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.2 ReLU (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 16) (1, 128, 16)
down6.DepthwiseSeparableConv1d.5 ReLU (1, 128, 16) (1, 128, 16)
FGRU GRUBlock (1, 16, 128) (1, 64, 16)
FGRU.GRU GRU (1, 16, 128) ((1, 16, 128), (2, 1, 64))
FGRU.conv Sequential (1, 128, 16) (1, 64, 16)
FGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16)
FGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
FGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16)
TGRU GRUBlock (1, 16, 64) (1, 64, 16)
TGRU.GRU GRU (1, 16, 64) ((1, 16, 128), (1, 1, 128))
TGRU.conv Sequential (1, 128, 16) (1, 64, 16)
TGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16)
TGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
TGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16)
up1 FirstTrCNN (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN Sequential (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN.0 Conv1d (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.2 ReLU (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.3 ConvTranspose1d (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN.4 BatchNorm1d (1, 64, 31) (1, 64, 31)
up1.FirstTrCNN.5 ReLU (1, 64, 31) (1, 64, 31)
up2 TrCNN (1, 64, 31) (1, 64, 65)
up2.TrCNN Sequential (1, 192, 32) (1, 64, 65)
up2.TrCNN.0 Conv1d (1, 192, 32) (1, 64, 32)
up2.TrCNN.1 BatchNorm1d (1, 64, 32) (1, 64, 32)
up2.TrCNN.2 ReLU (1, 64, 32) (1, 64, 32)
up2.TrCNN.3 ConvTranspose1d (1, 64, 32) (1, 64, 65)
up2.TrCNN.4 BatchNorm1d (1, 64, 65) (1, 64, 65)
up2.TrCNN.5 ReLU (1, 64, 65) (1, 64, 65)
up3 TrCNN (1, 64, 65) (1, 64, 66)
up3.TrCNN Sequential (1, 192, 64) (1, 64, 66)
up3.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64)
up3.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64)
up3.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64)
up3.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 66)
up3.TrCNN.4 BatchNorm1d (1, 64, 66) (1, 64, 66)
up3.TrCNN.5 ReLU (1, 64, 66) (1, 64, 66)
up4 TrCNN (1, 64, 66) (1, 64, 129)
up4.TrCNN Sequential (1, 192, 64) (1, 64, 129)
up4.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64)
up4.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64)
up4.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64)
up4.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 129)
up4.TrCNN.4 BatchNorm1d (1, 64, 129) (1, 64, 129)
up4.TrCNN.5 ReLU (1, 64, 129) (1, 64, 129)
up5 TrCNN (1, 64, 129) (1, 64, 130)
up5.TrCNN Sequential (1, 192, 128) (1, 64, 130)
up5.TrCNN.0 Conv1d (1, 192, 128) (1, 64, 128)
up5.TrCNN.1 BatchNorm1d (1, 64, 128) (1, 64, 128)
up5.TrCNN.2 ReLU (1, 64, 128) (1, 64, 128)
up5.TrCNN.3 ConvTranspose1d (1, 64, 128) (1, 64, 130)
up5.TrCNN.4 BatchNorm1d (1, 64, 130) (1, 64, 130)
up5.TrCNN.5 ReLU (1, 64, 130) (1, 64, 130)
up6 LastTrCNN (1, 64, 130) (1, 5, 257)
up6.LastTrCNN Sequential (1, 128, 128) (1, 5, 257)
up6.LastTrCNN.0 Conv1d (1, 128, 128) (1, 5, 128)
up6.LastTrCNN.1 BatchNorm1d (1, 5, 128) (1, 5, 128)
up6.LastTrCNN.2 ReLU (1, 5, 128) (1, 5, 128)
up6.LastTrCNN.3 ConvTranspose1d (1, 5, 128) (1, 5, 257)

eagomez2 avatar May 18 '23 12:05 eagomez2

I also have a question about the TGRU along the same lines. According to the paper:

The decoder is composed of a Time-axis Gated Recurrent Unit (TGRU) block and 1D Transposed Convolutional Neural Network (1D-TrCNN) blocks. The output of the encoder is passed into a unidirectional GRU layer to aggregate the information along the timeaxis

But then, the input to this layer is a (1, 16, 64) and according to PyTorch's GRU documentation, when batch_first=True, the 2nd dimension is the sequence length, which is the case here because batch_first defaults to True and is not changed when the TGRU layer is defined: https://github.com/YangangCao/TRUNet/blob/main/TRUNet.py#LL131C26-L131C26

To my understanding (please correct me if I'm wrong), the TGRU layer will not really aggregate information along the time axis, but will instead do a similar role than the FGRU, but using a unidirectional layer. I assumed first that batch_first should be set to False in order to apply the nn.GRU along the first dimension which is the original time dimension.

eagomez2 avatar May 18 '23 12:05 eagomez2

Hi again @atabakp ,

How are you get the 10 channels? I looked again into the model's code and I'm getting only 5 channels. Here I'm attaching the I/O of each layer:

module type input_shape output_shape root TRUNet (1, 4, 257) (1, 5, 257) down1 StandardConv1d (1, 4, 257) (1, 64, 128) down1.StandardConv1d Sequential (1, 4, 257) (1, 64, 128) down1.StandardConv1d.0 Conv1d (1, 4, 257) (1, 64, 128) down1.StandardConv1d.1 ReLU (1, 64, 128) (1, 64, 128) down2 DepthwiseSeparableConv1d (1, 64, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d Sequential (1, 64, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.0 Conv1d (1, 64, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 128) (1, 128, 128) down2.DepthwiseSeparableConv1d.5 ReLU (1, 128, 128) (1, 128, 128) down3 DepthwiseSeparableConv1d (1, 128, 128) (1, 128, 64) down3.DepthwiseSeparableConv1d Sequential (1, 128, 128) (1, 128, 64) down3.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 128) (1, 128, 128) down3.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128) down3.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128) down3.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 64) down3.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64) down3.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64) down4 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64) down4.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64) down5 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 32) down5.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 32) down5.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64) down5.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64) down5.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64) down5.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 32) down5.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 32) (1, 128, 32) down5.DepthwiseSeparableConv1d.5 ReLU (1, 128, 32) (1, 128, 32) down6 DepthwiseSeparableConv1d (1, 128, 32) (1, 128, 16) down6.DepthwiseSeparableConv1d Sequential (1, 128, 32) (1, 128, 16) down6.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 32) (1, 128, 32) down6.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 32) (1, 128, 32) down6.DepthwiseSeparableConv1d.2 ReLU (1, 128, 32) (1, 128, 32) down6.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 32) (1, 128, 16) down6.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 16) (1, 128, 16) down6.DepthwiseSeparableConv1d.5 ReLU (1, 128, 16) (1, 128, 16) FGRU GRUBlock (1, 16, 128) (1, 64, 16) FGRU.GRU GRU (1, 16, 128) ((1, 16, 128), (2, 1, 64)) FGRU.conv Sequential (1, 128, 16) (1, 64, 16) FGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16) FGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16) FGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16) TGRU GRUBlock (1, 16, 64) (1, 64, 16) TGRU.GRU GRU (1, 16, 64) ((1, 16, 128), (1, 1, 128)) TGRU.conv Sequential (1, 128, 16) (1, 64, 16) TGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16) TGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16) TGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16) up1 FirstTrCNN (1, 64, 16) (1, 64, 31) up1.FirstTrCNN Sequential (1, 64, 16) (1, 64, 31) up1.FirstTrCNN.0 Conv1d (1, 64, 16) (1, 64, 16) up1.FirstTrCNN.1 BatchNorm1d (1, 64, 16) (1, 64, 16) up1.FirstTrCNN.2 ReLU (1, 64, 16) (1, 64, 16) up1.FirstTrCNN.3 ConvTranspose1d (1, 64, 16) (1, 64, 31) up1.FirstTrCNN.4 BatchNorm1d (1, 64, 31) (1, 64, 31) up1.FirstTrCNN.5 ReLU (1, 64, 31) (1, 64, 31) up2 TrCNN (1, 64, 31) (1, 64, 65) up2.TrCNN Sequential (1, 192, 32) (1, 64, 65) up2.TrCNN.0 Conv1d (1, 192, 32) (1, 64, 32) up2.TrCNN.1 BatchNorm1d (1, 64, 32) (1, 64, 32) up2.TrCNN.2 ReLU (1, 64, 32) (1, 64, 32) up2.TrCNN.3 ConvTranspose1d (1, 64, 32) (1, 64, 65) up2.TrCNN.4 BatchNorm1d (1, 64, 65) (1, 64, 65) up2.TrCNN.5 ReLU (1, 64, 65) (1, 64, 65) up3 TrCNN (1, 64, 65) (1, 64, 66) up3.TrCNN Sequential (1, 192, 64) (1, 64, 66) up3.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64) up3.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64) up3.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64) up3.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 66) up3.TrCNN.4 BatchNorm1d (1, 64, 66) (1, 64, 66) up3.TrCNN.5 ReLU (1, 64, 66) (1, 64, 66) up4 TrCNN (1, 64, 66) (1, 64, 129) up4.TrCNN Sequential (1, 192, 64) (1, 64, 129) up4.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64) up4.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64) up4.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64) up4.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 129) up4.TrCNN.4 BatchNorm1d (1, 64, 129) (1, 64, 129) up4.TrCNN.5 ReLU (1, 64, 129) (1, 64, 129) up5 TrCNN (1, 64, 129) (1, 64, 130) up5.TrCNN Sequential (1, 192, 128) (1, 64, 130) up5.TrCNN.0 Conv1d (1, 192, 128) (1, 64, 128) up5.TrCNN.1 BatchNorm1d (1, 64, 128) (1, 64, 128) up5.TrCNN.2 ReLU (1, 64, 128) (1, 64, 128) up5.TrCNN.3 ConvTranspose1d (1, 64, 128) (1, 64, 130) up5.TrCNN.4 BatchNorm1d (1, 64, 130) (1, 64, 130) up5.TrCNN.5 ReLU (1, 64, 130) (1, 64, 130) up6 LastTrCNN (1, 64, 130) (1, 5, 257) up6.LastTrCNN Sequential (1, 128, 128) (1, 5, 257) up6.LastTrCNN.0 Conv1d (1, 128, 128) (1, 5, 128) up6.LastTrCNN.1 BatchNorm1d (1, 5, 128) (1, 5, 128) up6.LastTrCNN.2 ReLU (1, 5, 128) (1, 5, 128) up6.LastTrCNN.3 ConvTranspose1d (1, 5, 128) (1, 5, 257)

I answer myself about this one. The paper config listing for the decoder says:

DecoderConfig = {1-th: (3,2,64), 2-th: (5,2,64), 3-th: (3,1,64), 4-th: (5,2,64), 5-th: (3,1,64), 6-th: (5,2,10)}

where the last number is the number of channels, therefore you're right, they should be 10 instead.

eagomez2 avatar May 18 '23 16:05 eagomez2

I also have a question about the TGRU along the same lines. According to the paper:

The decoder is composed of a Time-axis Gated Recurrent Unit (TGRU) block and 1D Transposed Convolutional Neural Network (1D-TrCNN) blocks. The output of the encoder is passed into a unidirectional GRU layer to aggregate the information along the timeaxis

But then, the input to this layer is a (1, 16, 64) and according to PyTorch's GRU documentation, when batch_first=True, the 2nd dimension is the sequence length, which is the case here because batch_first defaults to True and is not changed when the TGRU layer is defined: https://github.com/YangangCao/TRUNet/blob/main/TRUNet.py#LL131C26-L131C26

To my understanding (please correct me if I'm wrong), the TGRU layer will not really aggregate information along the time axis, but will instead do a similar role than the FGRU, but using a unidirectional layer. I assumed first that batch_first should be set to False in order to apply the nn.GRU along the first dimension which is the original time dimension.

https://github.com/YangangCao/TRUNet/issues/4#issuecomment-1182544756

atabakp avatar May 18 '23 18:05 atabakp

Hi @atabakp ,

Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.

    # Control random seed
    rand_seed = torch.manual_seed(0)

    # Lets assume it has shape (1, 5, 257) (the expected output for a single source)
    # Since the activation function is ReLU, values can be equal or greater
    # than 0
    x_features = torch.rand((1, 5, 257), dtype=torch.float32)
    
    # Extract z_tf for target and residual
    z_tf = x_features[:, 0:1, :]
    z_tf_residual = x_features[:, 1:2, :]

    # Extract phi
    phi = x_features[:, 2:3, :]

    # Estimate beta (due to softplus it will be one or greater)
    beta = 1.0 + F.softplus(phi)

    # Estimate sigmod of target and residual
    sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
    sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)

    # Estimate upper bound for beta
    beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)

    # Because of the absolute value in the denominator, the same upper bound
    # can be applied to both betas
    beta = torch.clip(beta, max=beta_upper_bound)

    # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

    # Now that we have both masks, let's compute the triangle cosine law
    cos_phase = (
        (1.0 + mask_tf.square() - mask_tf_residual.square())
        / (2.0 * mask_tf))
    
    # Use trigonometric identity to obtain the sine
    sin_phase = torch.sqrt(1.0 - cos_phase.square())

    # Now estimate the sign
    q0 = x_features[:, 3:4, :]
    q1 = x_features[:, 4:5, :]
    gamma0 = F.gumbel_softmax(q0, tau=1.0)
    gamma1 = F.gumbel_softmax(q1, tau=1.0)
    sign = torch.where(gamma0 > gamma1, -1.0, 1.0)

    # Finally, estimate the complex mask
    complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)

    # Then it should be applied to the stft and inverted using the istft
    ...

eagomez2 avatar May 23 '23 15:05 eagomez2

Hi @atabakp ,

Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.

    # Control random seed
    rand_seed = torch.manual_seed(0)

    # Lets assume it has shape (1, 5, 257) (the expected output for a single source)
    # Since the activation function is ReLU, values can be equal or greater
    # than 0
    x_features = torch.rand((1, 5, 257), dtype=torch.float32)
    
    # Extract z_tf for target and residual
    z_tf = x_features[:, 0:1, :]
    z_tf_residual = x_features[:, 1:2, :]

    # Extract phi
    phi = x_features[:, 2:3, :]

    # Estimate beta (due to softplus it will be one or greater)
    beta = 1.0 + F.softplus(phi)

    # Estimate sigmod of target and residual
    sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
    sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)

    # Estimate upper bound for beta
    beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)

    # Because of the absolute value in the denominator, the same upper bound
    # can be applied to both betas
    beta = torch.clip(beta, max=beta_upper_bound)

    # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

    # Now that we have both masks, let's compute the triangle cosine law
    cos_phase = (
        (1.0 + mask_tf.square() - mask_tf_residual.square())
        / (2.0 * mask_tf))
    
    # Use trigonometric identity to obtain the sine
    sin_phase = torch.sqrt(1.0 - cos_phase.square())

    # Now estimate the sign
    q0 = x_features[:, 3:4, :]
    q1 = x_features[:, 4:5, :]
    gamma0 = F.gumbel_softmax(q0, tau=1.0)
    gamma1 = F.gumbel_softmax(q1, tau=1.0)
    sign = torch.where(gamma0 > gamma1, -1.0, 1.0)

    # Finally, estimate the complex mask
    complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)

    # Then it should be applied to the stft and inverted using the istft
    ...

1- I didn't apply ReLu for the last layer(x_features). 2- sigmoid_tf_residual = 1-sigmoid_tf This is always true, so I think do not need two outputs for this, but based on the paper, your code is correct. 3- You should always handle the division, log, sqrt,... in your code; for example, in cos_phase there is a chance that mask_tf becomes zero. 4- For obtaining sin, I am using acos: torch.sin(torch.acos(torch.clamp(cos_phase, min=-1 + eps, max=1 - eps)))

5-Estimating sign is a bit confusing, and I believe there is a typo in the formula of the paper. (I believe sign does not much matter for the performance) This is how I implemented it: gamma = torch.nn.functional.gumbel_softmax( torch.stack([q0, q1], dim=-1), tau=0.5, hard=False, ) gamma_0 = gamma[..., 0] gamma_1 = gamma[..., 1]

sign = torch.where(gamma_0 > gamma_1, -1.0, 1.0)

atabakp avatar May 23 '23 19:05 atabakp

Hi @atabakp , Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.

    # Control random seed
    rand_seed = torch.manual_seed(0)

    # Lets assume it has shape (1, 5, 257) (the expected output for a single source)
    # Since the activation function is ReLU, values can be equal or greater
    # than 0
    x_features = torch.rand((1, 5, 257), dtype=torch.float32)
    
    # Extract z_tf for target and residual
    z_tf = x_features[:, 0:1, :]
    z_tf_residual = x_features[:, 1:2, :]

    # Extract phi
    phi = x_features[:, 2:3, :]

    # Estimate beta (due to softplus it will be one or greater)
    beta = 1.0 + F.softplus(phi)

    # Estimate sigmod of target and residual
    sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
    sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)

    # Estimate upper bound for beta
    beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)

    # Because of the absolute value in the denominator, the same upper bound
    # can be applied to both betas
    beta = torch.clip(beta, max=beta_upper_bound)

    # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

    # Now that we have both masks, let's compute the triangle cosine law
    cos_phase = (
        (1.0 + mask_tf.square() - mask_tf_residual.square())
        / (2.0 * mask_tf))
    
    # Use trigonometric identity to obtain the sine
    sin_phase = torch.sqrt(1.0 - cos_phase.square())

    # Now estimate the sign
    q0 = x_features[:, 3:4, :]
    q1 = x_features[:, 4:5, :]
    gamma0 = F.gumbel_softmax(q0, tau=1.0)
    gamma1 = F.gumbel_softmax(q1, tau=1.0)
    sign = torch.where(gamma0 > gamma1, -1.0, 1.0)

    # Finally, estimate the complex mask
    complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)

    # Then it should be applied to the stft and inverted using the istft
    ...

1- I didn't apply ReLu for the last layer(x_features). 2- sigmoid_tf_residual = 1-sigmoid_tf This is always true, so I think do not need two outputs for this, but based on the paper, your code is correct. 3- You should always handle the division, log, sqrt,... in your code; for example, in cos_phase there is a chance that mask_tf becomes zero. 4- For obtaining sin, I am using acos: torch.sin(torch.acos(torch.clamp(cos_phase, min=-1 + eps, max=1 - eps)))

5-Estimating sign is a bit confusing, and I believe there is a typo in the formula of the paper. (I believe sign does not much matter for the performance) This is how I implemented it: gamma = torch.nn.functional.gumbel_softmax( torch.stack([q0, q1], dim=-1), tau=0.5, hard=False, ) gamma_0 = gamma[..., 0] gamma_1 = gamma[..., 1]

sign = torch.where(gamma_0 > gamma_1, -1.0, 1.0)

Thanks you very much @atabakp !

  • I tried clamping before with no luck but I ended up discovering that the problem was apparently in one of my losses since not it seems to work as expected.
  • I also can corroborate what you mention about simplifying sigmoid_tf_residual and it works with the simpler version.
  • I was also hesitant about the sign prediction since the phase should already be fit to the triangle inequality, but I haven't done yet and comparison with and without it. I'll update my conclusions as soon as I can check them.

eagomez2 avatar May 24 '23 10:05 eagomez2

Hi again @atabakp ,

When training the model, are you using 2s audio as the paper claims or are you using gradient accumulation or something like that to pass more data between steps?

I'm currently trying to train the model for dereverberation only, but 2s per audio in all case is very slow to train. So far I haven't reached to point to evaluate how successful the model is in the task, but it doesn't seem that it's learning quickly.

eagomez2 avatar Jun 01 '23 07:06 eagomez2

Hi again @atabakp ,

When training the model, are you using 2s audio as the paper claims or are you using gradient accumulation or something like that to pass more data between steps?

I'm currently trying to train the model for dereverberation only, but 2s per audio in all case is very slow to train. So far I haven't reached to point to evaluate how successful the model is in the task, but it doesn't seem that it's learning quickly.

I am using random-length sequences, single sequence per iteration (batch size =1)

atabakp avatar Jun 01 '23 07:06 atabakp

Sorry for necroposting here, but I'm trying to train this model, and with no luck yet. I managed to add trainable PCEN (as described in paper) and training on spectrograms. I construct input feature from PCEN (output of trainable layer), log magnitude, real and imag parts of STFT and feed it to the rest of the model described here. I also implemented 2d convs since I wanted to train on batches. Losses are the same as in the paper - multires cosine similarity + multires spectrum MSE. Model trains very weirdly (loss is decreasing for the first couple of hundreds of steps, then increasing, then decreasing again). @atabakp @eagomez2 I assume you managed to train this model - can you evaluate, what did you change compared to paper? Or maybe share your pipeline. Thanks in advance.

JBloodless avatar Feb 05 '24 15:02 JBloodless

Hi @JBloodless ,

In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).

eagomez2 avatar Feb 05 '24 16:02 eagomez2

Hi @JBloodless ,

In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).

Can you elaborate a bit, what do you mean by GAN loss?

JBloodless avatar Feb 06 '24 08:02 JBloodless

Check section 2 of this paper: https://arxiv.org/pdf/2010.10677.pdf

Hi @JBloodless , In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).

Can you elaborate a bit, what do you mean by GAN loss?

Check section 2 of this paper: https://arxiv.org/pdf/2010.10677.pdf

eagomez2 avatar Feb 06 '24 14:02 eagomez2

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

JBloodless avatar Feb 12 '24 13:02 JBloodless

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

eagomez2 avatar Feb 12 '24 13:02 eagomez2

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

What do you mean by repeating? I thought that network (in this implementation) returns one set of features for PHM (time, 5, bins), and corresponding PHM will be mask for direct source. Since I need to obtain only direct source (clean speech), I just multiply this PHM with input spectrum, and I get clean output. What did I assume wrong?

JBloodless avatar Feb 12 '24 13:02 JBloodless