Skip to content

Gradient Checkpointing causes model to compute junk results (NNX) #4626

Open
@erfanzar

Description

@erfanzar

System information

  • Flax, JAX, JAXlib versions: (0.10.4)
  • Python version: (e.g., 3.10.2)

Problem you have encountered:

When using gradient checkpointing in an flax.nnx model (nnx.remat), the model generates incorrect or junk results. This happens on both GPU and TPU. If two models are loaded:

  1. Model 1: Uses gradient checkpointing (EasyDeLGradientCheckPointers.NOTHING_SAVEABLE).
  2. Model 2: Does not use gradient checkpointing.

Both models will generate junk results. However, if only Model 2 (without checkpointing) is created and used, it works correctly.

What you expected to happen:

Each model should independently function correctly, and the activation checkpointing (remat) should not corrupt inference outputs when applied to a separate model instance.

Logs, error messages, etc:

(Provide any logs, traceback, or error messages if available.)

Steps to reproduce:

A minimal reproducible example is given below. Changing gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE to EasyDeLGradientCheckPointers.NONE resolves the issue.

Code snippet to reproduce the issue:

import easydel as ed
import jax
import transformers
from jax import numpy as jnp

def auto_remat(
    *modules: tp.Type[M],
    policy: tp.Union[
        EasyDeLGradientCheckPointers, str
    ] = EasyDeLGradientCheckPointers.NONE,
    prevent_cse: bool = True,
) -> tp.Tuple[tp.Type[M], ...]:
    if policy == EasyDeLGradientCheckPointers.NONE:
        return modules
    if isinstance(policy, str):
        policy = get_gradient_checkpoint_policy(policy)
    outs = ()
    for module in modules:
        assert issubclass(module, nn.Module)
        static_argnums = extract_static_parameters(module=module)
        if static_argnums is None:
            static_argnums = ()

        module.__call__ = nn.remat(
            f=module.__call__,
            prevent_cse=prevent_cse,
            static_argnums=static_argnums,
            policy=policy,
        )
        outs += (module,)
    return outs

def main():
    sharding_axis_dims = (1, 1, 1, -1)
    prefill_length = 512
    max_new_tokens = 128
    max_length = max_new_tokens + prefill_length
    pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"

    dtype = param_dtype = jnp.bfloat16
    partition_axis = ed.PartitionAxis()
    tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    tokenizer.padding_side = "left"
    tokenizer.pad_token_id = tokenizer.eos_token_id

    model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        auto_shard_model=True,
        sharding_axis_dims=sharding_axis_dims,
        config_kwargs=ed.EasyDeLBaseConfigDict(
            freq_max_position_embeddings=max_length,
            mask_max_position_embeddings=max_length,
            kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
            gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
            attn_dtype=jnp.bfloat16,
            attn_mechanism=ed.AttentionMechanisms.AUTO,
        ),
        quantization_method=ed.EasyDeLQuantizationMethods.NONE,
        param_dtype=param_dtype,
        dtype=dtype,
        partition_axis=partition_axis,
        precision=jax.lax.Precision.DEFAULT,
    )

    inference = ed.vInference(
        model=model,
        processor_class=tokenizer,
        generation_config=ed.vInferenceConfig(
            max_new_tokens=max_new_tokens,
            temperature=0.8,
            do_sample=True,
            top_p=0.95,
            top_k=10,
            eos_token_id=model.generation_config.eos_token_id,
            streaming_chunks=32,
            num_return_sequences=1,
        ),
    )

    inference.precompile(
        ed.vInferencePreCompileConfig(
            batch_size=1,
            prefill_length=prefill_length,
        )
    )

    messages = [
        {"role": "system", "content": "You are a helpful AI assistant."},
        {"role": "user", "content": "write 10 lines story about why you love EasyDeL"},
    ]

    ids = tokenizer.apply_chat_template(
        messages,
        return_tensors="jax",
        return_dict=True,
        add_generation_prompt=True,
    )
    print("Start Generation Process.")
    for response in inference.generate(**ids):
        ...
    print(
        tokenizer.batch_decode(
            response.sequences[..., response.padded_length :],
            skip_special_tokens=True,
        )
    )
    print(response.tokens_pre_second)

if __name__ == "__main__":
    main()

Workarounds Tried:

  • Setting gradient_checkpointing=EasyDeLGradientCheckPointers.NONE fixes the issue.
  • Ensuring that only one model (either with or without remat) is instantiated at a time prevents the corruption.
  • The issue only occurs when both models exist in memory simultaneously.

Possible Cause:

  • nn.remat might be affecting global state shared across models.
  • Memory corruption or state retention in flax.nnx affecting subsequent inference.

Additional Notes:

  • Would need further debugging into nn.remat handling in flax.nnx.
  • Possible scope leakage between checkpointed and non-checkpointed models.

Expected Fix: Ensure that gradient checkpointing via nnx.remat does not interfere with models that do not use checkpointing in the same session.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions