How and when to use nnx.fori_loop and nnx.scan to stack layers? #4726
Unanswered
lbergmann1
asked this question in
Q&A
Replies: 1 comment
-
Hi @lbergmann1, The main use case for |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
I am currently trying to optimize my JAX/NNX code and I am wondering how to correctly use
nnx.scan
andnnx.fori_loop
to stack layers.Here is some example code to create a simple
MLP
:Here are my questions:
nnx.fori_loop
ornnx.scan
, especially depending on the number of iterations? If I understand the JAX documentation correctly, these transforms should reduce compilation times for JIT-compiled functions compared to a python for-loop, especially for loops with many iterations. But are there other advantages (performance, memory usage, ...)?nnx.fori_loop
crashes, because the index variable is a tracer. The NNX documentation says thatnnx.Module
s can be composed usinglist
s. Is it not intended to usennx.fori_loop
to iterate overnnx.Module
s, or am I missing something?nnx.scan
works, but I would expect to get two linear layers (seeself.model
). However,self.model
is just a single linear layer. So, I assume that the code is not correct? If so, what would be the correct way to usennx.scan
in this example?Many thanks in advance☺️
Beta Was this translation helpful? Give feedback.
All reactions