fish
fish copied to clipboard
Error in training on cdsprites with ERM
There is a filter of the indices that selects only latents[:, 0] != 3
.
https://github.com/YugeTen/fish/blob/333efa24572d99da0a4107ab9cc4af93a915d2a9/src/models/datasets.py#L266-L271
In the code above, this condition is applied on val
and test
, but not on train
. This results in an error in training with ERM. Could you provide the correct way to perform ERM on cdsprites
?
Thanks.