transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Fix BLT training_ci overfit test

Open preetam1407 opened this issue 2 weeks ago • 8 comments

What does this PR do?

This PR fixes the BLT entry in the new training CI by making the tiny BLT model both:

  • reliably overfit the fixed batch for generation, and
  • pass the training gradient-norm reduction checks with BLT-specific thresholds.

In the current setup, the tiny BLT config used in BltModelTest::test_training_overfit shows:

  • loss going from ~3.46 to ~0.18 (~94.8% reduction),
  • grad norm going from ~1.28 to ~0.24 (~81.1% reduction),

but the test failed because:

  • the generic training test expected a grad-norm reduction ≥ 90%, and
  • generation with the default KV cache did not always reproduce the training sequence, while generation with use_cache=False did.

This PR makes two changes:

  1. BLT config

    • Add a use_cache argument to BltConfig.__init__ with default False, and forward it into super().__init__.
    • BLT now defaults to use_cache=False (matching the recommended generation settings in the BLT model card), while still respecting any explicit use_cache value in existing configs.
    • With this change, model.generate(...) uses the non-cache path by default for BLT, which fixes the generation mismatch in the training overfit test.
  2. BLT tests

    • In BltModelTest (in tests/models/blt/test_modeling_blt.py), override the training thresholds used by TrainingTesterMixin:
      • keep training_loss_reduction_threshold = 0.9,
      • set training_grad_norm_reduction_threshold = 0.8 for BLT only.
    • Remove the previous BLT-specific skip of test_training_overfit, so the shared test_training_overfit from TrainingTesterMixin now runs with BLT thresholds.
    • Empirically, the tiny BLT test config consistently reaches ~81% grad-norm reduction with gradient clipping, so 0.8 is a stable but still strict threshold, while the loss overfits very strongly (~95% reduction).

Verification (local):

  • Command: pytest tests/models/blt/test_modeling_blt.py::BltModelTest::test_training_overfit -s -vv

  • Results:

    • loss_reduction: ~94.8% (> 90% threshold),
    • grad_norm_reduction: ~81.1% (> 80% BLT threshold),
    • generated sequence exactly matches the fixed training pattern.

Fixes #42629

preetam1407 avatar Dec 07 '25 14:12 preetam1407

Quick update:

  • test_training_overfit for BLT now passes locally with the overridden thresholds in BltModelTest (loss ~95% reduction, grad norm ~81% reduction), and generation overfits the fixed pattern.
  • CI “check_code_quality” was failing because of a trailing whitespace in configuration_blt.py – I’ve fixed that and pushed.

The remaining CI failures are in tests/models/blt/test_modeling_blt.py::*assisted_decoding*:

  • AttributeError: 'DynamicCache' object has no attribute 'self_attention_cache'

These come from the assisted decoding tests using the new DynamicCache API. That looks like a separate issue in BLT’s cache handling, not related to the training-overfit thresholds this PR changes. Happy to help look into that in a follow-up if needed, but wanted to keep this PR focused on the training overfit test.

preetam1407 avatar Dec 07 '25 14:12 preetam1407

run-slow: blt

3outeille avatar Dec 08 '25 12:12 3outeille

💔 This comment contains run-slow, but unknown error occurred and the workflow run aborted!

github-actions[bot] avatar Dec 08 '25 12:12 github-actions[bot]

BLT now defaults to use_cache=False (matching the recommended generation settings in the BLT model card), while still respecting any explicit use_cache value in existing configs.

I found it weird that the generation is not working with use_cache=True. I think it is worth investigating why (cc: @itazap if you have time to guide @preetam1407 )

Empirically, the tiny BLT test config consistently reaches ~81% grad-norm reduction with gradient clipping, so 0.8 is a stable but still strict threshold, while the loss overfits very strongly (~95% reduction)

