UMGUD
UMGUD copied to clipboard
Cannot reproduce the results
Hi, I was running your code on the MNIST dataset, using the default setting (w. 10,000 epochs). However, the results are pretty poor:
Loading MNIST dataset.
- Prec@1 95.363 svhn Loading SVHN dataset.
- Prec@1 8.510 mnist_m Loading MNIST_M dataset.
- Prec@1 27.202 syn Loading SYN dataset.
- Prec@1 24.518 usps Loading USPS dataset. label range [0-9]
- Prec@1 55.343
It is quite weird as I didn't touch any part of your code. Can you please share the config file you used for MNIST training? If possible, I am also interested in the implementation of the segmentation part. My email address is [email protected].
Thanks!
Hi @cherise215,
The results you provide are too low. I noticed that you have similar problems in M-ADA. The config in "main.py" is exactly the same I used for MNIST training.
May I ask do you use the specified version for pytorch and meta-nn? Would you please provide more detailed information?
Best, Fengchun
I am using PyTorch (1.9.1), MetaNN 0.1.5.
Would you please try Pytorch 1.1.0?
Sure. Can you please kindly share your environment config (requirements.txt) for me to replicate the environment? e.g. you can use pipreqs to generate the dependency list. Thanks.
Sure. I upgraded pytorch to 1.6.0 recently, and I tested it for both M-ADA and UMGUD. requirements.txt
Thanks for providing the file. I have replicated the environment. Yet, I got a very low performance. I posted here for reference.
Loading the MNIST dataset.
- Prec@1 96.474 svhn Loading SVHN dataset.
- Prec@1 9.098 mnist_m Loading MNIST_M dataset.
- Prec@1 27.302 syn Loading SYN dataset.
- Prec@1 24.664 usps Loading USPS dataset. label range [0-9]
- Prec@1 54.587 avg acc 28.912854143601493 mnist Loading MNIST dataset.
- Prec@1 96.474 svhn Loading SVHN dataset.
- Prec@1 9.094 mnist_m Loading MNIST_M dataset.
- Prec@1 27.302 syn Loading SYN dataset.
- Prec@1 24.685 usps Loading USPS dataset. label range [0-9]
- Prec@1 54.788 avg acc 28.967539710889078
This is the full output during training: <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< Loading MNIST dataset. Loading MNIST dataset. Training Iter: [0][0/10001] Loss 2.3055 (2.3055) Prec@1 12.500 (12.500) validation set acc 10.606971153846153 Iter: [1000][1000/10001] Loss 0.3086 (0.3086) Prec@1 84.375 (84.375) validation set acc 88.65184294871794 mnist Loading MNIST dataset.
- Prec@1 88.692 svhn Loading SVHN dataset.
- Prec@1 7.949 mnist_m Loading MNIST_M dataset.
- Prec@1 20.463 syn Loading SYN dataset.
- Prec@1 15.866 usps Loading USPS dataset. label range [0-9]
- Prec@1 60.081 avg acc 26.089606094490073 Iter: [2000][2000/10001] Loss 0.2688 (0.2688) Prec@1 90.625 (90.625) validation set acc 94.24078525641026 mnist Loading MNIST dataset.
- Prec@1 94.251 svhn Loading SVHN dataset.
- Prec@1 7.388 mnist_m Loading MNIST_M dataset.
- Prec@1 21.041 syn Loading SYN dataset.
- Prec@1 18.561 usps Loading USPS dataset. label range [0-9]
- Prec@1 62.147 avg acc 27.284276418652695 Iter: [3000][3000/10001] Loss 0.0595 (0.0595) Prec@1 100.000 (100.000) validation set acc 94.90184294871794 mnist Loading MNIST dataset.
- Prec@1 94.892 svhn Loading SVHN dataset.
- Prec@1 7.134 mnist_m Loading MNIST_M dataset.
- Prec@1 19.851 syn Loading SYN dataset.
- Prec@1 17.691 usps Loading USPS dataset. label range [0-9]
- Prec@1 57.107 avg acc 25.445690132951977 Iter: [4000][4000/10001] Loss 0.0470 (0.0470) Prec@1 100.000 (100.000) validation set acc 95.9735576923077 mnist Loading MNIST dataset.
- Prec@1 95.954 svhn Loading SVHN dataset.
- Prec@1 7.811 mnist_m Loading MNIST_M dataset.
- Prec@1 24.066 syn Loading SYN dataset.
- Prec@1 20.627 usps Loading USPS dataset. label range [0-9]
- Prec@1 57.006 avg acc 27.377390026811618 Iter: [5000][5000/10001] Loss 0.0116 (0.0116) Prec@1 100.000 (100.000) validation set acc 94.05048076923077 mnist Loading MNIST dataset.
- Prec@1 94.040 svhn Loading SVHN dataset.
- Prec@1 7.392 mnist_m Loading MNIST_M dataset.
- Prec@1 19.495 syn Loading SYN dataset.
- Prec@1 17.638 usps Loading USPS dataset. label range [0-9]
- Prec@1 50.958 avg acc 23.870699009183006 Iter: [6000][6000/10001] Loss 0.0084 (0.0084) Prec@1 100.000 (100.000) validation set acc 96.14383012820512 mnist Loading MNIST dataset.
- Prec@1 96.144 svhn Loading SVHN dataset.
- Prec@1 8.191 mnist_m Loading MNIST_M dataset.
- Prec@1 23.710 syn Loading SYN dataset.
- Prec@1 20.984 usps Loading USPS dataset. label range [0-9]
- Prec@1 53.831 avg acc 26.678840918707657 Iter: [7000][7000/10001] Loss 0.0040 (0.0040) Prec@1 100.000 (100.000) validation set acc 96.33413461538461 mnist Loading MNIST dataset.
- Prec@1 96.344 svhn Loading SVHN dataset.
- Prec@1 12.915 mnist_m Loading MNIST_M dataset.
- Prec@1 28.414 syn Loading SYN dataset.
- Prec@1 25.378 usps Loading USPS dataset. label range [0-9]
- Prec@1 57.359 avg acc 31.01641570125756 Iter: [8000][8000/10001] Loss 0.0019 (0.0019) Prec@1 100.000 (100.000) validation set acc 96.24399038461539 mnist Loading MNIST dataset.
- Prec@1 96.244 svhn Loading SVHN dataset.
- Prec@1 10.028 mnist_m Loading MNIST_M dataset.
- Prec@1 28.125 syn Loading SYN dataset.
- Prec@1 26.143 usps Loading USPS dataset. label range [0-9]
- Prec@1 57.560 avg acc 30.464241204539942 Iter: [9000][9000/10001] Loss 0.0015 (0.0015) Prec@1 100.000 (100.000) validation set acc 95.60296474358974 mnist Loading MNIST dataset.
- Prec@1 95.593 svhn Loading SVHN dataset.
- Prec@1 12.315 mnist_m Loading MNIST_M dataset.
- Prec@1 25.267 syn Loading SYN dataset.
- Prec@1 21.005 usps Loading USPS dataset. label range [0-9]
- Prec@1 49.294 avg acc 26.970342750560427 Iter: [10000][10000/10001] Loss 0.0001 (0.0001) Prec@1 100.000 (100.000) validation set acc 96.47435897435898 mnist Loading MNIST dataset.
- Prec@1 96.474 svhn Loading SVHN dataset.
- Prec@1 9.098 mnist_m Loading MNIST_M dataset.
- Prec@1 27.302 syn Loading SYN dataset.
- Prec@1 24.664 usps Loading USPS dataset. label range [0-9]
- Prec@1 54.587 avg acc 28.912854143601493 mnist Loading MNIST dataset.
- Prec@1 96.474 svhn Loading SVHN dataset.
- Prec@1 9.094 mnist_m Loading MNIST_M dataset.
- Prec@1 27.302 syn Loading SYN dataset.
- Prec@1 24.685 usps Loading USPS dataset. label range [0-9]
- Prec@1 54.788 avg acc 28.967539710889078
According to your log, the validation acc of MNIST is even lower than expected (around 98%). I suggest you set K=0 and run the experiments. This setting should reproduce the results of ERM. If the results are still far away from the expected, the problem may lie in the data you used. You can use the processed data I used in the experiment: https://drive.google.com/drive/folders/1__r_p5W_yCrC_nVxkJB3cuqJf7h0vQqe?usp=sharing
Here are several logs of mine for reference: 1.log 2.log 3.log
Thanks. I downloaded the data. the performance is higher (w/ k=20 by default) but surprisingly lower than your previous method M-ADA. Iter: [10000][10000/10001] Loss 0.0001 (0.0001) Prec@1 100.000 (100.000) validation set acc 98.58774038461539 mnist Loading MNIST dataset.
- Prec@1 98.598 svhn Loading SVHN dataset.
- Prec@1 34.064 mnist_m Loading MNIST_M dataset.
- Prec@1 65.414 syn Loading SYN dataset.
- Prec@1 52.936 usps Loading USPS dataset. label range [0-9]
- Prec@1 72.631 avg acc 56.26116105049551 mnist Loading MNIST dataset.
- Prec@1 98.588 svhn Loading SVHN dataset.
- Prec@1 34.067 mnist_m Loading MNIST_M dataset.
- Prec@1 65.391 syn Loading SYN dataset.
- Prec@1 52.978 usps Loading USPS dataset. label range [0-9]
- Prec@1 72.833 avg acc 56.31745130237201
I can't reference the issue of M-ADA. I just paste the same answer here.
In M-ADA, although I have fixed the random seeds, there still exists randomness due to the adversarial training and the training of WAE. In UM-GUD, the problem is more obvious since I proposed to perturb the features. I also tried to perturb the weights but it didn't work. The upper bound of UM-GUD should be higher than M-ADA. And the reported results are averaged over several runs. If you print the results every 1000 iterations, you will find the best results usually appear at 6000th or 7000th iteration and they decrease after that. UM-GUD is more sensitive to hyper-parameters such as T_adv, K, T_sample, advstart_iter, and T_min. I am not sure whether they are the same as the ones I used before the submission since I have many running files and it has been a while. I will try to run more experiments to find more stable hyper-parameters.
Understood. Thanks for your answer and your time.