Description
System information
- Flax, JAX, JAXlib versions: (0.10.4)
- Python version: (e.g., 3.10.2)
Problem you have encountered:
When using gradient checkpointing in an flax.nnx
model (nnx.remat
), the model generates incorrect or junk results. This happens on both GPU and TPU. If two models are loaded:
- Model 1: Uses gradient checkpointing (
EasyDeLGradientCheckPointers.NOTHING_SAVEABLE
). - Model 2: Does not use gradient checkpointing.
Both models will generate junk results. However, if only Model 2 (without checkpointing) is created and used, it works correctly.
What you expected to happen:
Each model should independently function correctly, and the activation checkpointing (remat
) should not corrupt inference outputs when applied to a separate model instance.
Logs, error messages, etc:
(Provide any logs, traceback, or error messages if available.)
Steps to reproduce:
A minimal reproducible example is given below. Changing gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE
to EasyDeLGradientCheckPointers.NONE
resolves the issue.
Code snippet to reproduce the issue:
import easydel as ed
import jax
import transformers
from jax import numpy as jnp
def auto_remat(
*modules: tp.Type[M],
policy: tp.Union[
EasyDeLGradientCheckPointers, str
] = EasyDeLGradientCheckPointers.NONE,
prevent_cse: bool = True,
) -> tp.Tuple[tp.Type[M], ...]:
if policy == EasyDeLGradientCheckPointers.NONE:
return modules
if isinstance(policy, str):
policy = get_gradient_checkpoint_policy(policy)
outs = ()
for module in modules:
assert issubclass(module, nn.Module)
static_argnums = extract_static_parameters(module=module)
if static_argnums is None:
static_argnums = ()
module.__call__ = nn.remat(
f=module.__call__,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
)
outs += (module,)
return outs
def main():
sharding_axis_dims = (1, 1, 1, -1)
prefill_length = 512
max_new_tokens = 128
max_length = max_new_tokens + prefill_length
pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"
dtype = param_dtype = jnp.bfloat16
partition_axis = ed.PartitionAxis()
tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
auto_shard_model=True,
sharding_axis_dims=sharding_axis_dims,
config_kwargs=ed.EasyDeLBaseConfigDict(
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
attn_dtype=jnp.bfloat16,
attn_mechanism=ed.AttentionMechanisms.AUTO,
),
quantization_method=ed.EasyDeLQuantizationMethods.NONE,
param_dtype=param_dtype,
dtype=dtype,
partition_axis=partition_axis,
precision=jax.lax.Precision.DEFAULT,
)
inference = ed.vInference(
model=model,
processor_class=tokenizer,
generation_config=ed.vInferenceConfig(
max_new_tokens=max_new_tokens,
temperature=0.8,
do_sample=True,
top_p=0.95,
top_k=10,
eos_token_id=model.generation_config.eos_token_id,
streaming_chunks=32,
num_return_sequences=1,
),
)
inference.precompile(
ed.vInferencePreCompileConfig(
batch_size=1,
prefill_length=prefill_length,
)
)
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "write 10 lines story about why you love EasyDeL"},
]
ids = tokenizer.apply_chat_template(
messages,
return_tensors="jax",
return_dict=True,
add_generation_prompt=True,
)
print("Start Generation Process.")
for response in inference.generate(**ids):
...
print(
tokenizer.batch_decode(
response.sequences[..., response.padded_length :],
skip_special_tokens=True,
)
)
print(response.tokens_pre_second)
if __name__ == "__main__":
main()
Workarounds Tried:
- Setting
gradient_checkpointing=EasyDeLGradientCheckPointers.NONE
fixes the issue. - Ensuring that only one model (either with or without
remat
) is instantiated at a time prevents the corruption. - The issue only occurs when both models exist in memory simultaneously.
Possible Cause:
nn.remat
might be affecting global state shared across models.- Memory corruption or state retention in
flax.nnx
affecting subsequent inference.
Additional Notes:
- Would need further debugging into
nn.remat
handling inflax.nnx
. - Possible scope leakage between checkpointed and non-checkpointed models.
Expected Fix: Ensure that gradient checkpointing via nnx.remat
does not interfere with models that do not use checkpointing in the same session.