Skip to content

Iterable data not implemented error #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 65 additions & 17 deletions CodonTransformer/CodonUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
---------------------
Includes constants and helper functions used by other Python scripts.
"""

import json
import itertools
import os
import pickle
Expand All @@ -15,6 +15,7 @@
import pandas as pd
import requests
import torch
import torch.utils.data

# List of all amino acids
AMINO_ACIDS: List[str] = [
Expand Down Expand Up @@ -509,35 +510,55 @@ def __init__(self, dist_env: Optional[str] = None):
"slurm": ("SLURM_NTASKS", "SLURM_PROCID")
}.get(dist_env, ("WORLD_SIZE", "LOCAL_RANK"))

@property
def total_examples(self) -> int:
"""Number of examples in dataset. Must be implemented by subclasses."""
raise NotImplementedError

@property
def iterator(self) -> Iterator:
"""Define the stream logic for the dataset. Implement in subclasses."""
"""Iterator over dataset. Must be implemented by subclasses."""
raise NotImplementedError

def __iter__(self) -> Iterator:
"""
Create an iterator for the dataset, handling multi-processing contexts.

Returns:
Iterator: The iterator for the dataset.
"""
"""Create iterator with proper work distribution in multi-processing contexts."""
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:

if worker_info is None: # Single-process data loading
return self.iterator

# In multi-processing context, use 'os.environ' to
# find global worker rank. Then use 'islice' to allocate
# the items of the stream to the workers.
world_size = int(os.environ.get(self.world_size_handle))
global_rank = int(os.environ.get(self.rank_handle))
# In multi-processing context, calculate global worker rank
world_size = int(os.environ.get(self.world_size_handle, 1))
global_rank = int(os.environ.get(self.rank_handle, 0))
local_rank = worker_info.id
local_num_workers = worker_info.num_workers

# Assume that each process has the same number of local workers.
# Calculate worker rank and total workers
worker_rk = global_rank * local_num_workers + local_rank
worker_nb = world_size * local_num_workers

return itertools.islice(self.iterator, worker_rk, None, worker_nb)

def get_total_steps(self, batch_size: int, n_gpus: int, grad_accum: int = 1) -> int:
"""
Calculate total training steps for scheduler configuration.

Args:
batch_size: Training batch size per GPU
n_gpus: Number of GPUs being used
grad_accum: Gradient accumulation steps

Returns:
int: Total number of training steps
"""
effective_batch = batch_size * n_gpus * grad_accum
total_steps = self.total_examples // effective_batch

# Add one more step if there's a remainder
if self.total_examples % effective_batch != 0:
total_steps += 1

return total_steps

class IterableJSONData(IterableData):
"""
Expand All @@ -549,10 +570,37 @@ class IterableJSONData(IterableData):
**kwargs: Additional keyword arguments for the base class.
"""

def __init__(self, data_path: str, train: bool = True, **kwargs):
super().__init__(**kwargs)
def __init__(self, data_path: str, train: bool = True, dist_env: Optional[str] = None):
super().__init__(dist_env=dist_env)
self.data_path = data_path
self.train = train
self._total_examples = None

if not os.path.exists(data_path):
raise FileNotFoundError(f"Data file not found: {data_path}")

@property
def total_examples(self) -> int:
"""Calculate and cache total number of examples."""
if self._total_examples is None:
self._total_examples = sum(1 for _ in open(self.data_path))
return self._total_examples

@property
def iterator(self) -> Iterator[Dict[str, Any]]:
"""
Iterate over JSON lines in the data file.

