Skip to content

Problem training model with jax.jvp in forward pass #4636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Thomas-Christie opened this issue Mar 18, 2025 · 2 comments
Open

Problem training model with jax.jvp in forward pass #4636

Thomas-Christie opened this issue Mar 18, 2025 · 2 comments

Comments

@Thomas-Christie
Copy link

Hi - thanks for the library! I'm currently running into some issues when trying to train a "linearised" NN model. In my setup I'm taking a standard NN with weights $\theta^*$, let's denote $f(\cdot;\theta^*)$, and then forming a "linearised" version of it by taking a 1st-order Taylor expansion about $\theta^*$. Mathematically:

$f^\text{lin}_{\theta^*}(x;\theta) = f(x;\theta^*) + J_{\theta} f(x;\theta^*) (\theta - \theta^*) $

I then want to train this "linearised" model, keeping the original weights $\theta^*$ fixed and only updating weights $\theta$ (or alternatively updating the "weight update" $\tau = \theta - \theta^*$). For freezing the weights $\theta^*$ I have tried to adopt the approach mentioned in #4167. I'm able to perform a forward pass with my model fine, but as soon as I try to train the model, I run into issues. An example is as follows:

# Original model
class RegressionModel(nnx.Module):

    def __init__(self, *, rngs: nnx.Rngs) -> None:
        self.linear1 = nnx.Linear(1, 20, rngs=rngs)
        self.linear2 = nnx.Linear(20, 20, rngs=rngs)
        self.linear3 = nnx.Linear(20, 1, rngs=rngs)

    def __call__(self, x):
        x = nnx.tanh(self.linear1(x))
        x = nnx.tanh(self.linear2(x))
        x = self.linear3(x)
        return x


class LinearisedModel(nnx.Module):

    def __init__(self, original_model: nnx.Module) -> None:
        self.graph_def, self.original_weights = nnx.split(original_model)
        self.weight_update = jax.tree.map(
            nnx.Param, jax.tree.map(jnp.zeros_like, self.original_weights)
        )  # Want to be trainable, corresponds to "tau" above
        self.original_weights = jax.tree.map(
            nnx.Param, self.original_weights
        )  # Want to be not trainable

    def __call__(self, x):
        def _model_fn(weights):
            return nnx.call((self.graph_def, weights))(x)[0]

        original_pred, upd = jax.jvp(
            _model_fn,
            (self.original_weights,),
            (self.weight_update,),
        )
        return original_pred + upd


original_model = RegressionModel(rngs=nnx.Rngs(42))
linearised_model = LinearisedModel(original_model)

trainable_params = nnx.All(nnx.Param, nnx.PathContains("weight_update"))
optimizer = nnx.Optimizer(
    linearised_model,
    tx=optax.adamw(3e-4),
    wrt=trainable_params,
)


def train_step(model, optimizer, x, y):
    def loss_fn(model):
        pred = model(x)
        return optax.squared_error(pred, y).mean()

    diff_state = nnx.DiffState(0, trainable_params)
    grads = nnx.grad(loss_fn, argnums=diff_state)(model)
    optimizer.update(grads)


x = jnp.ones((1, 1))
y = jnp.ones((1, 1))
train_step(linearised_model, optimizer, x, y)

The error I'm getting is:

TypeError: Argument 'Param(
  value=Traced<ConcreteArray([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
    primal = Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0.]], dtype=float32)
    tangent = Traced<ShapedArray(float32[1,20])>with<JaxprTrace(level=1/0)> with
      pval = (ShapedArray(float32[1,20]), None)
      recipe = LambdaBinding()
)' of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.

and from the stack trace it seems to be stemming from the jax.jvp() call.

Can you see what I'm doing wrong? I'm happy to provide more details if necessary. A PyTorch version of what I'm trying to achieve can be found here. Thanks in advance!

@cgarciae
Copy link
Collaborator

Hey! The issue is that jax.jvp doesn't handle nnx.Variable instances (like Param) correctly. To fix it call nnx.state before passing them, I tested this runs on your code:

original_pred, upd = jax.jvp(
  _model_fn,
  (nnx.state(self.original_weights),),
  (nnx.state(self.weight_update),),
)

I think we just need to add nnx.vjp to make this cleaner, I've seen a couple of users that need it already.

@Thomas-Christie
Copy link
Author

Great, thanks, that worked! And yeah, I think having nnx.jvp would probably be a bit more intuitive - I actually searched to see if this existed whilst debugging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants