pytorch-metric-learning icon indicating copy to clipboard operation
pytorch-metric-learning copied to clipboard

Implementation logics for generalised end-to-end (GE2E) loss

Open penguinwang96825 opened this issue 1 year ago • 0 comments

Hi, recently I've come across the implementation of a 2018 paper GE2E, which basically computed the loss with the corresponding similarity matrix given embedding vectors and all centroids. I wrote a simple Python code with the whole logics of GE2E, and yet I don't seem to know if I were correct or not. Furthermore, I was wondering if this could be applied to BaseMetricLossFunction.

import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size = 4
num_classes = 3
embed_dim = 4

x = torch.randn((batch_size, embed_dim))
y = torch.randint(3, num_classes+3, (batch_size, ))

w = nn.Parameter(torch.tensor(10.0))
b = nn.Parameter(torch.tensor(-5.0))

classes = torch.unique(y)
num_classes = len(classes)

indices = list(map(lambda c: y.eq(c).nonzero().squeeze(1), classes))
y_ordered = torch.cat([y[idx] for idx in indices], dim=0).view_as(y)
y_ordered = y_ordered.apply_(lambda val: {c.item():i for i, c in enumerate(classes)}.get(val))
x_ordered = torch.cat([x[idx] for idx in indices], dim=0).view_as(x)
prototypes = torch.stack([x[idx].mean(0) for idx in indices])

cossim_matrix = F.normalize(x_ordered, p=2, dim=1) @ F.normalize(prototypes, p=2, dim=1).T
cossim_matrix = cossim_matrix.view(batch_size, num_classes)
cossim_matrix = cossim_matrix * w + b
loss = F.cross_entropy(cossim_matrix, y_ordered)

penguinwang96825 avatar Mar 24 '23 14:03 penguinwang96825