jiant icon indicating copy to clipboard operation
jiant copied to clipboard

MLM task is broken in the new update

Open pooryapzm opened this issue 3 years ago • 1 comments

MLM task

To Reproduce

  1. Tell use which version of jiant you're using: 2.2.0
  2. Describe the environment where you're using jiant, e.g, "Macbook CPU"

Expected behavior It should create a model and start the training.

Screenshots It throws the following exception when it tries to create MLM head. Screen Shot 2021-06-09 at 12 14 54 PM

Additional context The issue happens in the following line: https://github.com/nyu-mll/jiant/blob/310f22b515211f2e8fc8d229e196ce520557742d/jiant/proj/main/modeling/heads.py#L70 It happens because it calls JiantMLMHeadFactory with some arguments while the initializer doesn't accept any argument. A workaround is to create an object by adding the following lines:

        if head_class ==JiantMLMHeadFactory:
            head_class = head_class()

After this fix, the following line throws an exception: https://github.com/nyu-mll/jiant/blob/310f22b515211f2e8fc8d229e196ce520557742d/jiant/proj/main/modeling/heads.py#L208 It can be fixed by changing it to:

    def __call__(self, task, model_arch, **kwargs):
        """Summary

        Args:
            task (Task): Task used to initialize task head
            **kwargs: Additional arguments required to initialize task head
        """
        mlm_head_class = self.registry[model_arch]
        mlm_head = mlm_head_class(**kwargs)
        return mlm_head

Then, the next issue is for the following line: https://github.com/nyu-mll/jiant/blob/310f22b515211f2e8fc8d229e196ce520557742d/jiant/proj/main/modeling/taskmodels.py#L283 and it can be fixed by changing the line to: input_ids=masked_batch.masked_input_ids,

pooryapzm avatar Jun 09 '21 02:06 pooryapzm

Thanks for the workaround! Additionally, the original line 201 should be further changed into

def __call__(self, task, model_arch, hidden_dropout_prob, **kwargs):

to provent hidden_dropout_prob from being passed with **kwargs, which leads to an error.

Also, as of Transformers v4.5, transformers.models.bert.modeling_bert.BertLayerNorm and transformers.models.bert.modeling_bert.gelu are not longer supported. My correction is to change the former into torch.nn.LayerNorm and the latter into x = transformers.models.bert.modeling_bert.gelu(x).

Hope it can help future users!

Pzoom522 avatar Apr 28 '22 12:04 Pzoom522