argilla icon indicating copy to clipboard operation
argilla copied to clipboard

[FEATURE] Add Label Predictions for `Sentence Transformers` within `ArgillaTrainer`

Open kursathalat opened this issue 1 year ago • 2 comments

Is your feature request related to a problem? Please describe. For the Sentence Transformers framework within the ArgillaTrainer, predict() only makes sentence similarity predictions using cosine similarity. However, for the NLI task, the sentences must be labeled with one of the categories like a textcat task. Another problem this creates is: for an NLI task, there is a premise and a hypothesis and the task is to decide whether the hypothesis is inferred from the premise or not. By using similarity, there will be no specific hypothesis or premise, rather both sentences will be the same in terms of being inferred from each other.

Describe the solution you'd like Add a method/parameter to predict NLI data by labeling them into categories, i.e. entailment, contradiction, neutral. Here is an example:

loss_model = SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(),num_labels=3)

for step, batch in enumerate(train_dataloader):
  features, label_ids = batch
  for idx in range(len(features)):
    features[idx] = batch_to_device(features[idx], model.device)
  label_ids = label_ids.to(model.device)

  with torch.no_grad():
    _, prediction = loss_model(features, labels=None)

model_prediction = torch.argmax(prediction, dim=1)

Describe alternatives you've considered

Additional context

kursathalat avatar Oct 09 '23 14:10 kursathalat

@kursathalat thanks for this issue, after wrapping up the previous PR we can work on this new feature.

davidberenstein1957 avatar Oct 09 '23 15:10 davidberenstein1957

This issue is stale because it has been open for 90 days with no activity.

github-actions[bot] avatar Jan 30 '24 01:01 github-actions[bot]

This issue was closed because it has been inactive for 30 days since being marked as stale.

github-actions[bot] avatar May 06 '24 01:05 github-actions[bot]