gritlm
gritlm copied to clipboard
How to train using hard negative only?
I understand that the GritLM fine-tuning uses both in-batch negative and hard negatives for contrastive learning. We can use in-batch negatives only by setting train group size to 1.
However, in my case, I can only use hard negatives not in-batch negatives. Is there a way to disable in-batch negatives? If not, could you kindly advise which part of code I should modify to implement the changes myself please?
I think you'd need to modify https://github.com/ContextualAI/gritlm/blob/724df95c0f760249d3581f82cd7ca7f9ad5191c0/gritlm/training/model.py#L41 to only gather hard negatives and only use them for the loss. Would be great if you could share your code changes!
I think you'd need to modify
gritlm/gritlm/training/model.py
Line 41 in 724df95
p_reps = self._dist_gather_tensor(p_reps) to only gather hard negatives and only use them for the loss. Would be great if you could share your code changes!
Thanks for the instructions! I’ve made the proposed code changes, which involve avoiding data gathering from other GPUs and reshaping p_reps and q_reps to ensure that the dot products are computed between each query and its corresponding passages. Let me know your thoughts on these changes.
Additionally, I was wondering about the function compute_similarity. Under what circumstances would len(p_reps.size()) not equal 2?
class DistributedContrastiveLoss:
def __init__(self, temperature: float, negatives_cross_device: bool, hard_negatives_only: bool):
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
self.temperature = temperature
self.hard_negatives_only = hard_negatives_only
# Do not gather other GPU's batches if use hard negatives only
self.negatives_cross_device = False if self.hard_negatives_only else negatives_cross_device
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError('Cannot do negatives_cross_device without distributed training')
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
def __call__(self, q_reps, p_reps):
"""
q_reps: [batch_size, hidden_size] # Query embeddings
p_reps: [batch_size * (num_negatives + 1), hidden_size] # Passage embeddings. num_negatives + 1 = train_group_size
"""
if self.negatives_cross_device:
# This gathers both negatives and positives.
# It could likely be optimized by only gathering negatives.
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)
if self.hard_negatives_only:
# Reshape `p_reps` to group passages for each query
p_reps = p_reps.view(q_reps.size(0), (p_reps.size(0) // q_reps.size(0)), -1) # [batch_size, num_negatives + 1, hidden_size]
scores = self.compute_similarity(q_reps, p_reps) / self.temperature
scores = scores.view(q_reps.size(0), -1)
if self.hard_negatives_only:
# Target is always 0 since the first passage in each group is positive
target = torch.zeros(scores.size(0), dtype=torch.long, device=scores.device)
else:
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target *= (p_reps.size(0) // q_reps.size(0))
return self.cross_entropy(scores, target)
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
if t is None: return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
# All tensors have the same shape, as pooling already applied to them
dist.all_gather(all_tensors, t)
all_tensors[self.rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def compute_similarity(self, q_reps, p_reps):
if self.hard_negatives_only:
# Query embedding: [batch_size, hidden_size] -unsqueeze-> [batch_size, 1, hidden_size]
# Passage embedding: [batch_size, num_negatives + 1, hidden_size] -transpose-> [batch_size, hidden_size, num_negatives + 1]
# Resulting shape: [batch_size, 1, num_negatives + 1] -squeeze-> [batch_size, num_negatives + 1]
return torch.matmul(q_reps.unsqueeze(1), p_reps.transpose(-2, -1)).squeeze(1)
if len(p_reps.size()) == 2: return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
Hi @Muennighoff,
Just checking in on my previous message regarding the code changes and the compute_similarity function. Would appreciate your thoughts when you get a chance.
Let me know if you need any clarifications!
Maybe try training and check if it works?