jiant
jiant copied to clipboard
MLM task is broken in the new update
MLM task
To Reproduce
- Tell use which version of
jiant
you're using: 2.2.0 - Describe the environment where you're using
jiant
, e.g, "Macbook CPU"
Expected behavior It should create a model and start the training.
Screenshots
It throws the following exception when it tries to create MLM head.
Additional context The issue happens in the following line: https://github.com/nyu-mll/jiant/blob/310f22b515211f2e8fc8d229e196ce520557742d/jiant/proj/main/modeling/heads.py#L70 It happens because it calls JiantMLMHeadFactory with some arguments while the initializer doesn't accept any argument. A workaround is to create an object by adding the following lines:
if head_class ==JiantMLMHeadFactory:
head_class = head_class()
After this fix, the following line throws an exception: https://github.com/nyu-mll/jiant/blob/310f22b515211f2e8fc8d229e196ce520557742d/jiant/proj/main/modeling/heads.py#L208 It can be fixed by changing it to:
def __call__(self, task, model_arch, **kwargs):
"""Summary
Args:
task (Task): Task used to initialize task head
**kwargs: Additional arguments required to initialize task head
"""
mlm_head_class = self.registry[model_arch]
mlm_head = mlm_head_class(**kwargs)
return mlm_head
Then, the next issue is for the following line: https://github.com/nyu-mll/jiant/blob/310f22b515211f2e8fc8d229e196ce520557742d/jiant/proj/main/modeling/taskmodels.py#L283
and it can be fixed by changing the line to:
input_ids=masked_batch.masked_input_ids,
Thanks for the workaround! Additionally, the original line 201 should be further changed into
def __call__(self, task, model_arch, hidden_dropout_prob, **kwargs):
to provent hidden_dropout_prob
from being passed with **kwargs
, which leads to an error.
Also, as of Transformers v4.5, transformers.models.bert.modeling_bert.BertLayerNorm
and transformers.models.bert.modeling_bert.gelu
are not longer supported. My correction is to change the former into torch.nn.LayerNorm
and the latter into x = transformers.models.bert.modeling_bert.gelu(x)
.
Hope it can help future users!