Skip to content

Commit 637c82f

Browse files
Elnifionathanw-mlc
andauthored
[Llama 3.1] Updates dataset, logging, and checkpoint resume. (#787)
* adds all changes * updates MoE download as well * updates the logging name * Remove mention of fixed typo from README.md * updates the path * addresses comments * uses sequences here, instead of tokens * revert +1 in steps logging --------- Co-authored-by: Nathan Wasson <[email protected]>
1 parent a70765e commit 637c82f

File tree

7 files changed

+97
-56
lines changed

7 files changed

+97
-56
lines changed

large_language_model_pretraining/nemo/README.md

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ Note: it's recommended to map your `.ssh` folder to inside the container, so tha
3535

3636
The current codebase is using C4 dataset for train and evaluation. Please refer to [Section 3](#preprocessed-data-download) for downloading the preprocessed dataset and [Section 6](#data-preprocessing) if you would like to perform manual tokenization.
3737

38-
39-
### Steps to download the checkpoint
40-
4138
### Steps to run and time
4239

4340
To train Llama 3.1 405B, we need to fill out all fields in [config.sh](./config.sh). This file contains all configurations for Slurm cluster access and job submission configurations, directory mappings, containers, and model configurations.
@@ -79,12 +76,12 @@ You can then navigate in the terminal to your desired download directory and run
7976
```
8077
# Replace this path with your desired path on the machine
8178
export PREPROCESSED_PATH="./"
82-
rclone copy mlc-training:mlcommons-training-wg-public/llama3_1/datasets/preprocessed_c4 $PREPROCESSED_PATH -P
79+
rclone copy mlc-training:mlcommons-training-wg-public/common/datasets/c4/mixtral_8x22b_preprocessed $PREPROCESSED_PATH -P
8380
```
8481

8582
After the download is complete, you should see files with the following naming conventions under `PREPROCESSED_PATH`, ending with both `.idx` and `.bin`:
8683
- Training partitions: `c4-train.en_<number>_text_document`
87-
- Validation partitions: `c4-validation.en_text_document`
84+
- Validation partitions: `c4-validation-91205-samples.en_text_document`
8885

8986
#### Tokenizer
9087

@@ -103,22 +100,23 @@ After the download is complete, you should see five files under `TOKENIZER_PATH`
103100

104101
### Training and test data separation
105102

106-
We use the default split from the C4 dataset. This means that we use `c4-train.<x>-of-01024.json.gz` files for training and `c4-validation.<x>-of-00008.json.gz` files for evaluation.
103+
We use the default split from the C4 dataset. This means that we use `c4-train.<x>-of-01024.json.gz` files (where `768 <= x <= 1023`) for training, and we use our customized `c4-validation-91205-samples.en.json.gz`, which contains the first 91205 samples from the unshuffled C4 validation dataset, for evaluation.
104+
105+
Notice here that we are using the first 5760 sequences (47,185,920 tokens) from the validation dataset to perform the validation. According to our experiments, the first 91205 samples from the unshuffled C4 dataset yields 47,186,855 tokens, which is the smallest amount of samples needed to yield 47,185,920 tokens. Thus, we have chosen the first 91205 samples as our validation dataset.
107106

108107
### Training data order
109108

110109
We randomly shuffle the **last 256 of 1024 shards** for the benchmarking area.
111110

112111
### Test data order
113112

114-
We use the first 47M tokens in the validation dataset for validation. We **do not shuffle** the validation dataset.
113+
We use the first 5,760 sequences (91,205 untokenized samples) in the validation dataset for validation. We **do not shuffle** the validation dataset.
115114

116115
# 4. Model
117116
### Publication/Attribution
118117

119-
The model largely follows the Llama 3.1 405B [paper](https://arxiv.org/abs/2407.21783). Two noticeable differences are:
120-
1. We replace the paper's TikTokenizer with the **Mixtral 8x22b tokenizer** in this benchmark. Please refer to the [Tokenizer](#tokenizer) section for more details.
121-
1. We replace the paper's AdamW with the **Adam optimizer** in this benchmark. Please refer to the [Optimizer](#optimizer-spec) section for more details.
118+
The model largely follows the Llama 3.1 405B [paper](https://arxiv.org/abs/2407.21783). The only difference is:
119+
- We replace the paper's TikTokenizer with the **Mixtral 8x22b tokenizer** in this benchmark. Please refer to the [Tokenizer](#tokenizer) section for more details.
122120

123121
### Model details
124122

@@ -148,7 +146,7 @@ Large runs might need to span across multiple Slurm jobs, and we need to save an
148146

149147
### Optimizer spec
150148

151-
1. Optimizer type: **Adam**
149+
1. Optimizer type: **AdamW**
152150
2. Warmup steps computed as $8000 \times \lceil {1152 \over GBS} \rceil$.
153151
3. LR Scheduler's maximum number of steps computed as $1,200,000 \times \lceil {1152 \over GBS} \rceil$
154152

@@ -163,11 +161,11 @@ Validation log perplexity = 5.6
163161

164162
### Evaluation frequency
165163

166-
We perform evaluation every **377,487,360** tokens.
164+
We perform evaluation every **46,080** sequences.
167165

168166
### Evaluation thoroughness
169167

170-
We evaluate using **47,185,920** tokens from the validation dataset.
168+
We evaluate using **5,760** sequences from our customized validation dataset.
171169

172170

173171
# 6. Other
@@ -176,17 +174,41 @@ We evaluate using **47,185,920** tokens from the validation dataset.
176174

177175
Here are the instructions to prepare the preprocessed dataset from scratch. Data preprocessing is already done and the final dataset can be accessed by following instructions in the [Preprocessed data download](#preprocessed-data-download) section.
178176

177+
#### Raw data downloading
178+
179+
We use [AllenAI C4](https://huggingface.co/datasets/allenai/c4) dataset for this benchmark. The original zipped **`json.gz`** files can be downloaded by following AllenAI C4's instruction, and you can download our zipped customized validation dataset from the MLCommons S3 bucket by running the following command:
180+
181+
```bash
182+
export ORIGINAL_C4_PATH=""
183+
184+
# download the customized zipped validation dataset
185+
rclone copy mlc-training:mlcommons-training-wg-public/common/datasets/c4/original/c4-validation-91205-samples.en.json.gz $ORIGINAL_C4_PATH -P
186+
```
187+
188+
Alternatively, we have also hosted the **unzipped C4 `json`** files on MLCommons S3 bucket. You can download them using the following commands:
189+
190+
```bash
191+
export ORIGINAL_C4_PATH=""
192+
193+
# download the full C4 files, including all raw train and validations
194+
rclone copy mlc-training:mlcommons-training-wg-public/common/datasets/c4/original/en_json/3.0.1 $ORIGINAL_C4_PATH -P
195+
```
196+
197+
Note that for unzipped JSON files, it is recommended to zip them into `.gz` format before running the data preprocessing.
198+
179199
#### Prepare tokenizer
180200

181201
We use Mixtral 8x22B tokenizer in this benchmark. Tokenizer files can be downloaded [here](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/tree/main). Only the five files containing tokenizer-related contents (`special_tokens_map.json`, `tokenizer.json`, `tokenizer.model`, `tokenizer.model.v1`, `tokenizer_config.json`) are needed.
182202

183203
#### Run data preprocessing
184204

185-
Run the following commands to merge all 1024 training files into 8 `json.gz` files and all 8 validation files into a single `json.gz` file. Each of the `json.gz` files will be preprocessed into a pair of megatron dataset files (`.bin` and `.idx`).
205+
Run the following commands to merge all 1024 training files into 8 `json.gz` files, all 8 validation files into a single `json.gz` file, as well as generate our customized validation dataset. Each of the `json.gz` files will subsequently be preprocessed into a pair of megatron dataset files (`.bin` and `.idx`) by our preprocess.sh script.
186206

187207
```bash
188208
export C4_PATH=""
189209
export MERGED_C4_PATH=""
210+
# more information about this knob can be found in consolidate_data.sh
211+
export N_VALIDATION_SAMPLES=91205
190212

191213
bash consolidate_data.sh
192214
```
@@ -234,4 +256,4 @@ export DST_PATH=""
234256
sbatch launch_nemo_convert.sh
235257
```
236258

237-
After the model conversion is done, we can then set `MODEL_CKPT=$DST_PATH` together with `FROM_HF=1` when launching our job, so that we can resume training from the converted HF checkpoint.
259+
After the model conversion is done, we can then set `MODEL_CKPT=$DST_PATH` together with `FROM_HF=1` when launching our job, so that we can resume training from the converted HF checkpoint.

large_language_model_pretraining/nemo/callbacks.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def __init__(
9797
init_global_step, global_batch_size, seq_length,
9898
target_log_ppl,
9999
train_loss_key = "reduced_train_loss",
100-
val_loss_key = "val_loss"
100+
val_loss_key = "val_loss",
101+
train_step_time_in_s = "train_step_timing in s",
102+
train_step_time_atol=7200,
101103
):
102104
super().__init__()
103105

@@ -110,10 +112,17 @@ def __init__(
110112
self.val_loss_key = val_loss_key
111113
self.is_target_reached = False
112114

115+
self.train_step_time_in_s = train_step_time_in_s
116+
self.train_step_time_atol = train_step_time_atol
117+
113118
def log_metrics(self, metrics, step):
114119
if self.val_loss_key in metrics:
115120
self.log_validation_loss(metrics, step)
116121

122+
if self.train_step_time_in_s in metrics:
123+
step_time = metrics[self.train_step_time_in_s]
124+
assert step_time <= self.train_step_time_atol, f"Logged train step time ({step_time}) is slower than tolerable ({self.train_step_time_atol}). "
125+
117126
def log_validation_loss(self, metrics, step):
118127
consumed_tokens = (step - self.init_global_step) * self.gbs * self.seq_len
119128

@@ -138,11 +147,10 @@ def version(self):
138147

139148
### MLPerf callbacks
140149
def compute_consumed_mllog_tokens(trainer, init_global_step, global_batch_size, seq_length):
141-
steps_since_resume = trainer.global_step + 1 - init_global_step # global steps are 0-indexed
142150
consumed_samples = (
143-
steps_since_resume * global_batch_size
151+
trainer.global_step * global_batch_size
144152
)
145-
return int(consumed_samples) * seq_length
153+
return int(consumed_samples) # we log the epoch numbers in sequences, not tokens
146154

147155
class MLPerfCallback(pl.Callback):
148156
def __init__(

large_language_model_pretraining/nemo/pretrain_llama31.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,14 @@ def slurm_executor(
7171
),
7272
nodes=nodes,
7373
ntasks_per_node=devices,
74+
gpus_per_node=devices,
7475
mem="0",
7576
exclusive=True,
77+
gres="gpu:8",
7678
packager=run.GitArchivePackager(),
7779
dependencies=dependencies,
7880
)
7981

80-
if devices != 0:
81-
executor.gpus_per_node=devices
82-
executor.gres = "gpu:8"
83-
8482
executor.launcher = None
8583
executor.container_image = container_image
8684
executor.container_mounts = mounts
@@ -151,7 +149,9 @@ def get_pretrain(
151149
warmup_tokens = 8000 * base_gbs * 8192
152150

153151
max_lr = (gbs / base_gbs) * base_lr
152+
max_lr = round(max_lr, 8) # rounds to the nearest 8th digit.
154153

154+
# Code tracing shows that this is AdamW
155155
pretrain.optim = distributed_fused_adam_with_cosine_annealing(
156156
max_lr = max_lr,
157157
warmup_steps = math.ceil(warmup_tokens / 8192 / gbs),
@@ -217,10 +217,10 @@ def get_data(
217217
data_paths = {
218218
"train": train_datasets,
219219
"validation": [
220-
"/preproc_data/c4-validation.en_text_document"
220+
"/preproc_data/c4-validation-91205-samples.en_text_document"
221221
],
222222
"test": [
223-
"/preproc_data/c4-validation.en_text_document"
223+
"/preproc_data/c4-validation-91205-samples.en_text_document"
224224
],
225225
}
226226

@@ -300,8 +300,8 @@ def get_parser() -> argparse.ArgumentParser:
300300

301301
data_group.add_argument("--gbs", type=int, default=1152, help="Global batch size, should be divisible by PP")
302302
data_group.add_argument("--mbs", type=int, default=1, help="Micro batch size")
303-
data_group.add_argument("--eval_every", type=int, default=377_487_360, help="Evaluate at least every N training tokens")
304-
data_group.add_argument("--eval_tokens", type=int, default=47_185_920, help="Evaluate using at least N evaluation tokens")
303+
data_group.add_argument("--eval_every", type=int, default=46080, help="Evaluate at least every N training sequences")
304+
data_group.add_argument("--eval_tokens", type=int, default=5760, help="Evaluate using at least N evaluation sequences")
305305
data_group.add_argument('--max_steps', type=int, default=None, help="Maximum number of steps that each experiment partition will train on. None means no restriction on max steps. ")
306306
data_group.add_argument("--use_full_dataset", action="store_true", help="If set, then we use the full dataset, instead of the last 256/1024 shards")
307307
data_group.add_argument("--tokenizer_path", type=str, help="Tokenizer path that's used to tokenize the dataset")
@@ -312,6 +312,7 @@ def get_parser() -> argparse.ArgumentParser:
312312
experiment_group.add_argument("--num_exps", type=int, default=1)
313313
experiment_group.add_argument("--num_pars", type=int, default=1)
314314
experiment_group.add_argument("--target_log_ppl", type=float, default=5.6)
315+
experiment_group.add_argument("--step_time_atol", type=int, default=1600, help="train step time atol")
315316

316317
return parser
317318

@@ -351,8 +352,8 @@ def get_parser() -> argparse.ArgumentParser:
351352
use_full_dataset=args.use_full_dataset,
352353
)
353354

354-
eval_every_n_batches = math.ceil(args.eval_every / (args.gbs * 8192))
355-
eval_batches = math.ceil(args.eval_tokens / (args.gbs * 8192))
355+
eval_every_n_batches = math.ceil(args.eval_every / (args.gbs))
356+
eval_batches = math.ceil(args.eval_tokens / (args.gbs))
356357

357358
exp_prefix, pretrain = get_pretrain(
358359
size=args.size,
@@ -383,12 +384,12 @@ def get_parser() -> argparse.ArgumentParser:
383384
constants.EVAL_SAMPLES: args.eval_tokens,
384385

385386
# Optimizers
386-
constants.OPT_NAME: pretrain.optim.config.optimizer,
387+
constants.OPT_NAME: "adamw",
387388
constants.OPT_BASE_LR: pretrain.optim.config.lr,
388-
constants.OPT_ADAM_BETA_1: pretrain.optim.config.adam_beta1,
389-
constants.OPT_ADAM_BETA_2: pretrain.optim.config.adam_beta2,
390-
constants.OPT_ADAM_EPSILON: pretrain.optim.config.adam_eps,
391-
constants.OPT_WEIGHT_DECAY: pretrain.optim.config.weight_decay,
389+
constants.OPT_ADAMW_BETA_1: pretrain.optim.config.adam_beta1,
390+
constants.OPT_ADAMW_BETA_2: pretrain.optim.config.adam_beta2,
391+
constants.OPT_ADAMW_EPSILON: pretrain.optim.config.adam_eps,
392+
constants.OPT_ADAMW_WEIGHT_DECAY: pretrain.optim.config.weight_decay,
392393
constants.OPT_GRADIENT_CLIP_NORM: pretrain.optim.config.clip_grad,
393394

394395
# Schedulers
@@ -467,9 +468,14 @@ def get_parser() -> argparse.ArgumentParser:
467468
checkpoint_name = "checkpoint" + f"-seed-{seed}-par-{j}{ending_steps}"
468469
experiment_write_to_path = static_write_to_path + "/" + checkpoint_name
469470

470-
if not args.resume_from_hf:
471-
pretrain.resume.resume_from_directory = experiment_read_from_path
472-
pretrain.resume.resume_from_path = experiment_read_from_path
471+
if not (args.resume_from_hf and j == 0):
472+
pretrain.resume = run.Config(
473+
nl.AutoResume,
474+
resume_if_exists=True,
475+
resume_ignore_no_checkpoint=True,
476+
resume_from_path = experiment_read_from_path,
477+
resume_from_directory = experiment_read_from_path,
478+
)
473479
else:
474480
pretrain.resume = run.Config(nl.AutoResume, restore_config = run.Config(nl.RestoreConfig, path=experiment_read_from_path))
475481
pretrain.log.ckpt.train_time_interval = None
@@ -502,7 +508,8 @@ def get_parser() -> argparse.ArgumentParser:
502508
init_global_step=start_step,
503509
global_batch_size=args.gbs,
504510
seq_length=8192,
505-
target_log_ppl=args.target_log_ppl
511+
target_log_ppl=args.target_log_ppl,
512+
train_step_time_atol=args.step_time_atol,
506513
),
507514
]
508515

large_language_model_pretraining/nemo/run_llama31.sh

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,17 @@ git config --global --add safe.directory /workspace/llama31
5959
: "${START_STEPS:=0}"
6060

6161
# Dataloader settings
62-
: "${EVAL_EVERY:=""}"
63-
: "${EVAL_TOKENS:=""}"
6462
: "${MAX_STEPS:=""}"
6563

6664
# Experiment settings
6765
: "${SEEDS:=""}"
6866
IFS=" " read -ra seeds <<< $SEEDS
6967
: "${NEXP:=1}"
7068
: "${NPAR:=1}"
71-
: "${SAVE_CKPT:=1}"
69+
: "${SAVE_CKPT:=0}"
7270
: "${TAG:=""}"
7371
: "${TARGET:="5.6"}"
72+
: "${STEP_TIME_ATOL:="7200"}" # maximum tolerable step time, setting to 2hr by default
7473

7574
# Run
7675

@@ -107,14 +106,6 @@ if [ ! $DEPENDENCIES = "" ]; then
107106
CMD_SUFFIX="${CMD_SUFFIX} --dependencies ${DEPENDENCIES}"
108107
fi
109108

110-
if [ ! $EVAL_EVERY = "" ]; then
111-
CMD_SUFFIX="${CMD_SUFFIX} --eval_every ${EVAL_EVERY}"
112-
fi
113-
114-
if [ ! $EVAL_TOKENS = "" ]; then
115-
CMD_SUFFIX="${CMD_SUFFIX} --eval_tokens ${EVAL_TOKENS}"
116-
fi
117-
118109
if [ ! $MAX_STEPS = "" ]; then
119110
CMD_SUFFIX="${CMD_SUFFIX} --max_steps ${MAX_STEPS}"
120111
fi
@@ -145,6 +136,7 @@ python3 pretrain_llama31.py \
145136
--continual_ckpt_path /continual \
146137
--tokenizer_path /tokenizer \
147138
--target_log_ppl $TARGET \
139+
--step_time_atol $STEP_TIME_ATOL \
148140
--ckpt_start_step $START_STEPS \
149141
--max_retries $MAX_RETRIES \
150142
$CMD_SUFFIX

large_language_model_pretraining/nemo/utils/consolidate_data.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ set -e
22

33
: "${C4_PATH:?C4_PATH not set}"
44
: "${MERGED_C4_PATH:?MERGED_C4_PATH not set}"
5+
: "${N_VALIDATION_SAMPLES:=91205}"
6+
# defaults the N_VALIDATION_SAMPLES to 91205
7+
# C4 validation dataset: each sample on average tokenizes to 518 tokens
8+
# thus, to reach 47,185,920 validation tokens, we need to use at least 91205 samples,
9+
# which, after tokenization, will yield 47,186,855 tokens.
510

611
# create softlinks to store each shard before merging
712
mkdir -p softlinks
@@ -26,4 +31,7 @@ for shard in {0..7}; do
2631
cat softlinks/en_${shard}/*gz > ${MERGED_C4_PATH}/c4-train.en_${shard}.json.gz
2732
done
2833

29-
cat softlinks/en_validation/*gz > ${MERGED_C4_PATH}/c4-validation.en.json.gz
34+
cat softlinks/en_validation/*gz > ${MERGED_C4_PATH}/c4-validation.en.json.gz
35+
36+
# select the first N_VALIDATION_SAMPLES number of samples
37+
zcat ${MERGED_C4_PATH}/c4-validation.en.json.gz | head -n $N_VALIDATION_SAMPLES | gzip > ${MERGED_C4_PATH}/c4-validation-${N_VALIDATION_SAMPLES}-samples.en.json.gz

large_language_model_pretraining/nemo/utils/preprocess.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ srun --nodes=1 --ntasks-per-node=1 \
2727
--container-image=$CONT_IMAGE_URL --container-mounts $container_maps --no-container-entrypoint \
2828
--output preprocess_outputs/dataset_preprocess_validation.out \
2929
python3 /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \
30-
--input "/dataset/c4-validation.en.json.gz" \
31-
--output-prefix "/outputs/c4-validation.en" \
30+
--input "/dataset/c4-validation-91205-samples.en.json.gz" \
31+
--output-prefix "/outputs/c4-validation-91205-samples.en" \
3232
--tokenizer-library huggingface --tokenizer-type /tokenizer \
3333
--dataset-impl mmap --workers 128 &
3434
wait

mixture_of_experts_pretraining/README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,12 +492,16 @@ You can then navigate in the terminal to your desired download directory and run
492492
493493
## Text Datasets
494494
**Dataset**
495-
* Train Dataset`c4/en_json/3.0.1`
496-
* Eval Dataset `c4/en_val_subset_json`
497-
* Preprocessed GPU dataset `preprocessed_c4`
495+
* Train Dataset`original/en_json/3.0.1`
496+
* Eval Dataset `original/en_val_subset_json`
497+
* Preprocessed GPU dataset `mixtral_8x22b_preprocessed`
498498
```
499499
mkdir -p datasets
500-
rclone copy mlc-training:mlcommons-training-wg-public/mixtral_8x22b/datasets ./datasets -P
500+
rclone copy mlc-training:mlcommons-training-wg-public/common/datasets/c4 ./datasets -P
501+
502+
# moving them to the original naming convention so that it won't break the code
503+
mv ./datasets/original ./datasets/c4
504+
mv ./datasets/mixtral_8x22b_preprocessed ./datasets/preprocessed_c4
501505
```
502506
## Checkpoints
503507
* Mixtral-8x22B-v0.1-fsdp: use for `tensor_parallelism=1`

0 commit comments

Comments
 (0)