CLIP
CLIP copied to clipboard
Just a loss function for discussion
Hi! Thanks for sharing the great work! I have some questions that I want to discuss with the authors and everyone interested in this work. In Brief, do you think using the following loss function can make the training process more stable? Or can it make the model utilize the joint multi-modal embedding spaces more efficiently?
Numpy-like pseudocode
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# W_std_i[d_i, ] - learned proj of image to standard deviation in logarithmic form
# W_std_t[d_t, ] - learned proj of text to standard deviation in logarithmic form
# extract feature representations of each modality
I_f = image_encoder(I) # [n, d_i]
T_f = text_encoder(T) # [n, d_t]
# joint multimodal embedding [n, d_e]
I_e = np.dot(I_f, W_i)
T_e = np.dot(T_f, W_t)
# logarithmic form standard deviation in joint multimodal embedding space
I_ln_std = np.expand_dims(np.dot(I_f, W_std_i), axis=1) # [n, 1]
T_ln_std = np.expand_dims(np.dot(T_f, W_std_t), axis=0) # [1, n]
# variance in joint multimodal embedding space
I_var = np.exp(I_ln_std)**2 # [n, 1]
T_var = np.exp(T_ln_std)**2 # [1, n]
# pairwise distances [n, n]
distances_sq = np.sum((np.expand_dims(I_f, axis=1) - np.expand_dims(T_f, dim=0)) ** 2, axis=2)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(-I_ln_std - distances / I_var, label, axis=0)
loss_t = cross_entropy_loss(-T_ln_std - distances / T_var, label, axis=1)
loss = (loss_i + loss_t)/2
PyTorch-like pseudocode
def my_loss(I_f, I_ln_std, T_f, T_ln_std):
'''
I_f: [n, d_e]
T_f: [n, d_e]
I_ln_std: [n, ]
T_ln_std: [n, ]
'''
n = I_f.size(0)
ce_loss = nn.CrossEntropyLoss()
# I_f: [n, 1, d_e]
I_f = I_f.unsqueeze(1)
# T_f: [1, n, d_e]
T_f = T_f.unsqueeze(0)
# [n, n]
distances = torch.sum((I_f - T_f) ** 2, dim=2)
# [n, ]
I_var = I_ln_std.exp() ** 2
T_var = T_ln_std.exp() ** 2
# [1, n]
T_ln_std = T_ln_std.unsqueeze(0)
T_var = T_var.unsqueeze(0)
I_ln_std = I_ln_std.unsqueeze(0)
I_var = I_var.unsqueeze(0)
label = torch.arange(n, device=I_f.device)
loss_t = ce_loss(-T_ln_std - distances / T_var, label)
loss_i = ce_loss(-I_ln_std - distances.t() / I_var, label)
return (loss_t + loss_i) / 2
I_f, I_ln_std = image_encoder(image)
T_f, T_ln_std = text_encoder(text)
loss = my_loss(I_f, I_ln_std, T_f, T_ln_std)
I'll explain it in an example and in math in the following comments.
Considering the following two sentences, "A dog in the garden" and "A shiba dog in the garden", it is clear that the second sentence has more conditions. We may say that the first sentence's variance should be larger, and the second sentence's variance should be smaller. To be more precise, the image with more specific information should have lower variance for output.
In the view of probability You may view it here https://hackmd.io/@jS3Mpow5SSOX1M0cbyflTw/Hy9w75-Qs
Thank you for reading util here. Please tell me if my thought has any mistakes. Since I don't have that many GPUs and corresponging computer equipments to do the experiment, if you also think the idea make sense, please try the experiment. I'm looking forward to knowing that whether this method actually works.