flax icon indicating copy to clipboard operation
flax copied to clipboard

Gradient Checkpointing causes model to compute junk results (NNX)

Open erfanzar opened this issue 9 months ago • 8 comments

System information

  • Flax, JAX, JAXlib versions: (0.10.4)
  • Python version: (e.g., 3.10.2)

Problem you have encountered:

When using gradient checkpointing in an flax.nnx model (nnx.remat), the model generates incorrect or junk results. This happens on both GPU and TPU. If two models are loaded:

  1. Model 1: Uses gradient checkpointing (EasyDeLGradientCheckPointers.NOTHING_SAVEABLE).
  2. Model 2: Does not use gradient checkpointing.

Both models will generate junk results. However, if only Model 2 (without checkpointing) is created and used, it works correctly.

What you expected to happen:

Each model should independently function correctly, and the activation checkpointing (remat) should not corrupt inference outputs when applied to a separate model instance.

Logs, error messages, etc:

(Provide any logs, traceback, or error messages if available.)

Steps to reproduce:

A minimal reproducible example is given below. Changing gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE to EasyDeLGradientCheckPointers.NONE resolves the issue.

Code snippet to reproduce the issue:

import easydel as ed
import jax
import transformers
from jax import numpy as jnp

def auto_remat(
    *modules: tp.Type[M],
    policy: tp.Union[
        EasyDeLGradientCheckPointers, str
    ] = EasyDeLGradientCheckPointers.NONE,
    prevent_cse: bool = True,
) -> tp.Tuple[tp.Type[M], ...]:
    if policy == EasyDeLGradientCheckPointers.NONE:
        return modules
    if isinstance(policy, str):
        policy = get_gradient_checkpoint_policy(policy)
    outs = ()
    for module in modules:
        assert issubclass(module, nn.Module)
        static_argnums = extract_static_parameters(module=module)
        if static_argnums is None:
            static_argnums = ()

        module.__call__ = nn.remat(
            f=module.__call__,
            prevent_cse=prevent_cse,
            static_argnums=static_argnums,
            policy=policy,
        )
        outs += (module,)
    return outs

def main():
    sharding_axis_dims = (1, 1, 1, -1)
    prefill_length = 512
    max_new_tokens = 128
    max_length = max_new_tokens + prefill_length
    pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"

    dtype = param_dtype = jnp.bfloat16
    partition_axis = ed.PartitionAxis()
    tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    tokenizer.padding_side = "left"
    tokenizer.pad_token_id = tokenizer.eos_token_id

    model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        auto_shard_model=True,
        sharding_axis_dims=sharding_axis_dims,
        config_kwargs=ed.EasyDeLBaseConfigDict(
            freq_max_position_embeddings=max_length,
            mask_max_position_embeddings=max_length,
            kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
            gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
            attn_dtype=jnp.bfloat16,
            attn_mechanism=ed.AttentionMechanisms.AUTO,
        ),
        quantization_method=ed.EasyDeLQuantizationMethods.NONE,
        param_dtype=param_dtype,
        dtype=dtype,
        partition_axis=partition_axis,
        precision=jax.lax.Precision.DEFAULT,
    )

    inference = ed.vInference(
        model=model,
        processor_class=tokenizer,
        generation_config=ed.vInferenceConfig(
            max_new_tokens=max_new_tokens,
            temperature=0.8,
            do_sample=True,
            top_p=0.95,
            top_k=10,
            eos_token_id=model.generation_config.eos_token_id,
            streaming_chunks=32,
            num_return_sequences=1,
        ),
    )

    inference.precompile(
        ed.vInferencePreCompileConfig(
            batch_size=1,
            prefill_length=prefill_length,
        )
    )

    messages = [
        {"role": "system", "content": "You are a helpful AI assistant."},
        {"role": "user", "content": "write 10 lines story about why you love EasyDeL"},
    ]

    ids = tokenizer.apply_chat_template(
        messages,
        return_tensors="jax",
        return_dict=True,
        add_generation_prompt=True,
    )
    print("Start Generation Process.")
    for response in inference.generate(**ids):
        ...
    print(
        tokenizer.batch_decode(
            response.sequences[..., response.padded_length :],
            skip_special_tokens=True,
        )
    )
    print(response.tokens_pre_second)

if __name__ == "__main__":
    main()

Workarounds Tried:

  • Setting gradient_checkpointing=EasyDeLGradientCheckPointers.NONE fixes the issue.
  • Ensuring that only one model (either with or without remat) is instantiated at a time prevents the corruption.
  • The issue only occurs when both models exist in memory simultaneously.

Possible Cause:

  • nn.remat might be affecting global state shared across models.
  • Memory corruption or state retention in flax.nnx affecting subsequent inference.

