You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/home/miniforge3/envs/rl/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py", line 1244, in lower_jaxpr_to_module
if not ctx.module.operation.verify():
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.mlir._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
error: "jit(fn)/jit(main)/debug_callback"(callsite("fn"("/anon/issue.py":11:4 to :26) at "<module>"("/anon/issue.py":16:0 to :8))): 'stablehlo.custom_call' op incorrect layout dense<> : tensor<0xindex> for type 'tensor<2xui32>', layout must be a permutation of [0, 1)
note: "jit(fn)/jit(main)/debug_callback"(callsite("fn"("/anon/issue.py":11:4 to :26) at "<module>"("/anon/issue.py":16:0 to :8))): see current operation: "stablehlo.custom_call"(%arg0, %0) <{api_version = 1 : i32, backend_config = "", call_target_name = "xla_ffi_python_gpu_callback", called_computations = [], has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = []}> {mhlo.backend_config = {index = 0 : ui64}, mhlo.sharding = "{maximal device=0}"} : (tensor<ui32>, tensor<2xui32>) -> ()
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/anon/issue.py", line 16, in <module>
fn(rngs)
File "/home/miniforge3/envs/rl/lib/python3.11/site-packages/flax/nnx/transforms/compilation.py", line 350, in jit_wrapper
pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
^^^^^^^^^^
ValueError: Cannot lower jaxpr with verifier errors:
'stablehlo.custom_call' op incorrect layout dense<> : tensor<0xindex> for type 'tensor<2xui32>', layout must be a permutation of [0, 1)
at loc("jit(fn)/jit(main)/debug_callback"(callsite("fn"("/anon/issue.py":11:4 to :26) at "<module>"("/anon/issue.py":16:0 to :8))))
see current operation: "stablehlo.custom_call"(%arg0, %0) <{api_version = 1 : i32, backend_config = "", call_target_name = "xla_ffi_python_gpu_callback", called_computations = [], has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = []}> {mhlo.backend_config = {index = 0 : ui64}, mhlo.sharding = "{maximal device=0}"} : (tensor<ui32>, tensor<2xui32>) -> ()
at loc("jit(fn)/jit(main)/debug_callback"(callsite("fn"("/anon/issue.py":11:4 to :26) at "<module>"("/anon/issue.py":16:0 to :8))))
Steps to reproduce:
Run the code snippet.
The text was updated successfully, but these errors were encountered:
After some testing, I suspect that the issue is caused by the version of Jax. Downgrading to the next highest version jax==0.5.3 makes the issue go away.
Not comprehensive, but I tested a number of combinations:
System information
flax==0.10.6
jax==0.6.0
jaxlib==0.6.0
Problem you have encountered:
Tried to execute this snippet of code
What you expected to happen:
I hoped to get the breakpoint.
Error
Steps to reproduce:
Run the code snippet.
The text was updated successfully, but these errors were encountered: