Gradient Checkpointing causes model to compute junk results (NNX)
System information
- Flax, JAX, JAXlib versions: (0.10.4)
- Python version: (e.g., 3.10.2)
Problem you have encountered:
When using gradient checkpointing in an flax.nnx model (nnx.remat), the model generates incorrect or junk results. This happens on both GPU and TPU. If two models are loaded:
- Model 1: Uses gradient checkpointing (
EasyDeLGradientCheckPointers.NOTHING_SAVEABLE). - Model 2: Does not use gradient checkpointing.
Both models will generate junk results. However, if only Model 2 (without checkpointing) is created and used, it works correctly.
What you expected to happen:
Each model should independently function correctly, and the activation checkpointing (remat) should not corrupt inference outputs when applied to a separate model instance.
Logs, error messages, etc:
(Provide any logs, traceback, or error messages if available.)
Steps to reproduce:
A minimal reproducible example is given below. Changing gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE to EasyDeLGradientCheckPointers.NONE resolves the issue.
Code snippet to reproduce the issue:
import easydel as ed
import jax
import transformers
from jax import numpy as jnp
def auto_remat(
*modules: tp.Type[M],
policy: tp.Union[
EasyDeLGradientCheckPointers, str
] = EasyDeLGradientCheckPointers.NONE,
prevent_cse: bool = True,
) -> tp.Tuple[tp.Type[M], ...]:
if policy == EasyDeLGradientCheckPointers.NONE:
return modules
if isinstance(policy, str):
policy = get_gradient_checkpoint_policy(policy)
outs = ()
for module in modules:
assert issubclass(module, nn.Module)
static_argnums = extract_static_parameters(module=module)
if static_argnums is None:
static_argnums = ()
module.__call__ = nn.remat(
f=module.__call__,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
)
outs += (module,)
return outs
def main():
sharding_axis_dims = (1, 1, 1, -1)
prefill_length = 512
max_new_tokens = 128
max_length = max_new_tokens + prefill_length
pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"
dtype = param_dtype = jnp.bfloat16
partition_axis = ed.PartitionAxis()
tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
auto_shard_model=True,
sharding_axis_dims=sharding_axis_dims,
config_kwargs=ed.EasyDeLBaseConfigDict(
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
attn_dtype=jnp.bfloat16,
attn_mechanism=ed.AttentionMechanisms.AUTO,
),
quantization_method=ed.EasyDeLQuantizationMethods.NONE,
param_dtype=param_dtype,
dtype=dtype,
partition_axis=partition_axis,
precision=jax.lax.Precision.DEFAULT,
)
inference = ed.vInference(
model=model,
processor_class=tokenizer,
generation_config=ed.vInferenceConfig(
max_new_tokens=max_new_tokens,
temperature=0.8,
do_sample=True,
top_p=0.95,
top_k=10,
eos_token_id=model.generation_config.eos_token_id,
streaming_chunks=32,
num_return_sequences=1,
),
)
inference.precompile(
ed.vInferencePreCompileConfig(
batch_size=1,
prefill_length=prefill_length,
)
)
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "write 10 lines story about why you love EasyDeL"},
]
ids = tokenizer.apply_chat_template(
messages,
return_tensors="jax",
return_dict=True,
add_generation_prompt=True,
)
print("Start Generation Process.")
for response in inference.generate(**ids):
...
print(
tokenizer.batch_decode(
response.sequences[..., response.padded_length :],
skip_special_tokens=True,
)
)
print(response.tokens_pre_second)
if __name__ == "__main__":
main()
Workarounds Tried:
- Setting
gradient_checkpointing=EasyDeLGradientCheckPointers.NONEfixes the issue. - Ensuring that only one model (either with or without
remat) is instantiated at a time prevents the corruption. - The issue only occurs when both models exist in memory simultaneously.
Possible Cause:
nn.rematmight be affecting global state shared across models.- Memory corruption or state retention in
flax.nnxaffecting subsequent inference.
Additional Notes:
- Would need further debugging into
nn.remathandling inflax.nnx. - Possible scope leakage between checkpointed and non-checkpointed models.
Expected Fix: Ensure that gradient checkpointing via nnx.remat does not interfere with models that do not use checkpointing in the same session.
Any update or help on this?
EasyDeL seems to be the only jax implementation out there that supports GRPO training.
This is a blocking bug for my research. Is anyone else able to reproduce this? At least confirm that it is actually a Flax related bug?
Thanks for reporting this @erfanzar. Would it be possible for you to create a test case where nnx.remat fails?
nnx.remat is not doing whole lot except forwarding the underlying state to JAX.
Hi @erfanzar ,
I've encountered similar issues with generating junk results when working with Qwen models, although perhaps for slightly different underlying reasons.
In my case, using my own library implementation, I found that the junk results with Qwen seemed related to numerical stability, particularly within the attention mechanism when using bfloat16. It appeared that some operations were becoming unstable at lower precision.
My workaround was to force key parts of the attention calculation to run in float32 precision, even if the rest of the model used bfloat16. This stabilized the computation and resolved the junk output issue for me.
Here's the relevant snippet from my attention implementation:
import jax.numpy as jnp
import jax.nn as nn
import math
# Assuming query_states, key_states, value_states are initially bf16/fp16
# And attn_mask is prepared appropriately
# Force QK dot product and scaling to float32
attn_weights = (query_states.astype(jnp.float32) @ key_states.swapaxes(-2, -1).astype(jnp.float32)) / math.sqrt(self.head_dim)
# Apply mask in float32
if attn_mask is not None:
causal_mask = attn_mask # Assuming mask is already correctly broadcastable
# Ensure mask is also float32
attn_weights = attn_weights.astype(jnp.float32) + causal_mask.astype(jnp.float32)
# Softmax in float32
attn_weights = nn.softmax(attn_weights.astype(jnp.float32), axis=-1)
# Weight values in float32
attn_output = attn_weights @ value_states.astype(jnp.float32)
# Cast final output back to the original lower precision if needed
attn_output = attn_output.astype(jnp.bfloat16) # Or original dtype
While your issue seems directly linked to nnx.remat potentially causing state interference, it's possible that the recomputation process within remat might be exacerbating underlying numerical precision sensitivities in the Qwen attention layers.
Perhaps you could try modifying the attention mechanism within your EasyDeL setup (if possible, or by modifying the source temporarily) to use float32 for the intermediate calculations as shown above, even when gradient_checkpointing is enabled. It's a bit of a long shot since the root causes might be different, but forcing higher precision in sensitive areas like attention sometimes helps stabilize things unexpectedly.
Hope this potentially offers another avenue to investigate or might provide some relief!
Thanks @demon2036 for sharing this but we are using FA3 kernels for gpus and Splash for TPUs so attention dtype isn't really the issue But thanks ill double check
Thanks @demon2036 for sharing this but we are using FA3 kernels for gpus and Splash for TPUs so attention dtype isn't really the issue But thanks ill double check
@erfanzar Have you had a chance to try decoding using greedy sampling?
If you use greedy sampling, is the first token generated correct?
If the first token is correct, but the subsequent tokens are incorrect, that might strongly suggest an underlying numerical precision issue. In that scenario, perhaps it could be worth investigating if forcing operations like RMSNorm (and maybe others) to compute in fp32 makes a difference?
Is there any update on this issue?
I'm also finding suspicious behavior with nnx.remat. In particular, it scales the gradients per layer by a factor of about 1.3-1.4, which ends up scaling the gradients by a factor of close to 10 after several layers, destroying training.
I just switch from linen to NNX, and the linen remat/checkpoint does not have this problem.
@twmitchel can you please provide a small reproducer of the issue. I can take a look. Thanks!
@erfanzar your reproducer code does not work with the latest flax (0.12.0). I have the following error:
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
...
ValueError: Mesh requires the ndim of its first argument (`devices`) to equal the length of its second argument (`axis_names`), but got devices.ndim == 4 and len(axis_names) == 5.
I also assumed that import flax.nnx as nn, how auto_remap is used in main() ?