Additional Notes:

  • Would need further debugging into nn.remat handling in flax.nnx.
  • Possible scope leakage between checkpointed and non-checkpointed models.

Expected Fix: Ensure that gradient checkpointing via nnx.remat does not interfere with models that do not use checkpointing in the same session.

erfanzar avatar Mar 15 '25 11:03 erfanzar

Any update or help on this?

erfanzar avatar Mar 27 '25 11:03 erfanzar

EasyDeL seems to be the only jax implementation out there that supports GRPO training.

This is a blocking bug for my research. Is anyone else able to reproduce this? At least confirm that it is actually a Flax related bug?

peregilk avatar Mar 27 '25 12:03 peregilk

Thanks for reporting this @erfanzar. Would it be possible for you to create a test case where nnx.remat fails? nnx.remat is not doing whole lot except forwarding the underlying state to JAX.

cgarciae avatar Mar 28 '25 00:03 cgarciae

Hi @erfanzar ,

I've encountered similar issues with generating junk results when working with Qwen models, although perhaps for slightly different underlying reasons.

In my case, using my own library implementation, I found that the junk results with Qwen seemed related to numerical stability, particularly within the attention mechanism when using bfloat16. It appeared that some operations were becoming unstable at lower precision.

My workaround was to force key parts of the attention calculation to run in float32 precision, even if the rest of the model used bfloat16. This stabilized the computation and resolved the junk output issue for me.

Here's the relevant snippet from my attention implementation:

import jax.numpy as jnp
import jax.nn as nn
import math

# Assuming query_states, key_states, value_states are initially bf16/fp16
# And attn_mask is prepared appropriately

# Force QK dot product and scaling to float32
attn_weights = (query_states.astype(jnp.float32) @ key_states.swapaxes(-2, -1).astype(jnp.float32)) / math.sqrt(self.head_dim)

# Apply mask in float32
if attn_mask is not None:
    causal_mask = attn_mask # Assuming mask is already correctly broadcastable
    # Ensure mask is also float32
    attn_weights = attn_weights.astype(jnp.float32) + causal_mask.astype(jnp.float32)

# Softmax in float32
attn_weights = nn.softmax(attn_weights.astype(jnp.float32), axis=-1)

# Weight values in float32
attn_output = attn_weights @ value_states.astype(jnp.float32)

# Cast final output back to the original lower precision if needed
attn_output = attn_output.astype(jnp.bfloat16) # Or original dtype

While your issue seems directly linked to nnx.remat potentially causing state interference, it's possible that the recomputation process within remat might be exacerbating underlying numerical precision sensitivities in the Qwen attention layers.

Perhaps you could try modifying the attention mechanism within your EasyDeL setup (if possible, or by modifying the source temporarily) to use float32 for the intermediate calculations as shown above, even when gradient_checkpointing is enabled. It's a bit of a long shot since the root causes might be different, but forcing higher precision in sensitive areas like attention sometimes helps stabilize things unexpectedly.

Hope this potentially offers another avenue to investigate or might provide some relief!

demon2036 avatar Apr 02 '25 04:04 demon2036

Thanks @demon2036 for sharing this but we are using FA3 kernels for gpus and Splash for TPUs so attention dtype isn't really the issue But thanks ill double check

erfanzar avatar Apr 05 '25 16:04 erfanzar

Thanks @demon2036 for sharing this but we are using FA3 kernels for gpus and Splash for TPUs so attention dtype isn't really the issue But thanks ill double check

@erfanzar Have you had a chance to try decoding using greedy sampling?

If you use greedy sampling, is the first token generated correct?

If the first token is correct, but the subsequent tokens are incorrect, that might strongly suggest an underlying numerical precision issue. In that scenario, perhaps it could be worth investigating if forcing operations like RMSNorm (and maybe others) to compute in fp32 makes a difference?

demon2036 avatar Apr 06 '25 15:04 demon2036

Is there any update on this issue?

I'm also finding suspicious behavior with nnx.remat. In particular, it scales the gradients per layer by a factor of about 1.3-1.4, which ends up scaling the gradients by a factor of close to 10 after several layers, destroying training.

I just switch from linen to NNX, and the linen remat/checkpoint does not have this problem.

twmitchel avatar Oct 21 '25 19:10 twmitchel

@twmitchel can you please provide a small reproducer of the issue. I can take a look. Thanks!

@erfanzar your reproducer code does not work with the latest flax (0.12.0). I have the following error:

model = ed.AutoEasyDeLModelForCausalLM.from_pretrained( 
...
ValueError: Mesh requires the ndim of its first argument (`devices`) to equal the length of its second argument (`axis_names`), but got devices.ndim == 4 and len(axis_names) == 5.

I also assumed that import flax.nnx as nn, how auto_remap is used in main() ?

vfdev-5 avatar Oct 23 '25 13:10 vfdev-5