Skip to content

Commit a2fe3ca

Browse files
author
Flax Authors
committed
Merge pull request #4745 from vfdev-5:fix-lm1b-nnx-example
PiperOrigin-RevId: 764936098
2 parents e4ab883 + 7def227 commit a2fe3ca

File tree

8 files changed

+23
-23
lines changed

8 files changed

+23
-23
lines changed

examples/lm1b_nnx/README.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Then install Flax + the example dependencies:
5252
git clone --depth=1 --branch=main https://github.com/google/flax
5353
cd flax
5454
pip install -e .
55-
cd examples/lm1b
55+
cd examples/lm1b_nnx
5656
pip install -r requirements.txt
5757
```
5858

@@ -75,9 +75,9 @@ tensorboard --logdir=$HOME/logs
7575
You should expect to get numbers similar to these:
7676

7777

78-
Hardware | config | Training time | Loss | TensorBoard.dev | Workdir
79-
-------- | ------- | ------------- | -------------- | ------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------
80-
TPU v3-8 | default | 13h18m | 3.127 | [2021-08-08](https://tensorboard.dev/experiment/n30WkNOZTJq3RHWD7wNslg/) | [gs://flax_public/examples/lm1b/default](https://console.cloud.google.com/storage/browser/flax_public/examples/lm1b/default)
78+
Hardware | config | Training time | Loss | Workdir
79+
-------- | ------- | ------------- | -------------- | --------------------------------------------------------------------------------------------------------------------------
80+
TPU v3-8 | default | 13h18m | 3.127 | [gs://flax_public/examples/lm1b/default](https://console.cloud.google.com/storage/browser/flax_public/examples/lm1b/default)
8181

8282
### Downloading the LM1B Datasets
8383

@@ -87,6 +87,5 @@ data on a storage bucket, from where it can be loaded directly. Set the
8787
`TFDS_DATA_DIR` to your storage bucket path (`gs://<bucket name>`).
8888

8989
You can download and prepare LM1B datasets using TFDS directly:
90-
`python -m tensorflow_datasets.scripts.download_and_prepare
91-
--datasets=lm1b`
90+
`python -m tensorflow_datasets.scripts.download_and_prepare --datasets=lm1b`
9291

examples/lm1b_nnx/input_pipeline_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def _get_datasets(self):
4848
vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model')
4949

5050
# Go two directories up to the root of the flax directory.
51-
flax_root_dir = pathlib.Path(__file__).parents[4]
51+
# "/path/to/flax/examples/lm1b_nnx/models_test.py" -> "/path/to/flax"
52+
flax_root_dir = pathlib.Path(__file__).absolute().parents[2]
5253
data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable
53-
5454
with tfds.testing.mock_data(num_examples=128, data_dir=data_dir):
5555
train_ds, eval_ds, predict_ds, _ = input_pipeline.get_datasets(
5656
n_devices=2, config=config, vocab_path=vocab_path

examples/lm1b_nnx/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
'File path to the training hyperparameter configuration.',
3535
lock_config=True,
3636
)
37-
flags.mark_flags_as_required(['config', 'workdir'])
37+
flags.mark_flags_as_required(['workdir'])
3838

3939

4040
def main(argv):

examples/lm1b_nnx/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs):
292292
broadcast_dropout=False,
293293
dropout_rate=config.attention_dropout_rate,
294294
rngs=rngs,
295+
keep_rngs=False,
295296
)
296297
self.mlp = MlpBlock(config=config, rngs=rngs)
297298
self.dropout = nnx.Dropout(rate=config.dropout_rate)

examples/lm1b_nnx/models_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
jax.config.update('jax_disable_most_optimizations', True)
3535

3636
# add project_root to import lm1b Linen model
37-
project_root = str(Path(__file__).absolute().parents[4])
37+
# "/path/to/flax/examples/lm1b_nnx/models_test.py" -> "/path/to/flax"
38+
project_root = str(Path(__file__).absolute().parents[2])
3839
sys.path.append(project_root)
3940
from examples.lm1b.models import TransformerLM as TransformerLinen # type: ignore[import-error]
4041

examples/lm1b_nnx/train.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
# pytype: disable=wrong-arg-count
2121
# pytype: disable=attribute-error
2222

23-
import collections
2423
import dataclasses
2524
import os
2625

@@ -41,7 +40,6 @@
4140
from jax.sharding import PartitionSpec as P
4241
from utils import HasCache, TrainState
4342

44-
from flax import linen as nn
4543
from flax import nnx
4644
from flax.training import checkpoints, common_utils
4745

@@ -115,7 +113,7 @@ def compute_weighted_cross_entropy(
115113
targets, vocab_size, on_value=confidence, off_value=low_confidence
116114
)
117115

118-
loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
116+
loss = -jnp.sum(soft_targets * nnx.log_softmax(logits), axis=-1)
119117
loss = loss - normalizing_constant
120118

121119
normalizing_factor = np.prod(targets.shape)
@@ -389,6 +387,7 @@ def train_and_evaluate(config: default.Config, workdir: str):
389387
workdir: Working directory for checkpoints and TF summaries. If this
390388
contains checkpoint training will be resumed from the latest checkpoint.
391389
"""
390+
workdir = os.path.abspath(workdir)
392391
tf.io.gfile.makedirs(workdir)
393392

394393
vocab_path = config.vocab_path
@@ -440,18 +439,15 @@ def encode_strings(strs, max_len):
440439
max_len=max(config.max_target_length, config.max_eval_target_length),
441440
dropout_rate=config.dropout_rate,
442441
attention_dropout_rate=config.attention_dropout_rate,
443-
kernel_init=nn.initializers.xavier_uniform(),
444-
bias_init=nn.initializers.normal(stddev=1e-6),
442+
kernel_init=nnx.initializers.xavier_uniform(),
443+
bias_init=nnx.initializers.normal(stddev=1e-6),
445444
axis_rules=config.axis_rules,
446445
)
447446

448447
# Mesh definition
449448
devices_array = utils.create_device_mesh(config)
450449
mesh = Mesh(devices_array, config.mesh_axes)
451450

452-
# print(mesh.shape)
453-
# exit()
454-
455451
start_step = 0
456452
rng = jax.random.PRNGKey(config.seed)
457453
rng, init_rng = jax.random.split(rng)
@@ -498,7 +494,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
498494
None,
499495
), # type: ignore
500496
out_shardings=(state_sharding, None), # type: ignore
501-
static_argnums=(2, 3),
497+
static_argnames=("learning_rate_fn", "label_smoothing"),
502498
donate_argnums=0,
503499
)
504500

@@ -509,7 +505,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
509505
data_sharding,
510506
), # type: ignore
511507
out_shardings=None, # type: ignore
512-
static_argnums=(2, 3),
508+
static_argnames=("graphdef", "label_smoothing"),
513509
)
514510

515511
# Since the inputs and rngkey args for predict_step will be batched,
@@ -575,7 +571,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
575571
h(step)
576572

577573
# Periodic metric handling.
578-
if step % config.eval_every_steps == 0 or is_last_step:
574+
if (step > 0 and step % config.eval_every_steps == 0) or is_last_step:
579575
with report_progress.timed('training_metrics'):
580576
logging.info('Gathering training metrics.')
581577
train_metrics = common_utils.stack_forest(train_metrics)

examples/lm1b_nnx/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def setup_initial_state(
159159
model = constructor(config, rng)
160160
graphdef, params = nnx.split(model, nnx.Param)
161161
state = TrainState.create(
162-
apply_fn=graphdef.apply, params=params, tx=tx, graphdef=graphdef
162+
apply_fn=graphdef.apply,
163+
params=params,
164+
tx=tx,
165+
graphdef=graphdef,
163166
)
164167
state = jax.tree.map(_to_array, state)
165168
state_spec = nnx.get_partition_spec(state)

tests/download_dataset_metadata.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
set -e
1010

11-
# Download TFDS metadata to flax/.tdfs/metadata directory.
11+
# Download TFDS metadata to flax/.tfds/metadata directory.
1212
# This allows the tests to specify the `data_dir` when using tfds.testing.mock_data().
1313
cd "$( dirname "$0" )"
1414

0 commit comments

Comments
 (0)