SimCTG icon indicating copy to clipboard operation
SimCTG copied to clipboard

SimCTG BART training

Open rahulseetharaman opened this issue 1 year ago • 0 comments

Hi @yxuansu, thanks for the wonderful library

I am trying to use SimCTG framework to train a BART model for a question generation task. I am facing the following issue in trying to train a BART model with SimCTG loss.

  File "experiment-5/simctg_train.py", line 224, in <module>
    train()
  File "experiment-5/simctg_train.py", line 122, in train
    mle_loss, cl_loss = simctgloss(last_hidden_states=last_hidden_states, logits=logits,
  File "/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "simctg/lossfunction.py", line 93, in forward
    assert labels.size() == input_ids.size()
AssertionError

While looking at the loss function, I did realize why this happens. Is the loss function designed to support only decoder only models like GPT for example ? How to adapt it for BART and T5 ? For bart and t5 the assertion that input ids and labels dimensions are the same need not hold.

rahulseetharaman avatar May 06 '23 16:05 rahulseetharaman