RoCL icon indicating copy to clipboard operation
RoCL copied to clipboard

About sim Loss

Open muyuuuu opened this issue 3 years ago • 0 comments

Hello, thanks for you code, I'm als prepare to use SSL to improve robustness. Some question:

https://github.com/Kim-Minseon/RoCL/blob/b6d5185e294e8bca5670e9146df81b54d99d0635/src/loss.py#L59

This loss function is from paper simCLR ? I also decided to use this loss function.

https://github.com/Kim-Minseon/RoCL/blob/b6d5185e294e8bca5670e9146df81b54d99d0635/src/rocl_train.py#L154-L155

But I found that it converges very slowly and has much more value than classification function. Have you encountered this problem?

Here is my implement and log.

def InfoNCE(x1, x2, device):

    bs = x1.shape[0]
    feature = torch.cat((x1, x2), dim=0)
    feature = torch.nn.functional.normalize(feature, dim=1)

    similarity_matrix = torch.matmul(feature, feature.T)
    labels = torch.cat([torch.arange(bs) for i in range(2)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    mask = torch.eye(labels.shape[0], dtype=torch.bool)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat((positives, negatives), dim=1).to(device)
    labels = torch.zeros(logits.size()[0], dtype=torch.long).to(device)

    return logits, labels
Model: Wed Dec 29 01:10:13 2021 : clean class loss is 0.031855, attack class loss is 0.889221, sim loss is 4.037420.
Model: Wed Dec 29 01:10:48 2021 : clean class loss is 0.034873, attack class loss is 0.386264, sim loss is 4.034396.
Model: Wed Dec 29 01:11:23 2021 : clean class loss is 0.034596, attack class loss is 0.376193, sim loss is 4.035690.

muyuuuu avatar Dec 29 '21 03:12 muyuuuu