Skip to content

Commit 8292d9c

Browse files
author
Flax Authors
committed
Merge pull request #4346 from google:update-state-docstrings
PiperOrigin-RevId: 691845679
2 parents b8bdafb + 13b4077 commit 8292d9c

File tree

4 files changed

+25
-19
lines changed

4 files changed

+25
-19
lines changed

docs_nnx/api_reference/flax.nnx/graph.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ graph
1010
.. autofunction:: update
1111
.. autofunction:: pop
1212
.. autofunction:: state
13+
.. autofunction:: variables
1314
.. autofunction:: graph
1415
.. autofunction:: graphdef
1516
.. autofunction:: iter_graph

flax/nnx/graph.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,25 @@ def variables(
13511351
node,
13521352
*filters: filterlib.Filter,
13531353
) -> tp.Union[State[Key, Variable], tuple[State[Key, Variable], ...]]:
1354+
"""Similar to :func:`state` but returns the current :class:`Variable` objects instead
1355+
of new :class:`VariableState` instances.
1356+
1357+
Example::
1358+
1359+
>>> from flax import nnx
1360+
...
1361+
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
1362+
>>> params = nnx.variables(model, nnx.Param)
1363+
...
1364+
>>> assert params['kernel'] is model.kernel
1365+
>>> assert params['bias'] is model.bias
1366+
1367+
Args:
1368+
node: A graph node object.
1369+
*filters: One or more :class:`Variable` objects to filter by.
1370+
Returns:
1371+
One or more :class:`State` mappings containing the :class:`Variable` objects.
1372+
"""
13541373
num_filters = len(filters)
13551374
if num_filters == 0:
13561375
filters = (..., ...)

flax/nnx/statelib.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,10 @@ def __treescope_repr__(self, path, subtree_renderer):
5656

5757

5858
class State(MutableMapping[K, V], reprlib.Representable):
59-
"""A pytree-like structure that contains a ``Mapping`` from strings or
60-
integers to leaves. A valid leaf type is either :class:`Variable`,
61-
``jax.Array``, ``numpy.ndarray`` or nested ``State``'s. A ``State``
62-
can be generated by either calling :func:`split` or :func:`state` on
63-
the :class:`Module`."""
59+
"""A pytree-like structure that contains a ``Mapping`` from hashable and
60+
comparable keys to leaves. Leaves can be of any type but :class:`VariableState`
61+
and :class:`Variable` are the most common.
62+
"""
6463

6564
def __init__(
6665
self,

uv.lock

Lines changed: 1 addition & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)