Skip to content

Commit c84510d

Browse files
author
Flax Authors
committed
Merge pull request #3936 from google:nnx-stabilize
PiperOrigin-RevId: 636951986
2 parents b1cb952 + 67fa051 commit c84510d

File tree

129 files changed

+451
-385
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

129 files changed

+451
-385
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ vNext
8282

8383
0.8.0
8484
-----
85-
- Added [NNX](https://github.com/google/flax/tree/main/flax/experimental/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch.
85+
- Added [NNX](https://github.com/google/flax/tree/main/flax/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch.
8686
- Added `nn.compact_name_scope` decorator that enables methods to act as compact name scopes as with regular Haiku methods. This makes porting Haiku code easier.
8787
- Add copy() method to Module. This is a user-friendly version of the internal clone() method with better
8888
defaults for common use cases.

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
| [**What does Flax look like?**](#what-does-flax-look-like)
1313
| [**Documentation**](https://flax.readthedocs.io/)
1414

15+
**📣 NEW**: Check out the [**NNX**](https://flax.readthedocs.io/en/latest/nnx/index.html) API!
16+
1517
This README is a very short intro. **To learn everything you need to know about Flax, refer to our [full documentation](https://flax.readthedocs.io/).**
1618

1719
Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.

docs/api_reference/flax.experimental.nnx/nn/stochastic.rst

Lines changed: 0 additions & 8 deletions
This file was deleted.

docs/api_reference/flax.experimental.nnx/training/optimizer.rst

Lines changed: 0 additions & 8 deletions
This file was deleted.

docs/api_reference/flax.experimental.nnx/visualization.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

docs/api_reference/flax.experimental.nnx/graph.rst renamed to docs/api_reference/flax.nnx/graph.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
graph
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77

88
.. autofunction:: split

docs/api_reference/flax.experimental.nnx/helpers.rst renamed to docs/api_reference/flax.nnx/helpers.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
helpers
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77
.. autoclass:: Dict
88
:members:

docs/api_reference/flax.experimental.nnx/index.rst renamed to docs/api_reference/flax.nnx/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
flax.experimental.nnx
1+
flax.nnx
22
------------------------
33

4-
Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/experimental/nnx/index.html>`__ for more details.
4+
Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for more details.
55

66
.. toctree::
77
:maxdepth: 3

docs/api_reference/flax.experimental.nnx/module.rst renamed to docs/api_reference/flax.nnx/module.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77
.. autoclass:: Module
88
:members:

docs/api_reference/flax.experimental.nnx/nn/activations.rst renamed to docs/api_reference/flax.nnx/nn/activations.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Activation functions
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77
.. autofunction:: celu
88
.. autofunction:: elu

docs/api_reference/flax.experimental.nnx/nn/attention.rst renamed to docs/api_reference/flax.nnx/nn/attention.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Attention
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77
.. autoclass:: MultiHeadAttention
88
:members:

docs/api_reference/flax.experimental.nnx/nn/index.rst renamed to docs/api_reference/flax.nnx/nn/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
nn
22
----------------------------
33

4-
Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/experimental/nnx/index.html>`__ for more details.
4+
Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for more details.
55

66
.. toctree::
77
:maxdepth: 3

docs/api_reference/flax.experimental.nnx/nn/initializers.rst renamed to docs/api_reference/flax.nnx/nn/initializers.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Initializers
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx.initializers
5-
.. currentmodule:: flax.experimental.nnx.initializers
4+
.. automodule:: flax.nnx.initializers
5+
.. currentmodule:: flax.nnx.initializers
66

77
.. autofunction:: constant
88
.. autofunction:: delta_orthogonal

docs/api_reference/flax.experimental.nnx/nn/linear.rst renamed to docs/api_reference/flax.nnx/nn/linear.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ Linear
33

44
NNX linear layer classes.
55

6-
.. automodule:: flax.experimental.nnx
7-
.. currentmodule:: flax.experimental.nnx
6+
.. automodule:: flax.nnx
7+
.. currentmodule:: flax.nnx
88

99
.. autoclass:: Conv
1010
:members:

docs/api_reference/flax.experimental.nnx/nn/normalization.rst renamed to docs/api_reference/flax.nnx/nn/normalization.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Normalization
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77
.. autoclass:: BatchNorm
88
:members:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Stochastic
2+
------------------------
3+
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
6+
7+
.. autoclass:: Dropout
8+
:members:

docs/api_reference/flax.experimental.nnx/rnglib.rst renamed to docs/api_reference/flax.nnx/rnglib.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
rnglib
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77
.. autoclass:: Rngs
88
:members:

docs/api_reference/flax.experimental.nnx/spmd.rst renamed to docs/api_reference/flax.nnx/spmd.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
spmd
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77
.. autofunction:: get_partition_spec
88
.. autofunction:: get_named_sharding

docs/api_reference/flax.experimental.nnx/training/index.rst renamed to docs/api_reference/flax.nnx/training/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
training
22
----------------------------
33

4-
Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/experimental/nnx/index.html>`__ for more details.
4+
Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for more details.
55

66
.. toctree::
77
:maxdepth: 3

docs/api_reference/flax.experimental.nnx/training/metrics.rst renamed to docs/api_reference/flax.nnx/training/metrics.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Metrics
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx.metrics
5-
.. currentmodule:: flax.experimental.nnx.metrics
4+
.. automodule:: flax.nnx.metrics
5+
.. currentmodule:: flax.nnx.metrics
66

77
.. autoclass:: Metric
88
:members:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Optimizer
2+
------------------------
3+
4+
.. automodule:: flax.nnx.optimizer
5+
.. currentmodule:: flax.nnx.optimizer
6+
7+
.. autoclass:: Optimizer
8+
:members:

docs/api_reference/flax.experimental.nnx/transforms.rst renamed to docs/api_reference/flax.nnx/transforms.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
transforms
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77
.. autoclass:: JIT
88
:members:

docs/api_reference/flax.experimental.nnx/variables.rst renamed to docs/api_reference/flax.nnx/variables.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
variables
22
------------------------
33

4-
.. automodule:: flax.experimental.nnx
5-
.. currentmodule:: flax.experimental.nnx
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
66

77
.. autoclass:: BatchStat
88
:members:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
visualization
2+
------------------------
3+
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
6+
7+
.. autofunction:: display

docs/api_reference/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ API Reference
88
flax.core.frozen_dict
99
flax.cursor
1010
flax.errors
11-
flax.experimental.nnx/index
11+
flax.nnx/index
1212
flax.jax_utils
1313
flax.linen/index
1414
flax.serialization

docs/conf.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@
110110

111111
html_extra_path = ['robots.txt']
112112

113+
# href with no underline and white bold text color
114+
announcement = """
115+
<a
116+
href="https://flax.readthedocs.io/en/latest/nnx/index.html"
117+
style="text-decoration: none; color: white;"
118+
>
119+
📣 Check out the new <b>NNX</b> API!
120+
</a>
121+
"""
122+
113123
html_theme_options = {
114124
'repository_url': 'https://github.com/google/flax',
115125
'use_repository_button': True, # add a 'link to repository' button
@@ -122,6 +132,7 @@
122132
},
123133
'prev_next_buttons_location': None,
124134
'show_navbar_depth': 1,
135+
'announcement': announcement,
125136
}
126137

127138
# -- Options for myst ----------------------------------------------
@@ -135,7 +146,7 @@
135146
nb_execution_excludepatterns = [
136147
'quick_start.ipynb', # <-- times out
137148
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
138-
'flax/experimental/nnx', # exclude nnx
149+
'flax/nnx', # exclude nnx
139150
]
140151
# raise exceptions on execution so CI can catch errors
141152
nb_execution_allow_errors = False
@@ -151,7 +162,7 @@
151162
doctest_global_setup = """
152163
import jax
153164
import jax.numpy as jnp
154-
from flax.experimental import nnx
165+
from flax import nnx
155166
156167
import logging as slog
157168
from absl import logging as alog

docs/experimental/index.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

docs/guides/flax_fundamentals/flax_basics.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,7 @@
951951
"source": [
952952
"### Exporting to Tensorflow's SavedModel with jax2tf\n",
953953
"\n",
954-
"JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax."
954+
"JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax."
955955
]
956956
}
957957
],

docs/guides/flax_fundamentals/flax_basics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,4 +469,4 @@ Flax provides a handy wrapper - `TrainState` - that simplifies the above code. C
469469

470470
### Exporting to Tensorflow's SavedModel with jax2tf
471471

472-
JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax.
472+
JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax.

docs/index.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ both in the open source community
2828
(like `Hugging Face <https://huggingface.co/flax-community>`__)
2929
and at Google
3030
(like
31-
`PaLM <https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html>`__,
31+
`Gemini <https://deepmind.google/technologies/gemini>`__,
3232
`Imagen <https://imagen.research.google>`__,
3333
`Scenic <https://github.com/google-research/scenic/>`__,
3434
and `Big Vision <https://github.com/google-research/big_vision>`__).
@@ -309,6 +309,8 @@ Notable examples in Flax include:
309309

310310

311311

312+
.. role:: bold
313+
:class: bold
312314

313315
.. toctree::
314316
:hidden:
@@ -325,4 +327,4 @@ Notable examples in Flax include:
325327
contributing
326328
experimental
327329
api_reference/index
328-
experimental/index
330+
NNX <nnx/index>

0 commit comments

Comments
 (0)