flash-linear-attention
flash-linear-attention copied to clipboard
Transformer model not learning after adding a classification head
I added a classification head to the pretrained Transformer++ model from https://huggingface.co/fla-hub/transformer-1.3B-100B/tree/main and finetuned on SST-2 dataset. However, the validation loss remained constant since the begginning. Here's my code for the Sequence Classification I defined. Similar architecture works for the GLA model. Could you help me to take a look if there's anything wrong with my code or anything else. ` class TransformerForSequenceClassification(TransformerPreTrainedModel): def init(self, model_name, num_labels, config): super().init(config) self.num_labels = num_labels self.model = AutoModelForCausalLM.from_pretrained(model_name).model self.config = config self.classifier = nn.Linear(self.config.hidden_size, self.num_labels, bias=False) self.model.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = None
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
labels = labels.to(pooled_logits.device)
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)`
@OREYR Hi, not knowing what happened, can you provide more details about exp settings, scheduler, data, model framework, etc.
Hi @yzhangcs, I used lora to finetune the model. Here's the notebook link containing my exp settings and the problem I stated above: https://colab.research.google.com/drive/101HrS5Zkib_ortNBoioZnE1AfW4QrxXG?usp=sharing. I'm not sure if the problem is related to how I define the Classification class or the experiment setting. When I use a similar classification class and the same experiment setting for GLA model, everything looks fine. If you can help me take a look, it would be great!
Hi @yzhangcs, do you have the access to the notebook now?
@OREYR Does that mean you randomly init your model again?
For newly init models, lr of 1e-5 is too small.
Hi @yzhangcs, only the classifier weight is newly initialized. And I think the problem is with Rotary Embedding here. q and k contain NaA after applying rotary embedding. Not sure if it is caused by left padding or something else.
Thank you for reporting this. I will have a check.
@OREYR looks like you wrap the classifier with LoRA as well, and the orginal random params are freezed?
@yzhangcs The classifier is in module_to_save, so it is not trained with LoRa but still updated along with other LoRa layers. I also tried to train the full model without LoRa, the q layers after rotary embedding and gradient are still NaN. After I removed rotary embedding, validation loss started to decrease but very very slowly with various learning rates.
@OREYR one thing to confirm: how is MLP called in your peft modules? I wrote some fused kernels in this module to save mems, so please check the impls to make sure they are properly executed.
@yzhangcs It should be called in the same way as other parts if I understood correctly. The thing is when I removed LoRa, the problem persisted. Only by removing the rotary embedding, the gradient is no longer NaN.
@OREYR Can you paste the full runnable script from which I can observe the abnormal values here?
This issue is stale because it has been open for 30 days with no activity.
This issue was closed because it has been inactive for 7 days since being marked as stale.