deep_feature_reweighting icon indicating copy to clipboard operation
deep_feature_reweighting copied to clipboard

Base Model Checkpoints

Open Haoxiang-Wang opened this issue 2 years ago • 5 comments

Dear authors,

Could you provide the checkpoints (i.e., saved weights) of the base models used in your paper? I run your commands on CelebA & Waterbirds (for 5 random seeds), and the performance of base models & DFR on top of these base models is slightly worse than that reported in your paper. Thus, I want to request your trained base models for an exact reproduction & further comparison. I would highly appreciate it if you could provide a downloadable link to a Dropbox/Box/Google Drive folder containing your trained models. Thanks!

@andrewgordonwilson @PolinaKirichenko @izmailovpavel

Haoxiang-Wang avatar Nov 16 '22 08:11 Haoxiang-Wang

Hey Authors, @PolinaKirichenko

Can you please provide us with the checkpoints to match the results shown in the paper? By checkpoints, I mean both the pretrained ones and the final ones with DFR.

sanyalsunny111 avatar Jan 16 '23 05:01 sanyalsunny111

Hey @Haoxiang-Wang, @sanyalsunny111!

I re-ran 5 checkpoints for Waterbirds and CelebA and uploaded them to this google drive.

The results are the following:

  • Waterbirds: 92.0 ± 0.9 worst group accuracy
  • CelebA: 88.02 ± 1.6 worst group accuracy

Note that I used the newer repo here: spurious_feature_learning.

The DFR commands:

python3 dfr_evaluate_spurious.py --data_dir=/datasets/CelebA/ --data_transform=AugWaterbirdsCelebATransform --dataset=SpuriousCorrelationDataset --model=imagenet_resnet50_pretrained --ckpt_path=logs/celeba/erm_seed1/final_checkpoint.pt --result_path=celeba_erm_seed1_dfr.pkl --save_linear_model

python3 dfr_evaluate_spurious.py --data_dir=/datasets/waterbirds/ --data_transform=AugWaterbirdsCelebATransform --dataset=SpuriousCorrelationDataset --model=imagenet_resnet50_pretrained --ckpt_path=logs/waterbirds/erm_seed1/final_checkpoint.pt --result_path=wb_erm_seed1_dfr.pkl --save_linear_model

In the google drive, each dataset has 5 subfolders, and each of those contains the base model checkpoint, training logs, command used to train it and the last layer checkpoint trained by DFR.

Please let me know if you have issues with these checkpoints.

izmailovpavel avatar Jan 17 '23 22:01 izmailovpavel

@Haoxiang-Wang Thank you very much

sanyalsunny111 avatar Jan 18 '23 17:01 sanyalsunny111

@Haoxiang-Wang, I am a bit confused with the evaluation.

I have used a celeba ckpt provide by you and used this script for evaluation

python3 dfr_evaluate_spurious.py --data_dir=./data/celebA_v1.0/ --ckpt_path=dfr-ckpts/celeba/erm_seed1/final_checkpoint.pt --result_path=celeba_erm_seed1_dfr.pkl

I see multiple accuracies Can you please confirm which result to check the worst group accuracy?

image

sanyalsunny111 avatar Jan 20 '23 07:01 sanyalsunny111

Hi @sanyalsunny111, if you want to get the results for DFR_Val, you should be looking at the results under "DFR on Validation", and then test_worst_acc, which in your screenshot is 87.7%. Note again that in my previous post I provided results achieved with the updated repo here, and the commands are for that repo. You should be able to get similar results with this repo too though.

izmailovpavel avatar Jan 20 '23 14:01 izmailovpavel