MMD_AAE_PyTorch icon indicating copy to clipboard operation
MMD_AAE_PyTorch copied to clipboard

This is the unofficial PyTorch implementation of Domain Generalization with Adversarial Feature Learning.

Results 3 MMD_AAE_PyTorch issues
Sort by recently updated
recently updated
newest added

The original function calculates pairwise distances by looping over all pairs of instances, which can be slow if the number of instances is large. We can use broadcasting and vectorization...

老哥,非常感谢你的代码,救我狗命! 但我直接下载下来跑不通,报错了在main.py里面行151:adv_loss = advLoss(torch.square(preds), all_labels),preds是300dim,all_labels是600dim,然后说维度不匹配,请问pre应该是300个还是600个呀?》。《

你好 感谢你分享的MMD_AAE复现code! 在使用你代码的过程中我发现MMD loss一直保持在0.0600无法下降,我仔细调完后猜测可能是mmd的函数没有被载入PyTorch计算图的原因,不知是否是我猜测有误~