dcase20_task4 icon indicating copy to clipboard operation
dcase20_task4 copied to clipboard

train function issue

Open JimLin1005 opened this issue 2 years ago • 3 comments

Sorry, I would like to ask for your help. In line 357 "loss_value = train(training_loader,...)" of main.py, the training_loader is consisted of our dataset "weak_data, unlabel_data, train_synth_data". However, in line 81 "for i, ((batch_input, ema_batch_input), target) in enumerate(train_loader):" of train function "def train(train_loader,...)", I think we should extract the dataset from "train_loader", but it seems that the batch_input/ema_batch_input/target is not extracted from weak_data/unlabel_data/train_synth_data. May I ask you what is the batch_input/ema_batch_input/target? Thanks.

JimLin1005 avatar Oct 28 '22 03:10 JimLin1005

Hi, I'm not sure I understood well the question. batch_input is the batch that we input in the student model :

  • Composed of a number of data in the batch corresponding as "batch_size", the distribution is as follow:
    • Defined by batch_sizes = [cfg.batch_size//4, cfg.batch_size//2, cfg.batch_size//4]
    • Meaning 1/4 is weakly labeled data, 1/2 is unlabeled data and 1/4 is synthetic data

ema_batch_input as the same distribution as batch_input (defined by batch_sizes). It is the batch that will be input of the teacher model. They are the same input as the batch_input but have a small perturbation.

target is also of size batch_size, and are the labels of the input_data. The labels are masked depending of their nature (weak, unlabeled, strong)

Does it help ?

turpaultn avatar Jan 04 '23 15:01 turpaultn

Hi, Sorry for my unclear description. The training_loader is composed of [1/4*(weak_data), 1/2*(unlabel_data), 1/4*(train_synth_data)]. In line 63/64 of main.py as below. #63# train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch. #64# Should return a tuple: ((teacher input, student input), labels) In line 81 "for i, ((batch_input, ema_batch_input), target) in enumerate(train_loader):" of main.py batch_input => teacher input ema_batch_input => student input target => labels It is not same as your comment "batch_input is the batch that we input in the student model". //////////////////////////////////////////////////////////////////////////////////////////////////////////// According to your comment, may I understand it as below? [batch_input] => 1/4*(weak_data), 1/2*(unlabel_data), 1/4*(train_synth_data) weak_data_1.wav weak_label_1 weak_data_2.wav weak_label_2 ... unlabel_data_1.wav unlabel_data_2.wav ... train_synth_data_1.wav train_synth_label_1 train_synth_data_2.wav train_synth_label_2 ... [ema_batch_input] => 1/4*(weak_data_have_a_small_perturbation), 1/2*(unlabel_data_have_a_small_perturbation), 1/4*(train_synth_data_have_a_small_perturbation) weak_data_have_a_small_perturbation_1.wav weak_label_1 weak_data_have_a_small_perturbation_2.wav weak_label_2 ... unlabel_data_have_a_small_perturbation_1.wav unlabel_data_have_a_small_perturbation_2.wav ... train_synth_data_have_a_small_perturbation_1.wav train_synth_label_1 train_synth_data_have_a_small_perturbation_2.wav train_synth_label_2 ... [target] weak_label_1 weak_label_2 ... ?? =>For unlabeled data, what is the label? train_synth_label_1 train_synth_label_2 ... Thanks.

JimLin1005 avatar Jan 05 '23 02:01 JimLin1005

Hi, Indeed, there is a mistake in the doc string, it is "((student input, teacher input), labels)".

So it is : batch_input => student input ema_batch_input => teacher input target => labels

Could you make a pull request please ? PS : ema stands for "exponential moving average" (it was a way to define the teacher, I realise "teacher would have been more explicit")

Tensors explanation

What you'll get is big tensors : "[[[...], [...]], [[...], [...]]]" "[[...],[...]]"

where the inputs and labels are always of the same dimensions.

Indeed, [batch_input] => 1/4*(weak_data), 1/2*(unlabel_data), 1/4*(train_synth_data) So the first 1/4 data represents :

weak_data_1.wav weak_label_1 --- weak_label_1 has the same shape as strongly labeled data, but 1 over the full column ... unlabel_data_1.wav unlabel_label_1 --- unlabel_label_1 has the same shape as strongly labeled data, but the value is -1 everywhere ... train_synth_data_1.wav train_synth_label_1 --- here, just strongly labeled data

Why doing all this ?

To make batches, you need tensors of the same dimension, so I decided to keep the strongly labeled shape for all the data, and make "masks" to use only the data needed for each loss.

Hope this helps.

turpaultn avatar Jan 10 '23 09:01 turpaultn