diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index f08f4cc..18547ce 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -3,7 +3,7 @@ --------------------- Includes constants and helper functions used by other Python scripts. """ - +import json import itertools import os import pickle @@ -15,6 +15,7 @@ import pandas as pd import requests import torch +import torch.utils.data # List of all amino acids AMINO_ACIDS: List[str] = [ @@ -509,35 +510,55 @@ def __init__(self, dist_env: Optional[str] = None): "slurm": ("SLURM_NTASKS", "SLURM_PROCID") }.get(dist_env, ("WORLD_SIZE", "LOCAL_RANK")) + @property + def total_examples(self) -> int: + """Number of examples in dataset. Must be implemented by subclasses.""" + raise NotImplementedError + @property def iterator(self) -> Iterator: - """Define the stream logic for the dataset. Implement in subclasses.""" + """Iterator over dataset. Must be implemented by subclasses.""" raise NotImplementedError def __iter__(self) -> Iterator: - """ - Create an iterator for the dataset, handling multi-processing contexts. - - Returns: - Iterator: The iterator for the dataset. - """ + """Create iterator with proper work distribution in multi-processing contexts.""" worker_info = torch.utils.data.get_worker_info() - if worker_info is None: + + if worker_info is None: # Single-process data loading return self.iterator - # In multi-processing context, use 'os.environ' to - # find global worker rank. Then use 'islice' to allocate - # the items of the stream to the workers. - world_size = int(os.environ.get(self.world_size_handle)) - global_rank = int(os.environ.get(self.rank_handle)) + # In multi-processing context, calculate global worker rank + world_size = int(os.environ.get(self.world_size_handle, 1)) + global_rank = int(os.environ.get(self.rank_handle, 0)) local_rank = worker_info.id local_num_workers = worker_info.num_workers - # Assume that each process has the same number of local workers. + # Calculate worker rank and total workers worker_rk = global_rank * local_num_workers + local_rank worker_nb = world_size * local_num_workers + return itertools.islice(self.iterator, worker_rk, None, worker_nb) + def get_total_steps(self, batch_size: int, n_gpus: int, grad_accum: int = 1) -> int: + """ + Calculate total training steps for scheduler configuration. + + Args: + batch_size: Training batch size per GPU + n_gpus: Number of GPUs being used + grad_accum: Gradient accumulation steps + + Returns: + int: Total number of training steps + """ + effective_batch = batch_size * n_gpus * grad_accum + total_steps = self.total_examples // effective_batch + + # Add one more step if there's a remainder + if self.total_examples % effective_batch != 0: + total_steps += 1 + + return total_steps class IterableJSONData(IterableData): """ @@ -549,10 +570,37 @@ class IterableJSONData(IterableData): **kwargs: Additional keyword arguments for the base class. """ - def __init__(self, data_path: str, train: bool = True, **kwargs): - super().__init__(**kwargs) + def __init__(self, data_path: str, train: bool = True, dist_env: Optional[str] = None): + super().__init__(dist_env=dist_env) self.data_path = data_path self.train = train + self._total_examples = None + + if not os.path.exists(data_path): + raise FileNotFoundError(f"Data file not found: {data_path}") + + @property + def total_examples(self) -> int: + """Calculate and cache total number of examples.""" + if self._total_examples is None: + self._total_examples = sum(1 for _ in open(self.data_path)) + return self._total_examples + + @property + def iterator(self) -> Iterator[Dict[str, Any]]: + """ + Iterate over JSON lines in the data file. + + Returns: + Iterator[Dict[str, Any]]: Iterator yielding parsed JSON objects + """ + with open(self.data_path, 'r') as f: + for line in f: + try: + yield json.loads(line.strip()) + except json.JSONDecodeError as e: + print(f"Warning: Skipping malformed JSON line: {e}") + continue class ConfigManager(ABC): diff --git a/finetune.py b/finetune.py index fce46f3..bbce9ab 100644 --- a/finetune.py +++ b/finetune.py @@ -11,6 +11,7 @@ import argparse import os +from typing import Optional import pytorch_lightning as pl import torch @@ -72,23 +73,59 @@ def __call__(self, examples): class plTrainHarness(pl.LightningModule): - def __init__(self, model, learning_rate, warmup_fraction): + """Lightning module for training the model.""" + + def __init__( + self, + model: torch.nn.Module, + learning_rate: float, + warmup_fraction: float, + train_dataset: Optional[IterableJSONData] = None, + batch_size: Optional[int] = None, + n_gpus: Optional[int] = None, + grad_accum: Optional[int] = None + ): super().__init__() self.model = model self.learning_rate = learning_rate self.warmup_fraction = warmup_fraction + # Store dataset info for scheduler configuration + self.train_dataset = train_dataset + self.batch_size = batch_size + self.n_gpus = n_gpus + self.grad_accum = grad_accum + def configure_optimizers(self): optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate, ) + + # Calculate total steps from dataset if possible + if all(x is not None for x in [ + self.train_dataset, + self.batch_size, + self.n_gpus, + self.grad_accum + ]): + total_steps = self.train_dataset.get_total_steps( + batch_size=self.batch_size, + n_gpus=self.n_gpus, + grad_accum=self.grad_accum + ) + else: + total_steps = self.trainer.estimated_stepping_batches + + print(f"\nConfiguring OneCycleLR with total_steps: {total_steps}\n") + lr_scheduler = { "scheduler": torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, + total_steps=total_steps, pct_start=self.warmup_fraction, + anneal_strategy='linear' ), "interval": "step", "frequency": 1, @@ -96,6 +133,7 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler] def training_step(self, batch, batch_idx): + """Execute a single training step.""" self.model.bert.set_attention_type("block_sparse") outputs = self.model(**batch) self.log_dict( @@ -128,13 +166,42 @@ def main(args): pl.seed_everything(args.seed) torch.set_float32_matmul_precision("medium") - # Load the tokenizer and model + # Load tokenizer and model from HuggingFace tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") - model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer-base") - harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction) + model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") # Load the training data - train_data = IterableJSONData(args.dataset_dir, dist_env="slurm") + train_data = IterableJSONData( + data_path=args.dataset_dir, + train=True, + dist_env="slurm" if not args.debug else None + ) + + # Print dataset info + print(f"\nTotal examples in dataset: {train_data.total_examples}") + print(f"Batch size per GPU: {args.batch_size}") + print(f"Number of GPUs: {args.num_gpus}") + print(f"Gradient accumulation steps: {args.accumulate_grad_batches}") + + total_steps = train_data.get_total_steps( + batch_size=args.batch_size, + n_gpus=args.num_gpus, + grad_accum=args.accumulate_grad_batches + ) + print(f"Calculated total training steps: {total_steps}\n") + + # Create harness + harnessed_model = plTrainHarness( + model=model, + learning_rate=args.learning_rate, + warmup_fraction=args.warmup_fraction, + train_dataset=train_data, + batch_size=args.batch_size, + n_gpus=args.num_gpus, + grad_accum=args.accumulate_grad_batches + ) + + # Create data loader data_loader = DataLoader( dataset=train_data, collate_fn=MaskedTokenizerCollator(tokenizer), @@ -149,6 +216,16 @@ def main(args): checkpoint_filename=args.checkpoint_filename, every_n_train_steps=args.save_every_n_steps, ) + + # Early stopping callback + early_stop = pl.callbacks.EarlyStopping( + monitor='loss', + min_delta=0.00, + patience=3, + verbose=True, + mode='min' + ) + trainer = pl.Trainer( default_root_dir=args.checkpoint_dir, strategy="ddp_find_unused_parameters_true", @@ -158,8 +235,10 @@ def main(args): max_epochs=args.max_epochs, deterministic=False, enable_checkpointing=True, - callbacks=[save_checkpoint], + callbacks=[save_checkpoint, early_stop], accumulate_grad_batches=args.accumulate_grad_batches, + gradient_clip_val=1.0, # Add gradient clipping + log_every_n_steps=10, ) # Finetune the model @@ -228,3 +307,4 @@ def main(args): parser.add_argument("--debug", action="store_true", help="Enable debug mode") args = parser.parse_args() main(args) + diff --git a/pretrain.py b/pretrain.py index 1b253cc..d1f8af2 100644 --- a/pretrain.py +++ b/pretrain.py @@ -10,6 +10,7 @@ import argparse import os +from typing import Optional import pytorch_lightning as pl import torch @@ -72,23 +73,57 @@ def __call__(self, examples): class plTrainHarness(pl.LightningModule): - def __init__(self, model, learning_rate, warmup_fraction): + def __init__( + self, + model: torch.nn.Module, + learning_rate: float, + warmup_fraction: float, + train_dataset: Optional[IterableJSONData] = None, + batch_size: Optional[int] = None, + n_gpus: Optional[int] = None, + grad_accum: Optional[int] = None + ): super().__init__() self.model = model self.learning_rate = learning_rate self.warmup_fraction = warmup_fraction + # Store dataset info for scheduler configuration + self.train_dataset = train_dataset + self.batch_size = batch_size + self.n_gpus = n_gpus + self.grad_accum = grad_accum + def configure_optimizers(self): optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate, ) + + # Calculate total steps from dataset if possible + if all(x is not None for x in [ + self.train_dataset, + self.batch_size, + self.n_gpus, + self.grad_accum + ]): + total_steps = self.train_dataset.get_total_steps( + batch_size=self.batch_size, + n_gpus=self.n_gpus, + grad_accum=self.grad_accum + ) + else: + total_steps = self.trainer.estimated_stepping_batches + + print(f"\nConfiguring OneCycleLR with total_steps: {total_steps}\n") + lr_scheduler = { "scheduler": torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, + total_steps=total_steps, pct_start=self.warmup_fraction, + anneal_strategy='linear' ), "interval": "step", "frequency": 1, @@ -96,6 +131,7 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler] def training_step(self, batch, batch_idx): + """Execute a single training step.""" self.model.bert.set_attention_type("block_sparse") outputs = self.model(**batch) self.log_dict( @@ -147,10 +183,39 @@ def main(args): sep_token_id=2, ) model = BigBirdForMaskedLM(config=config) - harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction) # Load the training data - train_data = IterableJSONData(args.train_data_path, dist_env="slurm") + train_data = IterableJSONData( + data_path=args.train_data_path, + train=True, + dist_env="slurm" if not args.debug else None + ) + + # Print dataset info + print(f"\nTotal examples in dataset: {train_data.total_examples}") + print(f"Batch size per GPU: {args.batch_size}") + print(f"Number of GPUs: {args.num_gpus}") + print(f"Gradient accumulation steps: {args.accumulate_grad_batches}") + + total_steps = train_data.get_total_steps( + batch_size=args.batch_size, + n_gpus=args.num_gpus, + grad_accum=args.accumulate_grad_batches + ) + print(f"Calculated total training steps: {total_steps}\n") + + # Create harness + harnessed_model = plTrainHarness( + model=model, + learning_rate=args.learning_rate, + warmup_fraction=args.warmup_fraction, + train_dataset=train_data, + batch_size=args.batch_size, + n_gpus=args.num_gpus, + grad_accum=args.accumulate_grad_batches + ) + + # Create data loader data_loader = DataLoader( dataset=train_data, collate_fn=MaskedTokenizerCollator(tokenizer), @@ -161,6 +226,7 @@ def main(args): # Setup trainer and callbacks save_checkpoint = EpochCheckpoint(args.checkpoint_dir, args.save_interval) + trainer = pl.Trainer( default_root_dir=args.checkpoint_dir, strategy="ddp_find_unused_parameters_true", @@ -172,9 +238,11 @@ def main(args): enable_checkpointing=True, callbacks=[save_checkpoint], accumulate_grad_batches=args.accumulate_grad_batches, + gradient_clip_val=1.0, # Add gradient clipping + log_every_n_steps=10, ) - # Pretrain the model + # Train the model trainer.fit(harnessed_model, data_loader) @@ -237,3 +305,4 @@ def main(args): parser.add_argument("--debug", action="store_true", help="Enable debug mode") args = parser.parse_args() main(args) +