Skip to content

[Llama 3.1] Updates dataset, logging, and checkpoint resume. #787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions large_language_model_pretraining/nemo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ Note: it's recommended to map your `.ssh` folder to inside the container, so tha

The current codebase is using C4 dataset for train and evaluation. Please refer to [Section 3](#preprocessed-data-download) for downloading the preprocessed dataset and [Section 6](#data-preprocessing) if you would like to perform manual tokenization.


### Steps to download the checkpoint

### Steps to run and time

To train Llama 3.1 405B, we need to fill out all fields in [config.sh](./config.sh). This file contains all configurations for Slurm cluster access and job submission configurations, directory mappings, containers, and model configurations.
Expand Down Expand Up @@ -79,12 +76,12 @@ You can then navigate in the terminal to your desired download directory and run
```
# Replace this path with your desired path on the machine
export PREPROCESSED_PATH="./"
rclone copy mlc-training:mlcommons-training-wg-public/llama3_1/datasets/preprocessed_c4 $PREPROCESSED_PATH -P
rclone copy mlc-training:mlcommons-training-wg-public/common/datasets/c4/mixtral_8x22b_preprocessed $PREPROCESSED_PATH -P
```

After the download is complete, you should see files with the following naming conventions under `PREPROCESSED_PATH`, ending with both `.idx` and `.bin`:
- Training partitions: `c4-train.en_<number>_text_document`
- Validation partitions: `c4-validation.en_text_document`
- Validation partitions: `c4-validation-91205-samples.en_text_document`

#### Tokenizer

Expand All @@ -103,22 +100,23 @@ After the download is complete, you should see five files under `TOKENIZER_PATH`

### Training and test data separation

We use the default split from the C4 dataset. This means that we use `c4-train.<x>-of-01024.json.gz` files for training and `c4-validation.<x>-of-00008.json.gz` files for evaluation.
We use the default split from the C4 dataset. This means that we use `c4-train.<x>-of-01024.json.gz` files (where `768 <= x <= 1023`) for training, and we use our customized `c4-validation-91205-samples.en.json.gz`, which contains the first 91205 samples from the unshuffled C4 validation dataset, for evaluation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why and how 91205 samples were chosen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is discussed in the utils/consolidate_data.sh. I can copy the lines here to make it more clear in the README.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the description in this commit


Notice here that we are using the first 5760 sequences (47,185,920 tokens) from the validation dataset to perform the validation. According to our experiments, the first 91205 samples from the unshuffled C4 dataset yields 47,186,855 tokens, which is the smallest amount of samples needed to yield 47,185,920 tokens. Thus, we have chosen the first 91205 samples as our validation dataset.

### Training data order

We randomly shuffle the **last 256 of 1024 shards** for the benchmarking area.

### Test data order

We use the first 47M tokens in the validation dataset for validation. We **do not shuffle** the validation dataset.
We use the first 5,760 sequences (91,205 untokenized samples) in the validation dataset for validation. We **do not shuffle** the validation dataset.

# 4. Model
### Publication/Attribution

The model largely follows the Llama 3.1 405B [paper](https://arxiv.org/abs/2407.21783). Two noticeable differences are:
1. We replace the paper's TikTokenizer with the **Mixtral 8x22b tokenizer** in this benchmark. Please refer to the [Tokenizer](#tokenizer) section for more details.
1. We replace the paper's AdamW with the **Adam optimizer** in this benchmark. Please refer to the [Optimizer](#optimizer-spec) section for more details.
The model largely follows the Llama 3.1 405B [paper](https://arxiv.org/abs/2407.21783). The only difference is:
- We replace the paper's TikTokenizer with the **Mixtral 8x22b tokenizer** in this benchmark. Please refer to the [Tokenizer](#tokenizer) section for more details.

### Model details

Expand Down Expand Up @@ -148,7 +146,7 @@ Large runs might need to span across multiple Slurm jobs, and we need to save an

### Optimizer spec

1. Optimizer type: **Adam**
1. Optimizer type: **AdamW**
2. Warmup steps computed as $8000 \times \lceil {1152 \over GBS} \rceil$.
3. LR Scheduler's maximum number of steps computed as $1,200,000 \times \lceil {1152 \over GBS} \rceil$

Expand All @@ -163,11 +161,11 @@ Validation log perplexity = 5.6

### Evaluation frequency

We perform evaluation every **377,487,360** tokens.
We perform evaluation every **46,080** sequences.

### Evaluation thoroughness

We evaluate using **47,185,920** tokens from the validation dataset.
We evaluate using **5,760** sequences from our customized validation dataset.


# 6. Other
Expand All @@ -176,17 +174,41 @@ We evaluate using **47,185,920** tokens from the validation dataset.

Here are the instructions to prepare the preprocessed dataset from scratch. Data preprocessing is already done and the final dataset can be accessed by following instructions in the [Preprocessed data download](#preprocessed-data-download) section.

#### Raw data downloading

We use [AllenAI C4](https://huggingface.co/datasets/allenai/c4) dataset for this benchmark. The original zipped **`json.gz`** files can be downloaded by following AllenAI C4's instruction, and you can download our zipped customized validation dataset from the MLCommons S3 bucket by running the following command:

```bash
export ORIGINAL_C4_PATH=""

# download the customized zipped validation dataset
rclone copy mlc-training:mlcommons-training-wg-public/common/datasets/c4/original/c4-validation-91205-samples.en.json.gz $ORIGINAL_C4_PATH -P
```

Alternatively, we have also hosted the **unzipped C4 `json`** files on MLCommons S3 bucket. You can download them using the following commands:

```bash
export ORIGINAL_C4_PATH=""

# download the full C4 files, including all raw train and validations
rclone copy mlc-training:mlcommons-training-wg-public/common/datasets/c4/original/en_json/3.0.1 $ORIGINAL_C4_PATH -P
```

Note that for unzipped JSON files, it is recommended to zip them into `.gz` format before running the data preprocessing.

#### Prepare tokenizer

We use Mixtral 8x22B tokenizer in this benchmark. Tokenizer files can be downloaded [here](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/tree/main). Only the five files containing tokenizer-related contents (`special_tokens_map.json`, `tokenizer.json`, `tokenizer.model`, `tokenizer.model.v1`, `tokenizer_config.json`) are needed.

#### Run data preprocessing

Run the following commands to merge all 1024 training files into 8 `json.gz` files and all 8 validation files into a single `json.gz` file. Each of the `json.gz` files will be preprocessed into a pair of megatron dataset files (`.bin` and `.idx`).
Run the following commands to merge all 1024 training files into 8 `json.gz` files, all 8 validation files into a single `json.gz` file, as well as generate our customized validation dataset. Each of the `json.gz` files will subsequently be preprocessed into a pair of megatron dataset files (`.bin` and `.idx`) by our preprocess.sh script.

```bash
export C4_PATH=""
export MERGED_C4_PATH=""
# more information about this knob can be found in consolidate_data.sh
export N_VALIDATION_SAMPLES=91205

bash consolidate_data.sh
```
Expand Down Expand Up @@ -234,4 +256,4 @@ export DST_PATH=""
sbatch launch_nemo_convert.sh
```

After the model conversion is done, we can then set `MODEL_CKPT=$DST_PATH` together with `FROM_HF=1` when launching our job, so that we can resume training from the converted HF checkpoint.
After the model conversion is done, we can then set `MODEL_CKPT=$DST_PATH` together with `FROM_HF=1` when launching our job, so that we can resume training from the converted HF checkpoint.
16 changes: 12 additions & 4 deletions large_language_model_pretraining/nemo/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def __init__(
init_global_step, global_batch_size, seq_length,
target_log_ppl,
train_loss_key = "reduced_train_loss",
val_loss_key = "val_loss"
val_loss_key = "val_loss",
train_step_time_in_s = "train_step_timing in s",
train_step_time_atol=7200,
):
super().__init__()

Expand All @@ -110,10 +112,17 @@ def __init__(
self.val_loss_key = val_loss_key
self.is_target_reached = False

self.train_step_time_in_s = train_step_time_in_s
self.train_step_time_atol = train_step_time_atol

def log_metrics(self, metrics, step):
if self.val_loss_key in metrics:
self.log_validation_loss(metrics, step)

if self.train_step_time_in_s in metrics:
step_time = metrics[self.train_step_time_in_s]
assert step_time <= self.train_step_time_atol, f"Logged train step time ({step_time}) is slower than tolerable ({self.train_step_time_atol}). "

def log_validation_loss(self, metrics, step):
consumed_tokens = (step - self.init_global_step) * self.gbs * self.seq_len

Expand All @@ -138,11 +147,10 @@ def version(self):

### MLPerf callbacks
def compute_consumed_mllog_tokens(trainer, init_global_step, global_batch_size, seq_length):
steps_since_resume = trainer.global_step + 1 - init_global_step # global steps are 0-indexed
consumed_samples = (
steps_since_resume * global_batch_size
trainer.global_step * global_batch_size
)
return int(consumed_samples) * seq_length
return int(consumed_samples) # we log the epoch numbers in sequences, not tokens

class MLPerfCallback(pl.Callback):
def __init__(
Expand Down
45 changes: 26 additions & 19 deletions large_language_model_pretraining/nemo/pretrain_llama31.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,14 @@ def slurm_executor(
),
nodes=nodes,
ntasks_per_node=devices,
gpus_per_node=devices,
mem="0",
exclusive=True,
gres="gpu:8",
packager=run.GitArchivePackager(),
dependencies=dependencies,
)

if devices != 0:
executor.gpus_per_node=devices
executor.gres = "gpu:8"

executor.launcher = None
executor.container_image = container_image
executor.container_mounts = mounts
Expand Down Expand Up @@ -151,7 +149,9 @@ def get_pretrain(
warmup_tokens = 8000 * base_gbs * 8192

max_lr = (gbs / base_gbs) * base_lr
max_lr = round(max_lr, 8) # rounds to the nearest 8th digit.

# Code tracing shows that this is AdamW
pretrain.optim = distributed_fused_adam_with_cosine_annealing(
max_lr = max_lr,
warmup_steps = math.ceil(warmup_tokens / 8192 / gbs),
Expand Down Expand Up @@ -217,10 +217,10 @@ def get_data(
data_paths = {
"train": train_datasets,
"validation": [
"/preproc_data/c4-validation.en_text_document"
"/preproc_data/c4-validation-91205-samples.en_text_document"
],
"test": [
"/preproc_data/c4-validation.en_text_document"
"/preproc_data/c4-validation-91205-samples.en_text_document"
],
}

Expand Down Expand Up @@ -300,8 +300,8 @@ def get_parser() -> argparse.ArgumentParser:

data_group.add_argument("--gbs", type=int, default=1152, help="Global batch size, should be divisible by PP")
data_group.add_argument("--mbs", type=int, default=1, help="Micro batch size")
data_group.add_argument("--eval_every", type=int, default=377_487_360, help="Evaluate at least every N training tokens")
data_group.add_argument("--eval_tokens", type=int, default=47_185_920, help="Evaluate using at least N evaluation tokens")
data_group.add_argument("--eval_every", type=int, default=46080, help="Evaluate at least every N training sequences")
data_group.add_argument("--eval_tokens", type=int, default=5760, help="Evaluate using at least N evaluation sequences")
data_group.add_argument('--max_steps', type=int, default=None, help="Maximum number of steps that each experiment partition will train on. None means no restriction on max steps. ")
data_group.add_argument("--use_full_dataset", action="store_true", help="If set, then we use the full dataset, instead of the last 256/1024 shards")
data_group.add_argument("--tokenizer_path", type=str, help="Tokenizer path that's used to tokenize the dataset")
Expand All @@ -312,6 +312,7 @@ def get_parser() -> argparse.ArgumentParser:
experiment_group.add_argument("--num_exps", type=int, default=1)
experiment_group.add_argument("--num_pars", type=int, default=1)
experiment_group.add_argument("--target_log_ppl", type=float, default=5.6)
experiment_group.add_argument("--step_time_atol", type=int, default=1600, help="train step time atol")

return parser

Expand Down Expand Up @@ -351,8 +352,8 @@ def get_parser() -> argparse.ArgumentParser:
use_full_dataset=args.use_full_dataset,
)

eval_every_n_batches = math.ceil(args.eval_every / (args.gbs * 8192))
eval_batches = math.ceil(args.eval_tokens / (args.gbs * 8192))
eval_every_n_batches = math.ceil(args.eval_every / (args.gbs))
eval_batches = math.ceil(args.eval_tokens / (args.gbs))

exp_prefix, pretrain = get_pretrain(
size=args.size,
Expand Down Expand Up @@ -383,12 +384,12 @@ def get_parser() -> argparse.ArgumentParser:
constants.EVAL_SAMPLES: args.eval_tokens,

# Optimizers
constants.OPT_NAME: pretrain.optim.config.optimizer,
constants.OPT_NAME: "adamw",
constants.OPT_BASE_LR: pretrain.optim.config.lr,
constants.OPT_ADAM_BETA_1: pretrain.optim.config.adam_beta1,
constants.OPT_ADAM_BETA_2: pretrain.optim.config.adam_beta2,
constants.OPT_ADAM_EPSILON: pretrain.optim.config.adam_eps,
constants.OPT_WEIGHT_DECAY: pretrain.optim.config.weight_decay,
constants.OPT_ADAMW_BETA_1: pretrain.optim.config.adam_beta1,
constants.OPT_ADAMW_BETA_2: pretrain.optim.config.adam_beta2,
constants.OPT_ADAMW_EPSILON: pretrain.optim.config.adam_eps,
constants.OPT_ADAMW_WEIGHT_DECAY: pretrain.optim.config.weight_decay,
constants.OPT_GRADIENT_CLIP_NORM: pretrain.optim.config.clip_grad,

# Schedulers
Expand Down Expand Up @@ -467,9 +468,14 @@ def get_parser() -> argparse.ArgumentParser:
checkpoint_name = "checkpoint" + f"-seed-{seed}-par-{j}{ending_steps}"
experiment_write_to_path = static_write_to_path + "/" + checkpoint_name

if not args.resume_from_hf:
pretrain.resume.resume_from_directory = experiment_read_from_path
pretrain.resume.resume_from_path = experiment_read_from_path
if not (args.resume_from_hf and j == 0):
pretrain.resume = run.Config(
nl.AutoResume,
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
resume_from_path = experiment_read_from_path,
resume_from_directory = experiment_read_from_path,
)
else:
pretrain.resume = run.Config(nl.AutoResume, restore_config = run.Config(nl.RestoreConfig, path=experiment_read_from_path))
pretrain.log.ckpt.train_time_interval = None
Expand Down Expand Up @@ -502,7 +508,8 @@ def get_parser() -> argparse.ArgumentParser:
init_global_step=start_step,
global_batch_size=args.gbs,
seq_length=8192,
target_log_ppl=args.target_log_ppl
target_log_ppl=args.target_log_ppl,
train_step_time_atol=args.step_time_atol,
),
]

Expand Down
14 changes: 3 additions & 11 deletions large_language_model_pretraining/nemo/run_llama31.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,17 @@ git config --global --add safe.directory /workspace/llama31
: "${START_STEPS:=0}"

# Dataloader settings
: "${EVAL_EVERY:=""}"
: "${EVAL_TOKENS:=""}"
: "${MAX_STEPS:=""}"

# Experiment settings
: "${SEEDS:=""}"
IFS=" " read -ra seeds <<< $SEEDS
: "${NEXP:=1}"
: "${NPAR:=1}"
: "${SAVE_CKPT:=1}"
: "${SAVE_CKPT:=0}"
: "${TAG:=""}"
: "${TARGET:="5.6"}"
: "${STEP_TIME_ATOL:="7200"}" # maximum tolerable step time, setting to 2hr by default

# Run

Expand Down Expand Up @@ -107,14 +106,6 @@ if [ ! $DEPENDENCIES = "" ]; then
CMD_SUFFIX="${CMD_SUFFIX} --dependencies ${DEPENDENCIES}"
fi

if [ ! $EVAL_EVERY = "" ]; then
CMD_SUFFIX="${CMD_SUFFIX} --eval_every ${EVAL_EVERY}"
fi

if [ ! $EVAL_TOKENS = "" ]; then
CMD_SUFFIX="${CMD_SUFFIX} --eval_tokens ${EVAL_TOKENS}"
fi

if [ ! $MAX_STEPS = "" ]; then
CMD_SUFFIX="${CMD_SUFFIX} --max_steps ${MAX_STEPS}"
fi
Expand Down Expand Up @@ -145,6 +136,7 @@ python3 pretrain_llama31.py \
--continual_ckpt_path /continual \
--tokenizer_path /tokenizer \
--target_log_ppl $TARGET \
--step_time_atol $STEP_TIME_ATOL \
--ckpt_start_step $START_STEPS \
--max_retries $MAX_RETRIES \
$CMD_SUFFIX
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ set -e

: "${C4_PATH:?C4_PATH not set}"
: "${MERGED_C4_PATH:?MERGED_C4_PATH not set}"
: "${N_VALIDATION_SAMPLES:=91205}"
# defaults the N_VALIDATION_SAMPLES to 91205
# C4 validation dataset: each sample on average tokenizes to 518 tokens
# thus, to reach 47,185,920 validation tokens, we need to use at least 91205 samples,
# which, after tokenization, will yield 47,186,855 tokens.

# create softlinks to store each shard before merging
mkdir -p softlinks
Expand All @@ -26,4 +31,7 @@ for shard in {0..7}; do
cat softlinks/en_${shard}/*gz > ${MERGED_C4_PATH}/c4-train.en_${shard}.json.gz
done

cat softlinks/en_validation/*gz > ${MERGED_C4_PATH}/c4-validation.en.json.gz
cat softlinks/en_validation/*gz > ${MERGED_C4_PATH}/c4-validation.en.json.gz

# select the first N_VALIDATION_SAMPLES number of samples
zcat ${MERGED_C4_PATH}/c4-validation.en.json.gz | head -n $N_VALIDATION_SAMPLES | gzip > ${MERGED_C4_PATH}/c4-validation-${N_VALIDATION_SAMPLES}-samples.en.json.gz
4 changes: 2 additions & 2 deletions large_language_model_pretraining/nemo/utils/preprocess.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ srun --nodes=1 --ntasks-per-node=1 \
--container-image=$CONT_IMAGE_URL --container-mounts $container_maps --no-container-entrypoint \
--output preprocess_outputs/dataset_preprocess_validation.out \
python3 /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \
--input "/dataset/c4-validation.en.json.gz" \
--output-prefix "/outputs/c4-validation.en" \
--input "/dataset/c4-validation-91205-samples.en.json.gz" \
--output-prefix "/outputs/c4-validation-91205-samples.en" \
--tokenizer-library huggingface --tokenizer-type /tokenizer \
--dataset-impl mmap --workers 128 &
wait
12 changes: 8 additions & 4 deletions mixture_of_experts_pretraining/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,16 @@ You can then navigate in the terminal to your desired download directory and run

## Text Datasets
**Dataset**
* Train Dataset`c4/en_json/3.0.1`
* Eval Dataset `c4/en_val_subset_json`
* Preprocessed GPU dataset `preprocessed_c4`
* Train Dataset`original/en_json/3.0.1`
* Eval Dataset `original/en_val_subset_json`
* Preprocessed GPU dataset `mixtral_8x22b_preprocessed`
```
mkdir -p datasets
rclone copy mlc-training:mlcommons-training-wg-public/mixtral_8x22b/datasets ./datasets -P
rclone copy mlc-training:mlcommons-training-wg-public/common/datasets/c4 ./datasets -P

# moving them to the original naming convention so that it won't break the code
mv ./datasets/original ./datasets/c4
mv ./datasets/mixtral_8x22b_preprocessed ./datasets/preprocessed_c4
```
## Checkpoints
* Mixtral-8x22B-v0.1-fsdp: use for `tensor_parallelism=1`
Expand Down