triplet-loss-pytorch icon indicating copy to clipboard operation
triplet-loss-pytorch copied to clipboard

你好,请问可以解释一下这一部分的代码吗?

Open KrystalCWT opened this issue 4 years ago • 3 comments

你好,请问可以解释一下这一部分的代码吗?没看懂你triplet loss是怎么计算的。 temp_x = [torch.stack(input[i], dim=0) for i in range(len(input))] temp_y = [torch.stack(target[i], dim=0) for i in range(len(target))] new_x = torch.stack(temp_x, dim=0) new_y = torch.stack(temp_y, dim=0)

    new_x = [new_x[:, i] for i in range(3)]
    new_y = [new_y[:, i] for i in range(3)]
    sample_input = torch.cat(new_x, 0)
    sample_target = torch.cat(new_y, 0)
    # print (sample_target)
    # print (sample_target[:batch_size])
    # print (sample_target[batch_size:(batch_size * 2)])
    # print (sample_target[-batch_size:])
    target = sample_target.cuda(async=True)
    input_var = torch.autograd.Variable(sample_input.cuda())
    target_var = torch.autograd.Variable(target.cuda())
    # compute output
    output = model(input_var)
    anchor = output[:temp_batch_size]
    positive = output[temp_batch_size:(temp_batch_size * 2)]
    negative = output[-temp_batch_size:]

KrystalCWT avatar Dec 09 '20 07:12 KrystalCWT

你可以把每一步的数据size打印出来看一下,大概就明白了,主要就是把dataloader返回的结果拼装成loss能处理的模式

chencodeX avatar Dec 11 '20 06:12 chencodeX

你好 我也有一个疑问 这个anchor positive negative 的获取方式为什么是这样的

c464851257 avatar Aug 16 '22 07:08 c464851257

因为我把三组数据拼接成一个batch,然后做forward

chencodeX avatar Sep 29 '22 02:09 chencodeX