Skip to content

Commit a0622b0

Browse files
author
Flax Authors
committed
Merge pull request #4141 from google:nnx-landing-page
PiperOrigin-RevId: 668020773
2 parents 3a9d833 + 4e17aa1 commit a0622b0

File tree

8 files changed

+31
-31
lines changed

8 files changed

+31
-31
lines changed

docs/nnx/index.rst

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11

22
NNX
33
========
4+
.. div:: sd-text-left sd-font-italic
45

6+
**N**\ eural **N**\ etworks for JA\ **X**
57

6-
NNX is a **N**\ eural **N**\ etwork library for JA\ **X** that focuses on providing the best
7-
development experience, so building and experimenting with neural networks is easy and
8-
intuitive. It achieves this by embracing Python’s object-oriented model and making it
9-
compatible with JAX transforms, resulting in code that is easy to inspect, debug, and
10-
analyze.
8+
9+
----
10+
11+
NNX is a new Flax API that is designed to make it easier to create, inspect, debug,
12+
and analyze neural networks in JAX. It achieves this by adding first class support
13+
for Python reference semantics, allowing users to express their models using regular
14+
Python objects. NNX takes years of feedback from Linen and brings to Flax a simpler
15+
and more user-friendly experience.
1116

1217
Features
1318
^^^^^^^^^

docs/nnx/nnx_basics.ipynb

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@
66
"source": [
77
"# NNX Basics\n",
88
"\n",
9-
"NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best \n",
10-
"development experience, so building and experimenting with neural networks is easy and\n",
11-
"intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees), \n",
12-
"enabling reference sharing and mutability. This design allows your models to resemble \n",
13-
"familiar Python object-oriented code, particularly appealing to users of frameworks\n",
14-
"like PyTorch.\n",
15-
"\n",
16-
"Despite its simplified implementation, NNX supports the same powerful design patterns \n",
17-
"that have allowed Linen to scale effectively to large codebases."
9+
"NNX is a new Flax API that is designed to make it easier to create, inspect, debug,\n",
10+
"and analyze neural networks in JAX. It achieves this by adding first class support\n",
11+
"for Python reference semantics, allowing users to express their models using regular\n",
12+
"Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference\n",
13+
"sharing and mutability. This design should should make PyTorch or Keras users feel at\n",
14+
"home."
1815
]
1916
},
2017
{
@@ -68,7 +65,7 @@
6865
}
6966
],
7067
"source": [
71-
"! pip install -U flax treescope"
68+
"# ! pip install -U flax treescope"
7269
]
7370
},
7471
{

docs/nnx/nnx_basics.md

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,17 @@ jupytext:
1010

1111
# NNX Basics
1212

13-
NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best
14-
development experience, so building and experimenting with neural networks is easy and
15-
intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees),
16-
enabling reference sharing and mutability. This design allows your models to resemble
17-
familiar Python object-oriented code, particularly appealing to users of frameworks
18-
like PyTorch.
19-
20-
Despite its simplified implementation, NNX supports the same powerful design patterns
21-
that have allowed Linen to scale effectively to large codebases.
13+
NNX is a new Flax API that is designed to make it easier to create, inspect, debug,
14+
and analyze neural networks in JAX. It achieves this by adding first class support
15+
for Python reference semantics, allowing users to express their models using regular
16+
Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference
17+
sharing and mutability. This design should should make PyTorch or Keras users feel at
18+
home.
2219

2320
```{code-cell} ipython3
2421
:tags: [skip-execution]
2522
26-
! pip install -U flax treescope
23+
# ! pip install -U flax treescope
2724
```
2825

2926
```{code-cell} ipython3

flax/nnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@
125125
from .nnx.variables import (
126126
Param as Param,
127127
register_variable_name_type_pair as register_variable_name_type_pair,
128-
)
128+
)
129129
# this needs to be imported before optimizer to prevent circular import
130130
from .nnx.training import optimizer as optimizer
131131
from .nnx.training.metrics import Metric as Metric

flax/nnx/nnx/rnglib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,10 @@ def split_rngs_wrapper(*args, **kwargs):
437437
key = stream()
438438
backups.append((stream, stream.key.value, stream.count.value))
439439
stream.key.value = jax.random.split(key, splits)
440-
counts_shape = (splits, *stream.count.shape)
440+
if isinstance(splits, int):
441+
counts_shape = (splits, *stream.count.shape)
442+
else:
443+
counts_shape = (*splits, *stream.count.shape)
441444
stream.count.value = jnp.zeros(counts_shape, dtype=jnp.uint32)
442445

443446
return SplitBackups(backups)

flax/nnx/nnx/transforms/iteration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from flax.nnx.nnx.transforms.transforms import resolve_kwargs
4141
from flax.typing import Leaf, MISSING, Missing, PytreeDeque
4242
import jax
43-
from jax._src.tree_util import broadcast_prefix
4443
import jax.core
4544
import jax.numpy as jnp
4645
import jax.stages

flax/nnx/tests/bridge/wrappers_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from functools import partial
1615

1716
from absl.testing import absltest
1817
import flax

flax/nnx/tests/graph_utils_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,8 @@ def __init__(self, dout: int, rngs: nnx.Rngs):
499499
self.rngs = rngs
500500

501501
def __call__(self, x):
502-
503-
@partial(nnx.vmap, in_axes=(0, None), axis_size=5)
502+
@nnx.split_rngs(splits=5)
503+
@nnx.vmap(in_axes=(0, None), axis_size=5)
504504
def vmap_fn(inner, x):
505505
return inner(x)
506506

0 commit comments

Comments
 (0)