Skip to content

Commit 42aaab3

Browse files
adding initial code drop for llm finetune (#698)
* adding initial code drop for llm finetune * (a) fixing padding issue; (b) masking input tokens for eval dataset; (c) adding support for mlloger * fix masking bug * adding more logger support * bug fix * fix logging bug and update HP * adding patch for memmory issue and fused model enablement * fixing dataset and model links and updating bash script and readme * Fix eval batch size, add Dockerfile, improve logging, remove unused code * Fix eval batch size, add Dockerfile, improve logging, remove unused code * Remove training_step * renaming directory and adding more HP values to logger * adding weight decay to TrainingArguments and BLOCK_START BLOCK_STOP * editing logging to resolve all checker issues * fix issue in steps_num logging * updating bash script for GBS=8 --------- Co-authored-by: Michal Futrega <[email protected]>
1 parent 2d0e7ae commit 42aaab3

10 files changed

+1262
-0
lines changed

llama2_70b_lora/Dockerfile

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:24.01-py3
2+
FROM ${FROM_IMAGE_NAME}
3+
4+
WORKDIR /workspace/ft-llm
5+
ADD . /workspace/ft-llm
6+
7+
RUN pip install -r requirements.txt
8+
RUN pip install flash-attn==2.4.1 --no-build-isolation

llama2_70b_lora/README.md

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# LoRA benchmark
2+
3+
LoRA benchmark on GPU (Nvidia A100 80GB). Inspired by [this blog post](https://medium.com/@sourabmangrulkar/falcon-180b-finetuning-using-peft-and-deepspeed-b92643091d99) and [this script](https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/train.py).
4+
5+
6+
## Setup
7+
8+
Run the following:
9+
```bash
10+
sudo ./run_docker.sh
11+
cd lora
12+
pip install -r requirements.txt
13+
```
14+
15+
> The Docker run command contains `-v /home/regis_huggingface_co/workspace:/root/workspace --workdir /root/workspace`. Feel free to change these flags at your own convenience.
16+
17+
You will also need to run the following to install flash attention:
18+
```
19+
pip install flash-attn --no-build-isolation
20+
```
21+
22+
> For flash attention, make sure that the following command returns 0:
23+
> ```
24+
> ninja --version >/dev/null && echo $?
25+
> ```
26+
> If not, run
27+
> ```
28+
> pip uninstall -y ninja && pip install ninja
29+
> ```
30+
> and install `flash-attn` again.
31+
> More information [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
32+
33+
Make sure to have requested permission for donwloading Llama2 weights on the Hugging Face Hub: https://huggingface.co/meta-llama/Llama-2-7b-hf
34+
Then, you will need to be connected to your Hugging Face account with a read token running:
35+
```
36+
huggingface-cli login
37+
```
38+
Finally please install mlperf logger:
39+
```
40+
git clone https://github.com/mlperf/logging.git mlperf-logging
41+
pip install -e mlperf-logging
42+
```
43+
## Download Data and Model
44+
data can be downloaded from:
45+
[mlperf drive - train data](https://drive.google.com/file/d/1-JgY1mEafcJ7qhggt6UR3OEKAciIPd5s/view?usp=sharing)
46+
[mlperf drive - validation data](https://drive.google.com/file/d/1jrm6Lacrq49AYv0uB_Qy22xRmfPixQvs/view?usp=sharing)
47+
[mlperf drive - llama-v2 model](https://drive.google.com/drive/folders/1sTeuxkPhwkNPKIPFnOLIYCcK53oB3Ypc?usp=sharing)
48+
As defaults the scripts assume the model is under at ```./llama-v2-fused-qkv``` and the both train and validation are under ```dataset``` folder.
49+
50+
## Llama2-70B on 8 devices
51+
52+
Run:
53+
```bash
54+
accelerate launch --config_file configs/default_config.yaml scripts/train.py \
55+
--model_name meta-llama/Llama-2-70b-hf \
56+
--dataset_name "tau/scrolls" --dataset_config_name "gov_report" \
57+
--max_seq_len 8192 \
58+
--bf16 True \
59+
--logging_steps 1 \
60+
--eval_steps 22 \
61+
--output_dir "/tmp/llama-70b" \
62+
--per_device_train_batch_size 1 \
63+
--gradient_accumulation_steps 1 \
64+
--dataset_text_field "input" \
65+
--lr_scheduler_type "cosine" \
66+
--learning_rate 1e-3 \
67+
--warmup_ratio 0.03 \
68+
--use_gradient_checkpointing True \
69+
--use_peft_lora True \
70+
--lora_r 16 \
71+
--lora_alpha 32 \
72+
--lora_dropout 0.1 \
73+
--max_steps 440 \
74+
--use_flash_attn \
75+
--lora_target_modules "q_proj,v_proj,k_proj,o_proj"
76+
```
77+
where the Accelerate config file is [this one](https://github.com/regisss/lora/blob/main/configs/default_config.yaml).
78+
79+
> Using flash attention with `--use_flash_attn` is necessary for training on 8k-token sequences.
80+
81+
Learning curves of such a run can be found here: https://huggingface.co/regisss/test_5/tensorboard
82+
83+
84+
## Evaluation
85+
86+
To run evaluation for summarizing texts, you can run:
87+
- Without LoRA adapter weights:
88+
```
89+
python scripts/eval.py --model_name meta-llama/Llama-2-70b-hf --max_new_tokens 900 --seq_length 8192 --do_sample --dataset_name "tau/scrolls" --dataset_config_name "gov_report"
90+
```
91+
- With LoRA adapter weights:
92+
```
93+
python scripts/eval.py --peft_model_name path_to_my_lora_model --max_new_tokens 900 --seq_length 8192 --do_sample --dataset_name "tau/scrolls" --dataset_config_name "gov_report"
94+
```
95+
## expected outcome
96+
97+
A clean output (train and eval loss) of a singel run with 440 steps can be found under
98+
```
99+
convergence_example.txt
100+
```
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
deepspeed_config:
4+
gradient_accumulation_steps: 1
5+
offload_optimizer_device: none
6+
offload_param_device: none
7+
zero3_init_flag: true
8+
zero3_save_16bit_model: true
9+
zero_stage: 3
10+
distributed_type: DEEPSPEED
11+
downcast_bf16: 'no'
12+
machine_rank: 0
13+
main_training_function: main
14+
mixed_precision: bf16
15+
num_machines: 1
16+
num_processes: 8
17+
rdzv_backend: static
18+
same_network: true
19+
tpu_env: []
20+
tpu_use_cluster: false
21+
tpu_use_sudo: false
22+
use_cpu: false

0 commit comments

Comments
 (0)