M-ADA icon indicating copy to clipboard operation
M-ADA copied to clipboard

My traind Learner model can`t reach the result in paper

Open zhuyi3625 opened this issue 3 years ago • 12 comments

I trained a Learner model by the 'mnist', and test in 'mnist_m',however, the outputs of model have some mistakes , 10k input 'mnist_m' have the same output, so as in 'mnist' test dataset. I do not change any parameters, hyper-parameters or codes, what i have noticed is the train loss of new WAE is negative, what's wrong in my manipulation?

zhuyi3625 avatar Jul 17 '21 09:07 zhuyi3625

Here is the args: parser = argparse.ArgumentParser(description='Training on Digits') parser.add_argument('--data_dir', default='utils/data', type=str, help='dataset dir') parser.add_argument('--dataset', default='mnist', type=str, help='dataset mnist or cifar10') parser.add_argument('--num_iters', default=10001, type=int, help='number of total iterations to run') parser.add_argument('--start_iters', default=0, type=int, help='manual epoch number (useful on restarts)') parser.add_argument('-b', '--batch-size', default=32, type=int, help='mini-batch size (default: 128)') parser.add_argument('--lr', '--min-learning-rate', default=0.0001, type=float, help='initial learning rate') parser.add_argument('--lr_max', '--adv-learning-rate', default=1, type=float, help='adversarial learning rate') parser.add_argument('--gamma', default=1, type=float, help='coefficient of constraint') parser.add_argument('--beta', default=2000, type=float, help='coefficient of relaxation') parser.add_argument('--T_adv', default=25, type=int, help='iterations for adversarial training') parser.add_argument('--advstart_iter', default=0, type=int, help='iterations for pre-train') parser.add_argument('--K', default=3, type=int, help='number of augmented domains') parser.add_argument('--T_min', default=100, type=int, help='intervals between domain augmentation') parser.add_argument('--print-freq', '-p', default=1000, type=int, help='print frequency (default: 10)') parser.add_argument('--resume', default=None, type=str, help='path to saved checkpoint (default: none)') parser.add_argument('--name', default='Digits', type=str, help='name of experiment') parser.add_argument('--mode', default='train', type=str, help='train or test') parser.add_argument('--GPU_ID', default=0, type=int, help='GPU_id')

and the new WAE train 25 times a turn

zhuyi3625 avatar Jul 17 '21 09:07 zhuyi3625

I can not reproduce the results on mnist either, using the above config. I got: Loading SVHN dataset.

  • Prec@1 21.045 mnist_m Loading MNIST_M dataset.
  • Prec@1 56.951 syn Loading SYN dataset.
  • Prec@1 29.289 usps Loading USPS dataset. label range [0-9]
  • Prec@1 58.317 avg acc 48.1853883669931

The score is ridiculously low, which is even lower than the ones with ERM baseline reported in the paper.

cherise215 avatar Feb 08 '22 15:02 cherise215

Hi @zhuyi3625 ,

Sorry for the late reply. The config seems fine to me. You mentioned that all the outputs are the same. It seems the model is not trained in the right way. Would you please check the optimizer to see whether it works well and there are gradients to update the model?

Best, Fengchun

joffery avatar Feb 09 '22 19:02 joffery

Hi @cherise215 ,

The results are indeed ridiculously low. The most common reasons for previous problems are either the package version or the second order gradients. And I can't reach any concrete conclusion without any other details.

Best, Fengchun

joffery avatar Feb 09 '22 19:02 joffery

Thanks for the prompt reply. I am using the same version of metaNN (0.1.5) to support second-order gradients computation while using more advanced PyTorch (1.9.1). I thought using a more advanced PyTorch should yield higher performance as common sense. I am happy to replicate all of your environments for fair testing. Could you please provide a full list of requirements (e.g. requirements.txt) for me to set up a virtual env? This could be very helpful for me and others.

cherise215 avatar Feb 09 '22 19:02 cherise215

requirements.txt

joffery avatar Feb 09 '22 20:02 joffery

Thanks. It now works. The results are below, a bit higher than reported (maybe due to the upgraded pytorch): => loading checkpoint './runs/mnist/Digits/ckpt_mnist.pth.tar' => loaded checkpoint './runs/mnist/Digits/ckpt_mnist.pth.tar' (iter 10001) svhn Loading SVHN dataset.

  • Prec@1 45.691 mnist_m Loading MNIST_M dataset.
  • Prec@1 68.739 syn Loading SYN dataset.
  • Prec@1 49.423 usps Loading USPS dataset. label range [0-9]
  • Prec@1 78.125 avg acc 65.42903908619742

cherise215 avatar Feb 13 '22 18:02 cherise215

The reported results are averaged over several runs. Adversarial training brings additional randomness to the experiments.

joffery avatar Feb 16 '22 07:02 joffery

Sorry, I still found it didn't work. The previous post was reported by mistake as they are results from the downloaded checkpoints. If I trained from scratch, I still get very low scores. => loading checkpoint './runs/mnist/Digits/checkpoint.pth.tar' => loaded checkpoint './runs/mnist/Digits/checkpoint.pth.tar' (iter 10001) svhn Loading SVHN dataset.

  • Prec@1 20.407 mnist_m Loading MNIST_M dataset.
  • Prec@1 51.679 syn Loading SYN dataset.
  • Prec@1 27.265 usps Loading USPS dataset. label range [0-9]
  • Prec@1 54.284 avg acc 44.40954844244093

cherise215 avatar Feb 16 '22 21:02 cherise215

I think the problem is the MNIST data you used. USPS is very similar to MNIST. Even models without any augmentation can achieve > 70% acc on this dataset.

Please use the processed MNIST data (train/test.pkl https://drive.google.com/drive/folders/1__r_p5W_yCrC_nVxkJB3cuqJf7h0vQqe?usp=sharing) to run the experiments and let me know the results.

Did you use download_and_process_mnist.py to process your mnist data?

joffery avatar Feb 17 '22 06:02 joffery

I did preprocess the data but cannot produce your results. Thanks for sharing us with the processed one. Now, I got much higher performance: Loading SVHN dataset.

  • Prec@1 35.928 mnist_m Loading MNIST_M dataset.
  • Prec@1 68.639 syn Loading SYN dataset.
  • Prec@1 47.357 usps Loading USPS dataset. label range [0-9]
  • Prec@1 79.083 avg acc 65.02627795874845

However, it seems that the performance is even higher than your SOTA: UM-GUD? And may I know why your method is so sensitive to data preprocessing? Is there any specific assumption for your method?

cherise215 avatar Feb 23 '22 16:02 cherise215

The train/test.pkl is generated using download_and_process_mnist.py. And download_and_process_mnist.py is exactly the same as the one used in: https://github.com/ricvolpi/generalize-unseen-domains
There is nothing special about the data pre-processing. There must be something wrong when you processed the data such as scaling to 0-1, label mismatching between mnist and other datasets, etc. That is the reason why you can achieve >95% acc on MNIST while only >50% on USPS.

Although I have fixed the random seeds, there still exists randomness due to the adversarial training and the training of WAE. In UM-GUD, the problem is more obvious since I proposed to perturb the features. I also tried to perturb the weights but it didn't work. The upper bound of UM-GUD should be higher than M-ADA. And the reported results are averaged over several runs. If you print the results every 1000 iterations, you will find the best results usually appear at 6000th or 7000th iteration and they decrease after that. UM-GUD is more sensitive to hyper-parameters such as T_adv, K, T_sample, advstart_iter, and T_min. I am not sure whether they are the same as the ones I used before the submission since I have many running files and it has been a while. I will try to run more experiments to find more stable hyper-parameters.

joffery avatar Feb 23 '22 19:02 joffery