Skip to content

Unable to use jax.debug.breakpoint in nnx jitted scope #4735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
puct9 opened this issue Apr 29, 2025 · 3 comments
Open

Unable to use jax.debug.breakpoint in nnx jitted scope #4735

puct9 opened this issue Apr 29, 2025 · 3 comments

Comments

@puct9
Copy link

puct9 commented Apr 29, 2025

System information

  • Rocky Linux 9.4
  • flax==0.10.6 jax==0.6.0 jaxlib==0.6.0
  • Python 3.11
  • H200, 140G memory
  • CUDA 12.8

Problem you have encountered:

Tried to execute this snippet of code

import jax
from flax import nnx

@nnx.jit
def fn(rngs):
    jax.debug.breakpoint()
    return 4

rngs = nnx.Rngs(0)
fn(rngs)

What you expected to happen:

I hoped to get the breakpoint.

Error

Traceback (most recent call last):
  File "/home/miniforge3/envs/rl/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py", line 1244, in lower_jaxpr_to_module
    if not ctx.module.operation.verify():
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.mlir._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
error: "jit(fn)/jit(main)/debug_callback"(callsite("fn"("/anon/issue.py":11:4 to :26) at "<module>"("/anon/issue.py":16:0 to :8))): 'stablehlo.custom_call' op incorrect layout dense<> : tensor<0xindex> for type 'tensor<2xui32>', layout must be a permutation of [0, 1)
 note: "jit(fn)/jit(main)/debug_callback"(callsite("fn"("/anon/issue.py":11:4 to :26) at "<module>"("/anon/issue.py":16:0 to :8))): see current operation: "stablehlo.custom_call"(%arg0, %0) <{api_version = 1 : i32, backend_config = "", call_target_name = "xla_ffi_python_gpu_callback", called_computations = [], has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = []}> {mhlo.backend_config = {index = 0 : ui64}, mhlo.sharding = "{maximal device=0}"} : (tensor<ui32>, tensor<2xui32>) -> ()

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/anon/issue.py", line 16, in <module>
    fn(rngs)
  File "/home/miniforge3/envs/rl/lib/python3.11/site-packages/flax/nnx/transforms/compilation.py", line 350, in jit_wrapper
    pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
                                               ^^^^^^^^^^
ValueError: Cannot lower jaxpr with verifier errors:
        'stablehlo.custom_call' op incorrect layout dense<> : tensor<0xindex> for type 'tensor<2xui32>', layout must be a permutation of [0, 1)
                at loc("jit(fn)/jit(main)/debug_callback"(callsite("fn"("/anon/issue.py":11:4 to :26) at "<module>"("/anon/issue.py":16:0 to :8))))
        see current operation: "stablehlo.custom_call"(%arg0, %0) <{api_version = 1 : i32, backend_config = "", call_target_name = "xla_ffi_python_gpu_callback", called_computations = [], has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = []}> {mhlo.backend_config = {index = 0 : ui64}, mhlo.sharding = "{maximal device=0}"} : (tensor<ui32>, tensor<2xui32>) -> ()
                at loc("jit(fn)/jit(main)/debug_callback"(callsite("fn"("/anon/issue.py":11:4 to :26) at "<module>"("/anon/issue.py":16:0 to :8))))

Steps to reproduce:

Run the code snippet.

@puct9
Copy link
Author

puct9 commented Apr 29, 2025

After some testing, I suspect that the issue is caused by the version of Jax. Downgrading to the next highest version jax==0.5.3 makes the issue go away.

Not comprehensive, but I tested a number of combinations:

Compute Jax version Flax version Python version Result
CPU 0.5.3 0.10.6 3.10, 3.11 ok
CUDA 12 0.5.3 0.10.6 3.11 ok
CPU 0.6.0 0.10.6 3.11 bad
CUDA 12 0.6.0 0.10.6 3.11 bad, and tested on 2 sets of hardware
CUDA 12 0.6.0 (only option for Flax version) 63688067de85261824a0c31f5ff5285f6231c864 3.11 bad

@cgarciae
Copy link
Collaborator

Thanks for reporting this! It does sound like a JAX bug.

@puct9
Copy link
Author

puct9 commented May 4, 2025

Should I raise this in the JAX channel?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants