Skip to content

Commit e3bcc44

Browse files
author
Flax Authors
committed
Merge pull request #4499 from google:nnx-improve-module-docs
PiperOrigin-RevId: 718848543
2 parents b5d4ed8 + bdcc33a commit e3bcc44

File tree

19 files changed

+323
-306
lines changed

19 files changed

+323
-306
lines changed

docs/guides/converting_and_upgrading/haiku_migration_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and highlight the differences between the two libraries.
1212
from jax import random
1313
import optax
1414
import flax.linen as nn
15+
import haiku as hk
1516

1617
Basic Example
1718
-----------------

examples/gemma/transformer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_logit_softcap(
150150
all_outputs = []
151151
for config in [config_soft_cap, config_no_soft_cap]:
152152
transformer = transformer_lib.Transformer(
153-
config=config, rngs=nnx.Rngs(params=0)
153+
config=config, rngs=nnx.Rngs(params=1)
154154
)
155155
cache = transformer.init_cache(
156156
cache_size=cache_size,

flax/linen/linear.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,20 +1068,20 @@ class Embed(Module):
10681068
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
10691069
>>> variables = layer.init(jax.random.key(0), indices_input)
10701070
>>> variables
1071-
{'params': {'embedding': Array([[-0.28884724, 0.19018005, -0.414205 ],
1072-
[-0.11768015, -0.54618824, -0.3789283 ],
1073-
[ 0.30428642, 0.49511626, 0.01706631],
1074-
[-0.0982546 , -0.43055868, 0.20654906],
1075-
[-0.688412 , -0.46882293, 0.26723292]], dtype=float32)}}
1071+
{'params': {'embedding': Array([[ 0.04396089, -0.9328513 , -0.97328115],
1072+
[ 0.41147125, 0.66334754, 0.49469155],
1073+
[ 0.09719624, 0.49861377, 0.49519277],
1074+
[-0.13316602, 0.6697022 , 0.3710195 ],
1075+
[-0.5039532 , 0.287319 , 1.4369922 ]], dtype=float32)}}
10761076
>>> # get the first three and last three embeddings
10771077
>>> layer.apply(variables, indices_input)
1078-
Array([[[-0.28884724, 0.19018005, -0.414205 ],
1079-
[-0.11768015, -0.54618824, -0.3789283 ],
1080-
[ 0.30428642, 0.49511626, 0.01706631]],
1078+
Array([[[ 0.04396089, -0.9328513 , -0.97328115],
1079+
[ 0.41147125, 0.66334754, 0.49469155],
1080+
[ 0.09719624, 0.49861377, 0.49519277]],
10811081
<BLANKLINE>
1082-
[[-0.688412 , -0.46882293, 0.26723292],
1083-
[-0.0982546 , -0.43055868, 0.20654906],
1084-
[ 0.30428642, 0.49511626, 0.01706631]]], dtype=float32)
1082+
[[-0.5039532 , 0.287319 , 1.4369922 ],
1083+
[-0.13316602, 0.6697022 , 0.3710195 ],
1084+
[ 0.09719624, 0.49861377, 0.49519277]]], dtype=float32)
10851085
10861086
Attributes:
10871087
num_embeddings: number of embeddings / vocab size.

flax/linen/module.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2684,18 +2684,18 @@ def perturb(
26842684
>>> variables = model.init(jax.random.key(0), x)
26852685
>>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y)
26862686
>>> print(intm_grads['perturbations']['dense3'])
2687-
[[-1.456924 -0.44332537 0.02422847]
2688-
[-1.456924 -0.44332537 0.02422847]]
2687+
[[-0.04684732 0.06573904 -0.3194327 ]
2688+
[-0.04684732 0.06573904 -0.3194327 ]]
26892689
26902690
If perturbations are not passed to ``apply``, ``perturb`` behaves like a no-op
26912691
so you can easily disable the behavior when not needed::
26922692
26932693
>>> model.apply(variables, x) # works as expected
2694-
Array([[-1.0980128 , -0.67961735],
2695-
[-1.0980128 , -0.67961735]], dtype=float32)
2694+
Array([[-0.04579116, 0.50412744],
2695+
[-0.04579116, 0.50412744]], dtype=float32)
26962696
>>> model.apply({'params': variables['params']}, x) # behaves like a no-op
2697-
Array([[-1.0980128 , -0.67961735],
2698-
[-1.0980128 , -0.67961735]], dtype=float32)
2697+
Array([[-0.04579116, 0.50412744],
2698+
[-0.04579116, 0.50412744]], dtype=float32)
26992699
>>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y)
27002700
>>> 'perturbations' not in intm_grads
27012701
True

flax/linen/stochastic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class Dropout(Module):
4747
>>> x = jnp.ones((1, 3))
4848
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
4949
>>> model.apply(variables, x, train=False) # don't use dropout
50-
Array([[-0.88686204, -0.5928178 , -0.5184689 , -0.4345976 ]], dtype=float32)
50+
Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
5151
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
52-
Array([[ 0. , -1.1856356, -1.0369378, 0. ]], dtype=float32)
52+
Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
5353
5454
Attributes:
5555
rate: the dropout probability. (_not_ the keep rate!)

flax/nnx/nn/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ class MultiHeadAttention(Module):
244244
>>> assert (layer(q) == layer(q, q)).all()
245245
>>> assert (layer(q) == layer(q, q, q)).all()
246246
247-
Attributes:
247+
Args:
248248
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
249249
should be divisible by the number of heads.
250250
in_features: int or tuple with number of input features.

flax/nnx/nn/linear.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class LinearGeneral(Module):
119119
>>> y.shape
120120
(16, 4, 5)
121121
122-
Attributes:
122+
Args:
123123
in_features: int or tuple with number of input features.
124124
out_features: int or tuple with number of output features.
125125
axis: int or tuple with axes to apply the transformation on. For instance,
@@ -301,7 +301,7 @@ class Linear(Module):
301301
)
302302
})
303303
304-
Attributes:
304+
Args:
305305
in_features: the number of input features.
306306
out_features: the number of output features.
307307
use_bias: whether to add a bias to the output (default: True).
@@ -393,7 +393,7 @@ class Einsum(Module):
393393
>>> y.shape
394394
(16, 11, 8, 4)
395395
396-
Attributes:
396+
Args:
397397
einsum_str: a string to denote the einsum equation. The equation must
398398
have exactly two operands, the lhs being the input passed in, and
399399
the rhs being the learnable kernel. Exactly one of ``einsum_str``
@@ -572,7 +572,7 @@ class Conv(Module):
572572
... mask=mask, padding='VALID', rngs=rngs)
573573
>>> out = layer(x)
574574
575-
Attributes:
575+
Args:
576576
in_features: int or tuple with number of input features.
577577
out_features: int or tuple with number of output features.
578578
kernel_size: shape of the convolutional kernel. For 1D convolution,
@@ -823,7 +823,7 @@ class ConvTranspose(Module):
823823
... mask=mask, padding='VALID', rngs=rngs)
824824
>>> out = layer(x)
825825
826-
Attributes:
826+
Args:
827827
in_features: int or tuple with number of input features.
828828
out_features: int or tuple with number of output features.
829829
kernel_size: shape of the convolutional kernel. For 1D convolution,
@@ -1065,23 +1065,23 @@ class Embed(Module):
10651065
State({
10661066
'embedding': VariableState( # 15 (60 B)
10671067
type=Param,
1068-
value=Array([[-0.90411377, -0.3648777 , -1.1083648 ],
1069-
[ 0.01070483, 0.27923733, 1.7487359 ],
1070-
[ 0.59161806, 0.8660184 , 1.2838588 ],
1071-
[-0.748139 , -0.15856352, 0.06061118],
1072-
[-0.4769059 , -0.6607095 , 0.46697947]], dtype=float32)
1068+
value=Array([[ 0.57966787, -0.523274 , -0.43195742],
1069+
[-0.676289 , -0.50300646, 0.33996582],
1070+
[ 0.41796115, -0.59212935, 0.95934135],
1071+
[-1.0917838 , -0.7441663 , 0.07713798],
1072+
[-0.66570747, 0.13815777, 1.007365 ]], dtype=float32)
10731073
)
10741074
})
10751075
>>> # get the first three and last three embeddings
10761076
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
10771077
>>> layer(indices_input)
1078-
Array([[[-0.90411377, -0.3648777 , -1.1083648 ],
1079-
[ 0.01070483, 0.27923733, 1.7487359 ],
1080-
[ 0.59161806, 0.8660184 , 1.2838588 ]],
1078+
Array([[[ 0.57966787, -0.523274 , -0.43195742],
1079+
[-0.676289 , -0.50300646, 0.33996582],
1080+
[ 0.41796115, -0.59212935, 0.95934135]],
10811081
<BLANKLINE>
1082-
[[-0.4769059 , -0.6607095 , 0.46697947],
1083-
[-0.748139 , -0.15856352, 0.06061118],
1084-
[ 0.59161806, 0.8660184 , 1.2838588 ]]], dtype=float32)
1082+
[[-0.66570747, 0.13815777, 1.007365 ],
1083+
[-1.0917838 , -0.7441663 , 0.07713798],
1084+
[ 0.41796115, -0.59212935, 0.95934135]]], dtype=float32)
10851085
10861086
A parameterized function from integers [0, ``num_embeddings``) to
10871087
``features``-dimensional vectors. This ``Module`` will create an ``embedding``
@@ -1092,7 +1092,7 @@ class Embed(Module):
10921092
broadcast the ``embedding`` matrix to input shape with ``features``
10931093
dimension appended.
10941094
1095-
Attributes:
1095+
Args:
10961096
num_embeddings: number of embeddings / vocab size.
10971097
features: number of feature dimensions for each embedding.
10981098
dtype: the dtype of the embedding vectors (default: same as embedding).

