SimCTG
SimCTG copied to clipboard
SimCTG BART training
Hi @yxuansu, thanks for the wonderful library
I am trying to use SimCTG framework to train a BART model for a question generation task. I am facing the following issue in trying to train a BART model with SimCTG loss.
File "experiment-5/simctg_train.py", line 224, in <module>
train()
File "experiment-5/simctg_train.py", line 122, in train
mle_loss, cl_loss = simctgloss(last_hidden_states=last_hidden_states, logits=logits,
File "/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "simctg/lossfunction.py", line 93, in forward
assert labels.size() == input_ids.size()
AssertionError
While looking at the loss function, I did realize why this happens. Is the loss function designed to support only decoder only models like GPT for example ? How to adapt it for BART and T5 ? For bart and t5 the assertion that input ids and labels dimensions are the same need not hold.