SSKD icon indicating copy to clipboard operation
SSKD copied to clipboard

Contrastive Prediction

Open larry10hhobh opened this issue 4 years ago • 4 comments

Hi

Thank U for your code. I find a question in code of contrastive prediction. In student.py

# train ssp_head
for epoch in range(args.t_epoch):

    t_model.eval()
    loss_record = AverageMeter()
    acc_record = AverageMeter()

    start = time.time()
    for x, _ in train_loader:

        t_optimizer.zero_grad()

        x = x.cuda()
        c,h,w = x.size()[-3:]
        x = x.view(-1, c, h, w)

        _, rep, feat = t_model(x, bb_grad=False)
        batch = int(x.size(0) / 4)
        nor_index = (torch.arange(4*batch) % 4 == 0).cuda()
        aug_index = (torch.arange(4*batch) % 4 != 0).cuda()

        nor_rep = rep[nor_index]
        aug_rep = rep[aug_index]
        nor_rep = nor_rep.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2)
        aug_rep = aug_rep.unsqueeze(2).expand(-1,-1,1*batch)
        simi = F.cosine_similarity(aug_rep, nor_rep, dim=1)
        target = torch.arange(batch).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda()
        loss = F.cross_entropy(simi, target)

I think nor_rep and aug_rep come from different samples. It is not the relation between X and its transformation mentioned in the paper. Is my understanding wrong?

larry10hhobh avatar Aug 06 '20 06:08 larry10hhobh

Hi, thanks for running this repo.

The batch from train_loader is 64x4x3x32x32. The dimension '4' means one normal data + three transformed data. After x.view(), its shape is (64x4)x3x32x32. Suppose the output feature shape is (64x4)xF. nor_index and aug_index split the output features into two tensors: 64xF (normal) and 192xF (transformed). These two tensors are corresponding to nor_rep and aug_rep.

xuguodong03 avatar Aug 06 '20 06:08 xuguodong03

你好,我的疑惑主要是这个4是怎么来的。 pytorch进行是在线增广,这样1个epoch里面应该不会同时出现1个样本及其变换吧。即便存在的话,为什么确定是一个原本加上增广的3个样本,这个1+3是怎么来的?

larry10hhobh avatar Aug 07 '20 06:08 larry10hhobh

没有使用torchvision.datasets.CIFAR100,而是对dataset进行了修改,参见cifar.py

xuguodong03 avatar Aug 07 '20 09:08 xuguodong03

好的,感谢解答~

larry10hhobh avatar Aug 07 '20 14:08 larry10hhobh