Skip to content

Cannot nnx.eval_shape Modules with BatchNorm #4734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
LarsKue opened this issue Apr 29, 2025 · 1 comment
Open

Cannot nnx.eval_shape Modules with BatchNorm #4734

LarsKue opened this issue Apr 29, 2025 · 1 comment

Comments

@LarsKue
Copy link

LarsKue commented Apr 29, 2025

For very complex models, it is usually nice to initialize some feature dimensions automatically from output shape data. nnx.eval_shape facilitates this by allowing users to compute the output shape of model components without expensive FLOPs or memory allocation.

Unfortunately, any model component that uses BatchNorm will not be able to be evaluated as such. See minimal example below.

It would be nice if this could be supported! The obvious work-around is to pass a concrete array, although this can obviously be quite expensive, especially pre-jit.

System information

  • OS: Ubuntu 24.04
Flax, jax, jaxlib versions (obtain with `pip show flax jax jaxlib`)
Name: flax
Version: 0.10.6
Location: /home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, pyyaml, rich, tensorstore, treescope, typing-extensions
Required-by: bnn-sbi
---
Name: jax
Version: 0.6.0
Location: /home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages
Requires: jaxlib, ml-dtypes, numpy, opt-einsum, scipy
Required-by: bnn-sbi, chex, flax, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.6.0
Location: /home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, jax, optax
  • Python version: 3.11.11
  • GPU/TPU model and memory: -
  • CUDA version (if applicable): -

Problem you have encountered:

See above.

What you expected to happen:

Evaluating the shape of BatchNorm should simply return the input shape. Any model that is (partially) composed of BatchNorm should also not encounter this error.

Logs, error messages, etc:

Full traceback
Traceback (most recent call last):
  File "/home/lars/Documents/code/python/bnn-sbi/playground.py", line 14, in <module>
    mock_output = nnx.eval_shape(model, mock_input)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/flax/nnx/transforms/transforms.py", line 143, in eval_shape
    out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/api.py", line 2840, in eval_shape
    return jit(fun).eval_shape(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 401, in jit_eval_shape
    p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 716, in _infer_params
    return _infer_params_internal(fun, ji, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 739, in _infer_params_internal
    p, args_flat = _infer_params_impl(
                   ^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 636, in _infer_params_impl
    jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
                                              ^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/linear_util.py", line 477, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1441, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 334, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2223, in trace_to_jaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/api_util.py", line 288, in _argnums_partial
    return _fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/jax/_src/linear_util.py", line 402, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/flax/nnx/transforms/transforms.py", line 140, in _eval_shape_fn
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/flax/nnx/nn/normalization.py", line 362, in __call__
    self.mean.value = (
    ^^^^^^^^^^^^^^^
  File "/home/lars/Documents/code/python/bnn-sbi/.venv/lib/python3.11/site-packages/flax/nnx/variablelib.py", line 167, in __setattr__
    raise errors.TraceContextError(
flax.errors.TraceContextError: Cannot mutate BatchStat from a different trace level (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)

Steps to reproduce:

Minimal example:

import jax
import flax.nnx as nnx

rngs = nnx.Rngs(0)
model = nnx.BatchNorm(2, rngs=rngs)

mock_input = jax.ShapeDtypeStruct((32, 2), "float32")
mock_output = nnx.eval_shape(model, mock_input)
>>> flax.errors.TraceContextError: Cannot mutate BatchStat from a different trace level
@cgarciae
Copy link
Collaborator

The issue is that you're passing model as a capture, the solution is to pass in model explicitly:

mock_output = nnx.eval_shape(lambda m, x: m(x), model, mock_input)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants