Skip to content

Commit 34086ac

Browse files
committed
fix(moe): Fix OOM and HF requirements for CUDA path
1 parent 392fc9f commit 34086ac

File tree

4 files changed

+10
-37
lines changed

4 files changed

+10
-37
lines changed

mixture_of_experts_pretraining/clm_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def mask_pad(examples):
251251
# need to run in cpu with single process
252252
# to walk around undefined `OmegaConf.register_new_resolver` need to overwrite `run_dir` `global_train_batch_size` `global_eval_batch_size`
253253
# python clm_datasets.py model.name_or_path=mistralai/Mixtral-8x22B-v0.1 run_dir=/tmp global_train_batch_size=1 global_eval_batch_size=1 max_length=32768
254-
@hydra.main(config_path="config", config_name="config")
254+
@hydra.main(version_base=None, config_path="config", config_name="config")
255255
def main(config: DictConfig):
256256
tokenizer = AutoTokenizer.from_pretrained(
257257
config.model.name_or_path,

mixture_of_experts_pretraining/docker/gpu/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ ENV PYTHONPATH "${PYTHONPATH}:/app/Megatron-LM"
4141
RUN pip install git+https://github.com/NVIDIA/dllogger#egg=dllogger
4242
RUN pip install datasets==2.20.0 hydra-core sentencepiece
4343
RUN pip install "git+https://github.com/mlperf/logging.git"
44+
RUN pip install git+https://github.com/NVIDIA/NeMo-Run.git
4445

4546
WORKDIR /app/training
4647
ADD . /app/training

mixture_of_experts_pretraining/model_utils_gpu.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
limitations under the License.
1515
"""
1616

17-
import os
1817

1918
import torch
2019
from megatron.core.optimizer import OptimizerConfig
@@ -26,37 +25,9 @@
2625

2726
def setup_distributed(config):
2827
"""Initialize torch.distributed."""
29-
# Get rank and world size.
30-
local_rank = int(os.getenv("LOCAL_RANK", 0))
31-
rank = int(os.getenv("RANK", "0"))
32-
world_size = int(os.getenv("WORLD_SIZE", "1"))
33-
34-
logging.info(
35-
f"Initializing torch.distributed with local_rank: {local_rank}, rank: {rank}, world_size: {world_size}"
36-
)
37-
38-
# Set the device id.
39-
device = rank % torch.cuda.device_count()
40-
if local_rank is not None:
41-
device = local_rank
42-
torch.cuda.set_device(device)
43-
44-
# Call the init process.
45-
init_method = "tcp://"
46-
master_ip = os.getenv("MASTER_ADDR", "localhost")
47-
master_port = os.getenv("MASTER_PORT", "6000")
48-
import datetime
49-
50-
DEFAULT_TIMEOUT = datetime.timedelta(minutes=60)
51-
init_method += master_ip + ":" + master_port
5228
torch.distributed.init_process_group(
5329
backend="nccl",
54-
timeout=DEFAULT_TIMEOUT,
55-
world_size=world_size,
56-
rank=rank,
57-
init_method=init_method,
5830
)
59-
return local_rank, rank, world_size
6031

6132

6233
def setup_model_and_trainer(
@@ -124,6 +95,7 @@ def setup_model_and_trainer(
12495
fp16=False,
12596
params_dtype=torch.bfloat16,
12697
clip_grad=max_grad_norm,
98+
use_distributed_optimizer=True,
12799
)
128100

129101
if scheduler.name == "CosineAnnealing":

mixture_of_experts_pretraining/run_clm.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,16 @@ def main(config: DictConfig):
8585
)
8686
logger.info(f"{config.eval_frequency=}")
8787

88-
tokenizer = AutoTokenizer.from_pretrained(
89-
config.model.name_or_path,
90-
add_eos_token=False,
91-
add_bos_token=False,
92-
use_fast=False,
93-
)
94-
9588
clmlogger = ClmLogger(config, filename="output.txt")
9689

9790
if not USE_CUDA:
91+
tokenizer = AutoTokenizer.from_pretrained(
92+
config.model.name_or_path,
93+
add_eos_token=False,
94+
add_bos_token=False,
95+
use_fast=False,
96+
)
97+
9898
config_path = os.path.join(config.run_dir, "config.yaml")
9999
with get_file(config_path, "w") as f:
100100
OmegaConf.save(config, f)

0 commit comments

Comments
 (0)