flax/nnx/nn/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class LoRA(Module):
6161
>>> y.shape
6262
(16, 4)
6363
64-
Attributes:
64+
Args:
6565
in_features: the number of input features.
6666
lora_rank: the rank of the LoRA dimension.
6767
out_features: the number of output features.
@@ -133,7 +133,7 @@ class LoRALinear(Linear):
133133
>>> y.shape
134134
(16, 4)
135135
136-
Attributes:
136+
Args:
137137
in_features: the number of input features.
138138
out_features: the number of output features.
139139
lora_rank: the rank of the LoRA dimension.

flax/nnx/nn/normalization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class BatchNorm(Module):
236236
>>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all()
237237
>>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
238238
239-
Attributes:
239+
Args:
240240
num_features: the number of input features.
241241
use_running_average: if True, the stored batch statistics will be
242242
used instead of computing the batch statistics on the input.
@@ -407,7 +407,7 @@ class LayerNorm(Module):
407407
408408
>>> y = layer(x)
409409
410-
Attributes:
410+
Args:
411411
num_features: the number of input features.
412412
epsilon: A small float added to variance to avoid dividing by zero.
413413
dtype: the dtype of the result (default: infer from input and params).
@@ -539,7 +539,7 @@ class RMSNorm(Module):
539539
540540
>>> y = layer(x)
541541
542-
Attributes:
542+
Args:
543543
num_features: the number of input features.
544544
epsilon: A small float added to variance to avoid dividing by zero.
545545
dtype: the dtype of the result (default: infer from input and params).
@@ -670,7 +670,7 @@ class GroupNorm(Module):
670670
>>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x)
671671
>>> np.testing.assert_allclose(y, y2)
672672
673-
Attributes:
673+
Args:
674674
num_features: the number of input features/channels.
675675
num_groups: the total number of channel groups. The default value of 32 is
676676
proposed by the original group normalization paper.

0 commit comments

Comments
 (0)