Returns:
Iterator[Dict[str, Any]]: Iterator yielding parsed JSON objects
"""
with open(self.data_path, 'r') as f:
for line in f:
try:
yield json.loads(line.strip())
except json.JSONDecodeError as e:
print(f"Warning: Skipping malformed JSON line: {e}")
continue


class ConfigManager(ABC):
Expand Down
94 changes: 87 additions & 7 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import argparse
import os
from typing import Optional

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -72,30 +73,67 @@ def __call__(self, examples):


class plTrainHarness(pl.LightningModule):
def __init__(self, model, learning_rate, warmup_fraction):
"""Lightning module for training the model."""

def __init__(
self,
model: torch.nn.Module,
learning_rate: float,
warmup_fraction: float,
train_dataset: Optional[IterableJSONData] = None,
batch_size: Optional[int] = None,
n_gpus: Optional[int] = None,
grad_accum: Optional[int] = None
):
super().__init__()
self.model = model
self.learning_rate = learning_rate
self.warmup_fraction = warmup_fraction

# Store dataset info for scheduler configuration
self.train_dataset = train_dataset
self.batch_size = batch_size
self.n_gpus = n_gpus
self.grad_accum = grad_accum

def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.learning_rate,
)

# Calculate total steps from dataset if possible
if all(x is not None for x in [
self.train_dataset,
self.batch_size,
self.n_gpus,
self.grad_accum
]):
total_steps = self.train_dataset.get_total_steps(
batch_size=self.batch_size,
n_gpus=self.n_gpus,
grad_accum=self.grad_accum
)
else:
total_steps = self.trainer.estimated_stepping_batches

print(f"\nConfiguring OneCycleLR with total_steps: {total_steps}\n")

lr_scheduler = {
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.learning_rate,
total_steps=self.trainer.estimated_stepping_batches,
total_steps=total_steps,
pct_start=self.warmup_fraction,
anneal_strategy='linear'
),
"interval": "step",
"frequency": 1,
}
return [optimizer], [lr_scheduler]

def training_step(self, batch, batch_idx):
"""Execute a single training step."""
self.model.bert.set_attention_type("block_sparse")
outputs = self.model(**batch)
self.log_dict(
Expand Down Expand Up @@ -128,13 +166,42 @@ def main(args):
pl.seed_everything(args.seed)
torch.set_float32_matmul_precision("medium")

# Load the tokenizer and model
# Load tokenizer and model from HuggingFace
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer-base")
harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction)
model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")

# Load the training data
train_data = IterableJSONData(args.dataset_dir, dist_env="slurm")
train_data = IterableJSONData(
data_path=args.dataset_dir,
train=True,
dist_env="slurm" if not args.debug else None
)

# Print dataset info
print(f"\nTotal examples in dataset: {train_data.total_examples}")
print(f"Batch size per GPU: {args.batch_size}")
print(f"Number of GPUs: {args.num_gpus}")
print(f"Gradient accumulation steps: {args.accumulate_grad_batches}")

total_steps = train_data.get_total_steps(
batch_size=args.batch_size,
n_gpus=args.num_gpus,
grad_accum=args.accumulate_grad_batches
)
print(f"Calculated total training steps: {total_steps}\n")

# Create harness
harnessed_model = plTrainHarness(
model=model,
learning_rate=args.learning_rate,
warmup_fraction=args.warmup_fraction,
train_dataset=train_data,
batch_size=args.batch_size,
n_gpus=args.num_gpus,
grad_accum=args.accumulate_grad_batches
)

# Create data loader
data_loader = DataLoader(
dataset=train_data,
collate_fn=MaskedTokenizerCollator(tokenizer),
Expand All @@ -149,6 +216,16 @@ def main(args):
checkpoint_filename=args.checkpoint_filename,
every_n_train_steps=args.save_every_n_steps,
)

# Early stopping callback
early_stop = pl.callbacks.EarlyStopping(
monitor='loss',
min_delta=0.00,
patience=3,
verbose=True,
mode='min'
)

trainer = pl.Trainer(
default_root_dir=args.checkpoint_dir,
strategy="ddp_find_unused_parameters_true",
Expand All @@ -158,8 +235,10 @@ def main(args):
max_epochs=args.max_epochs,
deterministic=False,
enable_checkpointing=True,
callbacks=[save_checkpoint],
callbacks=[save_checkpoint, early_stop],
accumulate_grad_batches=args.accumulate_grad_batches,
gradient_clip_val=1.0, # Add gradient clipping
log_every_n_steps=10,
)

# Finetune the model
Expand Down Expand Up @@ -228,3 +307,4 @@ def main(args):
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
args = parser.parse_args()
main(args)

Loading
Loading