flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

Transformer model not learning after adding a classification head

Open OREYR opened this issue 1 year ago • 12 comments

I added a classification head to the pretrained Transformer++ model from https://huggingface.co/fla-hub/transformer-1.3B-100B/tree/main and finetuned on SST-2 dataset. However, the validation loss remained constant since the begginning. Here's my code for the Sequence Classification I defined. Similar architecture works for the GLA model. Could you help me to take a look if there's anything wrong with my code or anything else. ` class TransformerForSequenceClassification(TransformerPreTrainedModel): def init(self, model_name, num_labels, config): super().init(config) self.num_labels = num_labels self.model = AutoModelForCausalLM.from_pretrained(model_name).model self.config = config self.classifier = nn.Linear(self.config.hidden_size, self.num_labels, bias=False) self.model.post_init()

def forward(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.FloatTensor] = None,
    token_type_ids: Optional[torch.LongTensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
    use_cache: Optional[bool] = None
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:


    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        inputs_embeds=inputs_embeds,
        past_key_values=past_key_values,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict
    )
    sequence_output = outputs[0]
    logits = self.classifier(sequence_output)
 

    if input_ids is not None:
        batch_size, sequence_length = input_ids.shape[:2]
    else:
        batch_size, sequence_length = inputs_embeds.shape[:2]

    assert (
        self.config.pad_token_id is not None or batch_size == 1
    ), "Cannot handle batch sizes > 1 if no padding token is defined."

    if self.config.pad_token_id is None:
        sequence_lengths = -1
    else:
        if input_ids is not None:
            # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
            sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
            sequence_lengths = sequence_lengths % input_ids.shape[-1]
            sequence_lengths = sequence_lengths.to(logits.device)
        else:
            sequence_lengths = -1
            logger.warning_once(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )

    pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

    loss = None
    if labels is not None:
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"

        if self.config.problem_type == "regression":
            loss_fct = nn.MSELoss()
            if self.num_labels == 1:
                loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
            else:
                loss = loss_fct(pooled_logits, labels)
        elif self.config.problem_type == "single_label_classification":
            loss_fct = nn.CrossEntropyLoss()
            labels = labels.to(pooled_logits.device)
            loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
        elif self.config.problem_type == "multi_label_classification":
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(pooled_logits, labels)

    if not return_dict:
        output = (pooled_logits,) + outputs[1:]
        return ((loss,) + output) if loss is not None else output

    return SequenceClassifierOutputWithPast(
        loss=loss,
        logits=pooled_logits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions
    )`

OREYR avatar Jun 30 '24 20:06 OREYR

@OREYR Hi, not knowing what happened, can you provide more details about exp settings, scheduler, data, model framework, etc.

yzhangcs avatar Jul 01 '24 12:07 yzhangcs

Hi @yzhangcs, I used lora to finetune the model. Here's the notebook link containing my exp settings and the problem I stated above: https://colab.research.google.com/drive/101HrS5Zkib_ortNBoioZnE1AfW4QrxXG?usp=sharing. I'm not sure if the problem is related to how I define the Classification class or the experiment setting. When I use a similar classification class and the same experiment setting for GLA model, everything looks fine. If you can help me take a look, it would be great!

OREYR avatar Jul 01 '24 17:07 OREYR

Hi @yzhangcs, do you have the access to the notebook now?

OREYR avatar Jul 03 '24 23:07 OREYR

@OREYR Does that mean you randomly init your model again? image

For newly init models, lr of 1e-5 is too small.

yzhangcs avatar Jul 08 '24 07:07 yzhangcs

Hi @yzhangcs, only the classifier weight is newly initialized. And I think the problem is with Rotary Embedding here. q and k contain NaA after applying rotary embedding. Not sure if it is caused by left padding or something else.

OREYR avatar Jul 08 '24 11:07 OREYR

Thank you for reporting this. I will have a check.

yzhangcs avatar Jul 08 '24 11:07 yzhangcs

@OREYR looks like you wrap the classifier with LoRA as well, and the orginal random params are freezed?

yzhangcs avatar Jul 08 '24 12:07 yzhangcs

@yzhangcs The classifier is in module_to_save, so it is not trained with LoRa but still updated along with other LoRa layers. I also tried to train the full model without LoRa, the q layers after rotary embedding and gradient are still NaN. After I removed rotary embedding, validation loss started to decrease but very very slowly with various learning rates.

OREYR avatar Jul 08 '24 12:07 OREYR

@OREYR one thing to confirm: how is MLP called in your peft modules? I wrote some fused kernels in this module to save mems, so please check the impls to make sure they are properly executed. image

yzhangcs avatar Jul 08 '24 12:07 yzhangcs

@yzhangcs It should be called in the same way as other parts if I understood correctly. The thing is when I removed LoRa, the problem persisted. Only by removing the rotary embedding, the gradient is no longer NaN.

OREYR avatar Jul 08 '24 13:07 OREYR

@OREYR Can you paste the full runnable script from which I can observe the abnormal values here?

yzhangcs avatar Jul 08 '24 13:07 yzhangcs

This issue is stale because it has been open for 30 days with no activity.

github-actions[bot] avatar Aug 11 '24 00:08 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Aug 25 '24 00:08 github-actions[bot]