As for lowering the grad_norm threshold, I am against it and I think the reason that it doesn't reduce to 90% is because we don't have proper weight initialization. Maybe worth checking how they do it (cf https://github.com/facebookresearch/blt/blob/main/bytelatent/model/blt.py#L1052) and implement something like this but within transformers (https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/transformers_modeling_backend/model/model.py#L167)

3outeille avatar Dec 08 '25 12:12 3outeille

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

💔 This comment contains run-slow, but unknown error occurred and the workflow run aborted!

github-actions[bot] avatar Dec 08 '25 12:12 github-actions[bot]

use_cache=False I believe since we have the BltPatcher which requires the full sequence to know how to patch / group tokens

edit (more info): The BltPatcher will group the raw byte sequence intro groups of bytes which will be patches / "tokens" by the BltModel. If the cache window is < than the length of this byte sequence (note: consider the length in bytes, so roughly x4 the number of characters to be safe), then the BltPatcher will only see a smaller window of the sequence (at generation time) and patch it differently than if the whole sequence is being considered at once.

so use_cache=False is correct or I would try forcing the cache window to be much larger than the max byte sequence length

itazap avatar Dec 08 '25 16:12 itazap

I found it weird that the generation is not working with use_cache=True. I think it is worth investigating why (cc: @itazap if you have time to guide @preetam1407 )

Thanks @itazap for the help!

As for lowering the grad_norm threshold, I am against it and I think the reason that it doesn't reduce to 90% is because we don't have proper weight initialization. Maybe worth checking how they do it (cf https://github.com/facebookresearch/blt/blob/main/bytelatent/model/blt.py#L1052) and implement something like this but within transformers (https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/transformers_modeling_backend/model/model.py#L167)

@3outeille I’ve implemented the custom BLT weight initialization, and the overfit test now behaves as expected with proper grad_norm reduction.

Since the modular to modeling conversion auto-generated BltTextModel and BltVisionModel, I added placeholder classes and updated the config + repo checks to explicitly ignore them. image

preetam1407 avatar Dec 11 '25 09:12 preetam1407

nice job ! Will review tomorrow

3outeille avatar Dec 11 '25 12:12 3outeille

checking monday, it is weird to me that make fixup doesnt work as expected. You shouldn't have to add those placeholders to begin with @ArthurZucker

3outeille avatar Dec 12 '25 22:12 3outeille

checking monday, it is weird to me that make fixup doesnt work as expected. You shouldn't have to add those placeholders to begin with @ArthurZucker

This is fixed now. the placeholder classes are removed.

preetam1407 avatar Dec 13 '25 07:12 preetam1407

@3outeille, will be waiting for your review! I think we have resolved all the issues mentioned last week.

preetam1407 avatar Dec 15 '25 10:12 preetam1407

run-slow: blt

3outeille avatar Dec 15 '25 12:12 3outeille

This comment contains run-slow, running the specified jobs:

models: ["models/blt"] quantizations: []

github-actions[bot] avatar Dec 15 '25 12:12 github-actions[bot]

alright, just last issue to address and it will be good to merge. Good job overall ! 🚀

3outeille avatar Dec 15 '25 12:12 3outeille

alright, just last issue to address and it will be good to merge. Good job overall ! 🚀

Hey @3outeille, could you please point me to the last remaining issue you mentioned? I can’t seem to locate it.

preetam1407 avatar Dec 15 '25 13:12 preetam1407

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

github-actions[bot] avatar Dec 15 '25 15:12 github-actions[bot]

reccurent gemma test failing by passing use_cache=False. Maybe use use_cache=True for test_training_mixing.py (only for reccurent_gemma)

            with torch.no_grad():
>               generated_ids = model.generate(
                    prompt_ids,
                    max_new_tokens=num_tokens_to_generate,
                    do_sample=False,
                    pad_token_id=config.pad_token_id if hasattr(config, "pad_token_id") else 0,
                    eos_token_id=0,
                    use_cache=False,
                )

tests/test_training_mixin.py:351: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py:120: in decorate_context
    return func(*args, **kwargs)
/usr/local/lib/python3.10/site-packages/transformers/generation/utils.py:2684: in generate
    result = decoding_method(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = RecurrentGemmaForCausalLM(
  (model): RecurrentGemmaModel(
    (embed_tokens): Embedding(99, 32, padding_idx=0)
    (layers): ModuleList(
      (0-1): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm((32,), eps=1e-06)
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=32, out_features=32, bias=True)
          (linear_x): Linear(in_features=32, out_features=32, bias=True)
          (linear_out): Linear(in_features=32, out_features=32, bias=True)
          (conv_1d): Conv1d(32, 32, kernel_size=(4,), stride=(1,), padding=(3,), groups=32)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): GELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm((32,), eps=1e-06)
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=32, out_features=18, bias=True)
          (up_proj): Linear(in_features=32, out_features=18, bias=True)
          (down_proj): Linear(in_features=18, out_features=32, bias=True)
          (act_fn): GELUTanh()
        )
      )
    )
    (final_norm): RecurrentGemmaRMSNorm((32,), eps=1e-06)
  )
  (lm_head): Linear(in_features=32, out_features=99, bias=False)
)
input_ids = tensor([[1]]), logits_processor = []
stopping_criteria = [<transformers.generation.stopping_criteria.MaxLengthCriteria object at 0x7fb9b22c5d50>, <transformers.generation.stopping_criteria.EosTokenCriteria object at 0x7fb99d5ffd90>]
generation_config = <[TypeError('Object of type Tensor is not JSON serializable') raised in repr()] GenerationConfig object at 0x7fb99d97e230>
synced_gpus = False, streamer = None
model_kwargs = {'attention_mask': tensor([[1, 1]]), 'cache_position': tensor([0, 1]), 'logits_to_keep': 1, 'use_cache': False}
pad_token_id = tensor(0), output_attentions = False
output_hidden_states = False, output_scores = False, output_logits = None
return_dict_in_generate = False, has_eos_stopping_criteria = True

    def _sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: LogitsProcessorList,
        stopping_criteria: StoppingCriteriaList,
        generation_config: GenerationConfig,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        **model_kwargs,
    ) -> GenerateNonBeamOutput | torch.LongTensor:
        r"""
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
    
        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            logits_processor (`LogitsProcessorList`):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            generation_config ([`~generation.GenerationConfig`]):
                The generation configuration to be used as parametrization of the decoding method.
            synced_gpus (`bool`):
                Whether to continue running the while loop until max_length (needed to avoid deadlocking with
                `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
            streamer (`BaseStreamer`, *optional*):
                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
    
        Return:
            [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
            A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.
        """
        # init values
        pad_token_id = generation_config._pad_token_tensor
        output_attentions = generation_config.output_attentions
        output_hidden_states = generation_config.output_hidden_states
        output_scores = generation_config.output_scores
        output_logits = generation_config.output_logits
        return_dict_in_generate = generation_config.return_dict_in_generate
        has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
        do_sample = generation_config.do_sample
    
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        raw_logits = () if (return_dict_in_generate and output_logits) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
    
        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
    
        # keep track of which sequences are already finished
        batch_size, cur_len = input_ids.shape[:2]
        this_peer_finished = False
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
    
        model_forward = (
            self.get_compiled_call(generation_config.compile_config)
            if self._valid_auto_compile_criteria(model_kwargs, generation_config)
            else self.__call__
        )
    
        prefill_consumed = False
        outputs = self._prefill(input_ids, generation_config, model_kwargs)
    
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
            if prefill_consumed:
                model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
                outputs = model_forward(**model_inputs, return_dict=True)
            prefill_consumed = True
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs,
                model_kwargs,
                is_encoder_decoder=self.config.is_encoder_decoder,
            )
            if synced_gpus and this_peer_finished:
                continue
    
            # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
            # (the clone itself is always small)
            next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
    
            # pre-process distribution
            next_token_scores = logits_processor(input_ids, next_token_logits)
    
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_logits:
                    raw_logits += (next_token_logits,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
    
                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )
    
            # token selection
            if do_sample:
                probs = nn.functional.softmax(next_token_scores, dim=-1)
                # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(next_token_scores, dim=-1)
    
            # finished sentences should have their next token be a padding token
            if has_eos_stopping_criteria:
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
    
            # update generated ids, model inputs, and length for next step
>           input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
E           RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 2 for tensor number 1 in the list.

/usr/local/lib/python3.10/site-packages/transformers/generation/utils.py:2932: RuntimeError

3outeille avatar Dec 15 '25 16:12 3outeille

@3outeille, updated tests/test_training_mixin.py to set use_cache=True only for model_type == "recurrent_gemma".

preetam1407 avatar Dec 15 '25 17:12 preetam1407

A few CI checks are still failing. The CI tests_tokenization failures look infra-related, similar to some earlier CI failures in this PR.

I ran the failing tests locally, and they all pass on this branch.

preetam1407 avatar Dec 15 '25 20:12 preetam1407

[For maintainers] Suggested jobs to run (before merge)

run-slow: blt

github-actions[bot] avatar Dec 16 '25 04:12 github-actions[bot]

@3outeille, all requested changes are done. Whenever you get time, I’d appreciate you taking a look to merge it.

Thanks a lot!

preetam1407 avatar Dec 16 '25 05:12 preetam1407

thank you again for your work !

3outeille avatar Dec 16 '25 10:12 3outeille