CLIP icon indicating copy to clipboard operation
CLIP copied to clipboard

Just a loss function for discussion

Open 072jiajia opened this issue 2 years ago • 3 comments

Hi! Thanks for sharing the great work! I have some questions that I want to discuss with the authors and everyone interested in this work. In Brief, do you think using the following loss function can make the training process more stable? Or can it make the model utilize the joint multi-modal embedding spaces more efficiently?

Numpy-like pseudocode

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# W_std_i[d_i, ] - learned proj of image to standard deviation in logarithmic form
# W_std_t[d_t, ] - learned proj of text to standard deviation in logarithmic form

# extract feature representations of each modality
I_f = image_encoder(I)  # [n, d_i]
T_f = text_encoder(T)  # [n, d_t]
# joint multimodal embedding [n, d_e]
I_e = np.dot(I_f, W_i)
T_e = np.dot(T_f, W_t)
# logarithmic form standard deviation in joint multimodal embedding space
I_ln_std = np.expand_dims(np.dot(I_f, W_std_i), axis=1) # [n, 1]
T_ln_std = np.expand_dims(np.dot(T_f, W_std_t), axis=0) # [1, n]
# variance in joint multimodal embedding space
I_var = np.exp(I_ln_std)**2 # [n, 1]
T_var = np.exp(T_ln_std)**2 # [1, n]
# pairwise distances [n, n]
distances_sq = np.sum((np.expand_dims(I_f, axis=1) - np.expand_dims(T_f, dim=0)) ** 2, axis=2)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(-I_ln_std - distances / I_var, label, axis=0)
loss_t = cross_entropy_loss(-T_ln_std - distances / T_var, label, axis=1)
loss = (loss_i + loss_t)/2

PyTorch-like pseudocode

def my_loss(I_f, I_ln_std, T_f, T_ln_std):
    '''
        I_f: [n, d_e]
        T_f: [n, d_e]
        I_ln_std: [n, ]
        T_ln_std: [n, ]
    '''
    n = I_f.size(0)
    ce_loss = nn.CrossEntropyLoss()

    # I_f: [n, 1, d_e]
    I_f = I_f.unsqueeze(1)
    # T_f: [1, n, d_e]
    T_f = T_f.unsqueeze(0)
    # [n, n]
    distances = torch.sum((I_f - T_f) ** 2, dim=2)

    # [n, ]
    I_var = I_ln_std.exp() ** 2
    T_var = T_ln_std.exp() ** 2

    # [1, n]
    T_ln_std = T_ln_std.unsqueeze(0)
    T_var = T_var.unsqueeze(0)
    I_ln_std = I_ln_std.unsqueeze(0)
    I_var = I_var.unsqueeze(0)

    label = torch.arange(n, device=I_f.device)
    loss_t = ce_loss(-T_ln_std - distances / T_var, label)
    loss_i = ce_loss(-I_ln_std - distances.t() / I_var, label)

    return (loss_t + loss_i) / 2


I_f, I_ln_std = image_encoder(image)
T_f, T_ln_std = text_encoder(text)

loss = my_loss(I_f, I_ln_std, T_f, T_ln_std)

I'll explain it in an example and in math in the following comments.

072jiajia avatar Oct 24 '22 07:10 072jiajia

Considering the following two sentences, "A dog in the garden" and "A shiba dog in the garden", it is clear that the second sentence has more conditions. We may say that the first sentence's variance should be larger, and the second sentence's variance should be smaller. To be more precise, the image with more specific information should have lower variance for output.

072jiajia avatar Oct 24 '22 07:10 072jiajia

In the view of probability You may view it here https://hackmd.io/@jS3Mpow5SSOX1M0cbyflTw/Hy9w75-Qs

072jiajia avatar Oct 24 '22 08:10 072jiajia

Thank you for reading util here. Please tell me if my thought has any mistakes. Since I don't have that many GPUs and corresponging computer equipments to do the experiment, if you also think the idea make sense, please try the experiment. I'm looking forward to knowing that whether this method actually works.

072jiajia avatar Oct 24 '22 08:10 072jiajia