From 4327c8b0aff09fd1698800e3ccca40b881ca024f Mon Sep 17 00:00:00 2001 From: dansuissa <150948900+dansuissa@users.noreply.github.com> Date: Thu, 20 Feb 2025 16:03:42 +0100 Subject: [PATCH 1/2] Create CodonRestrictionSites.py --- CodonTransformer/CodonRestrictionSites.py | 280 ++++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 CodonTransformer/CodonRestrictionSites.py diff --git a/CodonTransformer/CodonRestrictionSites.py b/CodonTransformer/CodonRestrictionSites.py new file mode 100644 index 0000000..2b28b6f --- /dev/null +++ b/CodonTransformer/CodonRestrictionSites.py @@ -0,0 +1,280 @@ +""" +File: CodonRestrictionSites.py +------------------------------ +Includes functions for handling forbidden sequences and restriction sites in DNA sequences. +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set, Tuple, Union +import re +from Bio.Seq import Seq + + +@dataclass +class ForbiddenSequenceViolation: + """Represents a violation of forbidden sequence rules.""" + sequence: str + position: int + type: str # 'restriction_site' or 'user_defined' + is_reverse_complement: bool = False + + +@dataclass +class ForbiddenSequenceConfig: + """Configuration for forbidden sequence handling.""" + # Pre-defined restriction sites to check + restriction_sites: Set[str] = field(default_factory=lambda: { + 'BsaI': 'GGTCTC', # BsaI recognition site + 'BbsI': 'GAAGAC', # BbsI recognition site + 'BsmBI': 'CGTCTC', # BsmBI recognition site + 'EcoRI': 'GAATTC', # EcoRI recognition site + 'BamHI': 'GGATCC', # BamHI recognition site + 'XhoI': 'CTCGAG', # XhoI recognition site + 'NotI': 'GCGGCCGC', # NotI recognition site + 'XbaI': 'TCTAGA', # XbaI recognition site + }) + + # Additional user-defined forbidden sequences + additional_sequences: Set[str] = field(default_factory=set) + + # Strategy for handling forbidden sequences + strategy: str = "hybrid" # 'mutate', 'regenerate', or 'hybrid' + + # Maximum number of regeneration attempts + max_attempts: int = 10 + + # Maximum number of mutations allowed per sequence + max_mutations: int = 3 + + # Whether to check reverse complement sequences + check_reverse_complement: bool = True + + def __post_init__(self): + """Validate configuration after initialization.""" + valid_strategies = {'mutate', 'regenerate', 'hybrid'} + if self.strategy not in valid_strategies: + raise ValueError(f"Strategy must be one of {valid_strategies}") + + if self.max_attempts < 1: + raise ValueError("max_attempts must be positive") + + if self.max_mutations < 0: + raise ValueError("max_mutations must be non-negative") + + +def get_reverse_complement(sequence: str) -> str: + """ + Get the reverse complement of a DNA sequence. + + Args: + sequence (str): DNA sequence + + Returns: + str: Reverse complement sequence + """ + return str(Seq(sequence).reverse_complement()) + + +def find_forbidden_sequences( + dna: str, + config: ForbiddenSequenceConfig +) -> List[ForbiddenSequenceViolation]: + """ + Find all forbidden sequences in a DNA sequence. + + Args: + dna (str): DNA sequence to check + config (ForbiddenSequenceConfig): Configuration for checking + + Returns: + List[ForbiddenSequenceViolation]: List of violations found + """ + violations = [] + dna = dna.upper() + + # Combine restriction sites and additional sequences + all_sequences = { + seq: 'restriction_site' if name in config.restriction_sites else 'user_defined' + for name, seq in config.restriction_sites.items() + } + all_sequences.update({seq: 'user_defined' for seq in config.additional_sequences}) + + # Check each sequence and its reverse complement + for seq, seq_type in all_sequences.items(): + # Check forward sequence + for match in re.finditer(seq, dna): + violations.append(ForbiddenSequenceViolation( + sequence=seq, + position=match.start(), + type=seq_type, + is_reverse_complement=False + )) + + # Check reverse complement if enabled + if config.check_reverse_complement: + rev_comp = get_reverse_complement(seq) + for match in re.finditer(rev_comp, dna): + violations.append(ForbiddenSequenceViolation( + sequence=rev_comp, + position=match.start(), + type=seq_type, + is_reverse_complement=True + )) + + return violations + + +def suggest_mutations( + dna: str, + violation: ForbiddenSequenceViolation, + codon_boundaries: bool = True +) -> List[Tuple[str, int]]: + """ + Suggest possible mutations to fix a forbidden sequence violation. + + Args: + dna (str): Original DNA sequence + violation (ForbiddenSequenceViolation): Violation to fix + codon_boundaries (bool): Whether to respect codon boundaries + + Returns: + List[Tuple[str, int]]: List of (mutated sequence, position) pairs + """ + mutations = [] + seq_len = len(violation.sequence) + pos = violation.position + + # Define possible nucleotide substitutions + substitutions = { + 'A': ['T', 'C', 'G'], + 'T': ['A', 'C', 'G'], + 'C': ['A', 'T', 'G'], + 'G': ['A', 'T', 'C'] + } + + # Try each position in the forbidden sequence + for i in range(seq_len): + current_pos = pos + i + + # Skip if not on codon boundary when required + if codon_boundaries and current_pos % 3 != 0: + continue + + original_base = dna[current_pos] + + # Try each possible substitution + for new_base in substitutions[original_base]: + mutated_seq = ( + dna[:current_pos] + + new_base + + dna[current_pos + 1:] + ) + + # Verify the mutation removes the forbidden sequence + if not any( + v.sequence == violation.sequence + for v in find_forbidden_sequences( + mutated_seq[pos:pos + seq_len], + ForbiddenSequenceConfig( + additional_sequences={violation.sequence} + ) + ) + ): + mutations.append((mutated_seq, current_pos)) + + return mutations + + +def fix_forbidden_sequences( + dna: str, + config: ForbiddenSequenceConfig, + protein: Optional[str] = None +) -> Tuple[str, List[Tuple[int, str, str]]]: + """ + Attempt to fix forbidden sequences in a DNA sequence. + + Args: + dna (str): DNA sequence to fix + config (ForbiddenSequenceConfig): Configuration for fixing + protein (Optional[str]): Original protein sequence to maintain + + Returns: + Tuple[str, List[Tuple[int, str, str]]]: + - Fixed DNA sequence + - List of (position, original, new) changes made + """ + violations = find_forbidden_sequences(dna, config) + if not violations: + return dna, [] + + changes_made = [] + fixed_dna = dna + mutations_remaining = config.max_mutations + + for violation in violations: + if mutations_remaining <= 0: + break + + # Get possible mutations + mutations = suggest_mutations( + fixed_dna, + violation, + codon_boundaries=protein is not None + ) + + # Apply first valid mutation + for mutated_seq, position in mutations: + # Verify protein sequence is maintained if provided + if protein and str(Seq(mutated_seq).translate()) != protein: + continue + + # Record the change + changes_made.append(( + position, + fixed_dna[position], + mutated_seq[position] + )) + + fixed_dna = mutated_seq + mutations_remaining -= 1 + break + + return fixed_dna, changes_made + + +def format_forbidden_sequence_report( + violations: List[ForbiddenSequenceViolation], + changes: List[Tuple[int, str, str]] = None +) -> str: + """ + Format a report of forbidden sequence violations and fixes. + + Args: + violations (List[ForbiddenSequenceViolation]): Found violations + changes (List[Tuple[int, str, str]], optional): Changes made to fix violations + + Returns: + str: Formatted report + """ + report = [] + report.append("Forbidden Sequence Analysis") + report.append("=" * 40) + + if not violations: + report.append("No forbidden sequences found.") + return "\n".join(report) + + report.append(f"Found {len(violations)} forbidden sequence(s):") + for v in violations: + report.append(f"\n- Sequence: {v.sequence}") + report.append(f" Position: {v.position}") + report.append(f" Type: {v.type}") + if v.is_reverse_complement: + report.append(" (Reverse Complement)") + + if changes: + report.append("\nChanges made:") + for pos, old, new in changes: + report.append(f"Position {pos}: {old} → {new}") + + return "\n".join(report) From 1e5123c143df328f103300c494de5782516bd172 Mon Sep 17 00:00:00 2001 From: dansuissa <150948900+dansuissa@users.noreply.github.com> Date: Thu, 20 Feb 2025 16:04:48 +0100 Subject: [PATCH 2/2] Update CodonPrediction.py --- CodonTransformer/CodonPrediction.py | 270 ++++++++++++++++++++-------- 1 file changed, 193 insertions(+), 77 deletions(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index ef14b2f..ac5b8aa 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -20,6 +20,34 @@ PreTrainedTokenizerFast, ) +# Local imports - adjust paths based on your project structure +try: + from CodonComplexity import ( # Try direct import first + ComplexityConfig, + check_sequence_complexity, + format_complexity_report, + get_total_complexity_score, + ) + from CodonRestrictionSites import ( # Import restriction site handling + ForbiddenSequenceConfig, + find_forbidden_sequences, + fix_forbidden_sequences, + format_forbidden_sequence_report, + ) +except ImportError: + from .CodonComplexity import ( # Try relative import + ComplexityConfig, + check_sequence_complexity, + format_complexity_report, + get_total_complexity_score, + ) + from .CodonRestrictionSites import ( + ForbiddenSequenceConfig, + find_forbidden_sequences, + fix_forbidden_sequences, + format_forbidden_sequence_report, + ) + from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( AMINO_ACID_TO_INDEX, @@ -32,18 +60,28 @@ def predict_dna_sequence( - protein: str, - organism: Union[int, str], - device: torch.device, - tokenizer: Union[str, PreTrainedTokenizerFast] = None, - model: Union[str, torch.nn.Module] = None, - attention_type: str = "original_full", - deterministic: bool = True, - temperature: float = 0.2, - top_p: float = 0.95, - num_sequences: int = 1, - match_protein: bool = False, -) -> Union[DNASequencePrediction, List[DNASequencePrediction]]: + protein: str, + organism: Union[int, str], + device: torch.device, + tokenizer: Union[str, PreTrainedTokenizerFast] = None, + model: Union[str, torch.nn.Module] = None, + attention_type: str = "original_full", + deterministic: bool = True, + temperature: float = 0.2, + top_p: float = 0.95, + num_sequences: int = 1, + match_protein: bool = False, + check_complexity: bool = False, + complexity_config: Optional[ComplexityConfig] = None, + check_forbidden: bool = False, + forbidden_config: Optional[ForbiddenSequenceConfig] = None, + max_attempts: int = 10 +) -> Union[ + DNASequencePrediction, + List[DNASequencePrediction], + Tuple[DNASequencePrediction, str], + Tuple[DNASequencePrediction, str, str] +]: """ Predict the DNA sequence(s) for a given protein using the CodonTransformer model. @@ -72,8 +110,6 @@ def predict_dna_sequence( temperature (float, optional): A value controlling the randomness of predictions during non-deterministic decoding. Lower values (e.g., 0.2) make the model more conservative, while higher values (e.g., 0.8) increase randomness. - Using high temperatures may result in prediction of DNA sequences that - do not translate to the input protein. Recommended values are: - Low randomness: 0.2 - Medium randomness: 0.5 @@ -81,21 +117,29 @@ def predict_dna_sequence( The temperature must be a positive float. Defaults to 0.2. top_p (float, optional): The cumulative probability threshold for nucleus sampling. Tokens with cumulative probability up to top_p are considered for sampling. - This parameter helps balance diversity and coherence in the predicted DNA sequences. - The value must be a float between 0 and 1. Defaults to 0.95. + This parameter helps balance diversity and coherence. Defaults to 0.95. num_sequences (int, optional): The number of DNA sequences to generate. Only applicable when deterministic is False. Defaults to 1. match_protein (bool, optional): Ensures the predicted DNA sequence is translated to the input protein sequence by sampling from only the respective codons of given amino acids. Defaults to False. + check_complexity (bool, optional): Whether to check sequence complexity and attempt + regeneration if complexity checks fail. Defaults to False. + complexity_config (Optional[ComplexityConfig], optional): Configuration for complexity + checking. If None, default configuration is used. See ComplexityConfig class. + check_forbidden (bool, optional): Whether to check for forbidden sequences like + restriction sites. Defaults to False. + forbidden_config (Optional[ForbiddenSequenceConfig], optional): Configuration for + forbidden sequence handling. If None, default configuration is used. + max_attempts (int, optional): Maximum number of attempts to generate valid sequences. + Used for both complexity and forbidden sequence checks. Defaults to 10. Returns: - Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects - containing the prediction results: - - organism (str): Name of the organism used for prediction. - - protein (str): Input protein sequence for which DNA sequence is predicted. - - processed_input (str): Processed input sequence (merged protein and DNA). - - predicted_dna (str): Predicted DNA sequence. + Union[DNASequencePrediction, List[DNASequencePrediction], Tuple[DNASequencePrediction, str], + Tuple[DNASequencePrediction, str, str]]: + - If no checks enabled: Single prediction or list of predictions + - If one check enabled: Tuple of (prediction, report) + - If both checks enabled: Tuple of (prediction, complexity_report, forbidden_report) Raises: ValueError: If the protein sequence is empty, if the organism is invalid, @@ -104,59 +148,15 @@ def predict_dna_sequence( Note: This function uses ORGANISM2ID, INDEX2TOKEN, and AMINO_ACID_TO_INDEX dictionaries - imported from CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their - corresponding IDs. INDEX2TOKEN maps model output indices (token IDs) to - respective codons. AMINO_ACID_TO_INDEX maps each amino acid and stop symbol to indices - of codon tokens that translate to it. - - Example: - >>> import torch - >>> from transformers import AutoTokenizer, BigBirdForMaskedLM - >>> from CodonTransformer.CodonPrediction import predict_dna_sequence - >>> from CodonTransformer.CodonJupyter import format_model_output - >>> - >>> # Set up device - >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - >>> - >>> # Load tokenizer and model - >>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") - >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") - >>> model = model.to(device) - >>> - >>> # Define protein sequence and organism - >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA" - >>> organism = "Escherichia coli general" - >>> - >>> # Predict DNA sequence with deterministic decoding (single sequence) - >>> output = predict_dna_sequence( - ... protein=protein, - ... organism=organism, - ... device=device, - ... tokenizer=tokenizer, - ... model=model, - ... attention_type="original_full", - ... deterministic=True - ... ) - >>> - >>> # Predict multiple DNA sequences with low randomness and top_p sampling - >>> output_random = predict_dna_sequence( - ... protein=protein, - ... organism=organism, - ... device=device, - ... tokenizer=tokenizer, - ... model=model, - ... attention_type="original_full", - ... deterministic=False, - ... temperature=0.2, - ... top_p=0.95, - ... num_sequences=3 - ... ) - >>> - >>> print(format_model_output(output)) - >>> for i, seq in enumerate(output_random, 1): - ... print(f"Sequence {i}:") - ... print(format_model_output(seq)) - ... print() + from CodonTransformer.CodonUtils. When complexity checks are enabled, it checks: + - Repeat sequence detection (≥20bp or Tm ≥60°C) + - GC content checks (25-65% globally, max 52% deviation) + - Homopolymer identification + - HIS tag pattern validation + When forbidden sequence checks are enabled, it checks: + - Common restriction sites (BsaI, BbsI, etc.) + - User-defined forbidden sequences + - Both forward and reverse complement sequences """ if not protein: raise ValueError("Protein sequence cannot be empty.") @@ -171,10 +171,126 @@ def predict_dna_sequence( raise ValueError("num_sequences must be a positive integer.") if deterministic and num_sequences > 1: - raise ValueError( - "Multiple sequences can only be generated in non-deterministic mode." + raise ValueError("Multiple sequences can only be generated in non-deterministic mode.") + + # Initialize configurations + if check_complexity and complexity_config is None: + complexity_config = ComplexityConfig() + if check_forbidden and forbidden_config is None: + forbidden_config = ForbiddenSequenceConfig() + + def generate_single_prediction(current_temperature: float = temperature): + """Helper function to generate one prediction.""" + nonlocal deterministic + # Force non-deterministic mode for multiple attempts + local_deterministic = False if (check_complexity or check_forbidden) else deterministic + + prediction = predict_dna_sequence( + protein=protein, + organism=organism, + device=device, + tokenizer=tokenizer, + model=model, + attention_type=attention_type, + deterministic=local_deterministic, + temperature=current_temperature, + top_p=top_p, + num_sequences=1, + match_protein=match_protein, + check_complexity=False, + check_forbidden=False ) + return prediction if isinstance(prediction, DNASequencePrediction) else prediction[0] + + # Handle complexity and forbidden sequence checking + if check_complexity or check_forbidden: + best_prediction = None + best_complexity_score = float('inf') + best_forbidden_count = float('inf') + best_complexity_report = None + best_forbidden_report = None + + for attempt in range(max_attempts): + # Gradually increase temperature for more diversity + current_temperature = temperature + (attempt * 0.1) + + prediction = generate_single_prediction(current_temperature) + + # Check complexity if enabled + complexity_score = float('inf') + complexity_report = None + if check_complexity: + violations = check_sequence_complexity( + prediction.predicted_dna, + config=complexity_config + ) + complexity_score = get_total_complexity_score(violations) + complexity_report = format_complexity_report( + prediction.predicted_dna, + violations + ) + + # Check forbidden sequences if enabled + forbidden_count = float('inf') + forbidden_report = None + if check_forbidden: + violations = find_forbidden_sequences( + prediction.predicted_dna, + forbidden_config + ) + forbidden_count = len(violations) + + if violations and forbidden_config.strategy in ('mutate', 'hybrid'): + fixed_dna, changes = fix_forbidden_sequences( + prediction.predicted_dna, + forbidden_config, + protein + ) + if changes: # If fixes were made + prediction = DNASequencePrediction( + organism=prediction.organism, + protein=prediction.protein, + processed_input=prediction.processed_input, + predicted_dna=fixed_dna + ) + # Recheck violations after fixes + violations = find_forbidden_sequences(fixed_dna, forbidden_config) + forbidden_count = len(violations) + + forbidden_report = format_forbidden_sequence_report( + violations, + changes if 'changes' in locals() else None + ) + # Update best prediction if better + if complexity_score < best_complexity_score or forbidden_count < best_forbidden_count: + best_prediction = prediction + best_complexity_score = complexity_score + best_forbidden_count = forbidden_count + best_complexity_report = complexity_report + best_forbidden_report = forbidden_report + + # Check if prediction passes all enabled checks + passes_complexity = not check_complexity or complexity_score < 10 + passes_forbidden = not check_forbidden or forbidden_count == 0 + + if passes_complexity and passes_forbidden: + if check_complexity and check_forbidden: + return prediction, complexity_report, forbidden_report + elif check_complexity: + return prediction, complexity_report + else: + return prediction, forbidden_report + + # Return best attempt if no perfect solution found + if check_complexity and check_forbidden: + return best_prediction, best_complexity_report, best_forbidden_report + elif check_complexity: + return best_prediction, best_complexity_report + else: + return best_prediction, best_forbidden_report + + # Standard prediction logic for when no checks are needed # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer)