Skip to content

Restriction sites #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 193 additions & 77 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@
PreTrainedTokenizerFast,
)

# Local imports - adjust paths based on your project structure
try:
from CodonComplexity import ( # Try direct import first
ComplexityConfig,
check_sequence_complexity,
format_complexity_report,
get_total_complexity_score,
)
from CodonRestrictionSites import ( # Import restriction site handling
ForbiddenSequenceConfig,
find_forbidden_sequences,
fix_forbidden_sequences,
format_forbidden_sequence_report,
)
except ImportError:
from .CodonComplexity import ( # Try relative import
ComplexityConfig,
check_sequence_complexity,
format_complexity_report,
get_total_complexity_score,
)
from .CodonRestrictionSites import (
ForbiddenSequenceConfig,
find_forbidden_sequences,
fix_forbidden_sequences,
format_forbidden_sequence_report,
)

from CodonTransformer.CodonData import get_merged_seq
from CodonTransformer.CodonUtils import (
AMINO_ACID_TO_INDEX,
Expand All @@ -32,18 +60,28 @@


def predict_dna_sequence(
protein: str,
organism: Union[int, str],
device: torch.device,
tokenizer: Union[str, PreTrainedTokenizerFast] = None,
model: Union[str, torch.nn.Module] = None,
attention_type: str = "original_full",
deterministic: bool = True,
temperature: float = 0.2,
top_p: float = 0.95,
num_sequences: int = 1,
match_protein: bool = False,
) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
protein: str,
organism: Union[int, str],
device: torch.device,
tokenizer: Union[str, PreTrainedTokenizerFast] = None,
model: Union[str, torch.nn.Module] = None,
attention_type: str = "original_full",
deterministic: bool = True,
temperature: float = 0.2,
top_p: float = 0.95,
num_sequences: int = 1,
match_protein: bool = False,
check_complexity: bool = False,
complexity_config: Optional[ComplexityConfig] = None,
check_forbidden: bool = False,
forbidden_config: Optional[ForbiddenSequenceConfig] = None,
max_attempts: int = 10
) -> Union[
DNASequencePrediction,
List[DNASequencePrediction],
Tuple[DNASequencePrediction, str],
Tuple[DNASequencePrediction, str, str]
]:
"""
Predict the DNA sequence(s) for a given protein using the CodonTransformer model.

Expand Down Expand Up @@ -72,30 +110,36 @@ def predict_dna_sequence(
temperature (float, optional): A value controlling the randomness of predictions
during non-deterministic decoding. Lower values (e.g., 0.2) make the model
more conservative, while higher values (e.g., 0.8) increase randomness.
Using high temperatures may result in prediction of DNA sequences that
do not translate to the input protein.
Recommended values are:
- Low randomness: 0.2
- Medium randomness: 0.5
- High randomness: 0.8
The temperature must be a positive float. Defaults to 0.2.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
Tokens with cumulative probability up to top_p are considered for sampling.
This parameter helps balance diversity and coherence in the predicted DNA sequences.
The value must be a float between 0 and 1. Defaults to 0.95.
This parameter helps balance diversity and coherence. Defaults to 0.95.
num_sequences (int, optional): The number of DNA sequences to generate. Only applicable
when deterministic is False. Defaults to 1.
match_protein (bool, optional): Ensures the predicted DNA sequence is translated
to the input protein sequence by sampling from only the respective codons of
given amino acids. Defaults to False.
check_complexity (bool, optional): Whether to check sequence complexity and attempt
regeneration if complexity checks fail. Defaults to False.
complexity_config (Optional[ComplexityConfig], optional): Configuration for complexity
checking. If None, default configuration is used. See ComplexityConfig class.
check_forbidden (bool, optional): Whether to check for forbidden sequences like
restriction sites. Defaults to False.
forbidden_config (Optional[ForbiddenSequenceConfig], optional): Configuration for
forbidden sequence handling. If None, default configuration is used.
max_attempts (int, optional): Maximum number of attempts to generate valid sequences.
Used for both complexity and forbidden sequence checks. Defaults to 10.

Returns:
Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects
containing the prediction results:
- organism (str): Name of the organism used for prediction.
- protein (str): Input protein sequence for which DNA sequence is predicted.
- processed_input (str): Processed input sequence (merged protein and DNA).
- predicted_dna (str): Predicted DNA sequence.
Union[DNASequencePrediction, List[DNASequencePrediction], Tuple[DNASequencePrediction, str],
Tuple[DNASequencePrediction, str, str]]:
- If no checks enabled: Single prediction or list of predictions
- If one check enabled: Tuple of (prediction, report)
- If both checks enabled: Tuple of (prediction, complexity_report, forbidden_report)

Raises:
ValueError: If the protein sequence is empty, if the organism is invalid,
Expand All @@ -104,59 +148,15 @@ def predict_dna_sequence(

Note:
This function uses ORGANISM2ID, INDEX2TOKEN, and AMINO_ACID_TO_INDEX dictionaries
imported from CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their
corresponding IDs. INDEX2TOKEN maps model output indices (token IDs) to
respective codons. AMINO_ACID_TO_INDEX maps each amino acid and stop symbol to indices
of codon tokens that translate to it.

Example:
>>> import torch
>>> from transformers import AutoTokenizer, BigBirdForMaskedLM
>>> from CodonTransformer.CodonPrediction import predict_dna_sequence
>>> from CodonTransformer.CodonJupyter import format_model_output
>>>
>>> # Set up device
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>>
>>> # Load tokenizer and model
>>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
>>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
>>> model = model.to(device)
>>>
>>> # Define protein sequence and organism
>>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"
>>> organism = "Escherichia coli general"
>>>
>>> # Predict DNA sequence with deterministic decoding (single sequence)
>>> output = predict_dna_sequence(
... protein=protein,
... organism=organism,
... device=device,
... tokenizer=tokenizer,
... model=model,
... attention_type="original_full",
... deterministic=True
... )
>>>
>>> # Predict multiple DNA sequences with low randomness and top_p sampling
>>> output_random = predict_dna_sequence(
... protein=protein,
... organism=organism,
... device=device,
... tokenizer=tokenizer,
... model=model,
... attention_type="original_full",
... deterministic=False,
... temperature=0.2,
... top_p=0.95,
... num_sequences=3
... )
>>>
>>> print(format_model_output(output))
>>> for i, seq in enumerate(output_random, 1):
... print(f"Sequence {i}:")
... print(format_model_output(seq))
... print()
from CodonTransformer.CodonUtils. When complexity checks are enabled, it checks:
- Repeat sequence detection (≥20bp or Tm ≥60°C)
- GC content checks (25-65% globally, max 52% deviation)
- Homopolymer identification
- HIS tag pattern validation
When forbidden sequence checks are enabled, it checks:
- Common restriction sites (BsaI, BbsI, etc.)
- User-defined forbidden sequences
- Both forward and reverse complement sequences
"""
if not protein:
raise ValueError("Protein sequence cannot be empty.")
Expand All @@ -171,10 +171,126 @@ def predict_dna_sequence(
raise ValueError("num_sequences must be a positive integer.")

if deterministic and num_sequences > 1:
raise ValueError(
"Multiple sequences can only be generated in non-deterministic mode."
raise ValueError("Multiple sequences can only be generated in non-deterministic mode.")

# Initialize configurations
if check_complexity and complexity_config is None:
complexity_config = ComplexityConfig()
if check_forbidden and forbidden_config is None:
forbidden_config = ForbiddenSequenceConfig()

def generate_single_prediction(current_temperature: float = temperature):
"""Helper function to generate one prediction."""
nonlocal deterministic
# Force non-deterministic mode for multiple attempts
local_deterministic = False if (check_complexity or check_forbidden) else deterministic

prediction = predict_dna_sequence(
protein=protein,
organism=organism,
device=device,
tokenizer=tokenizer,
model=model,
attention_type=attention_type,
deterministic=local_deterministic,
temperature=current_temperature,
top_p=top_p,
num_sequences=1,
match_protein=match_protein,
check_complexity=False,
check_forbidden=False
)
return prediction if isinstance(prediction, DNASequencePrediction) else prediction[0]

# Handle complexity and forbidden sequence checking
if check_complexity or check_forbidden:
best_prediction = None
best_complexity_score = float('inf')
best_forbidden_count = float('inf')
best_complexity_report = None
best_forbidden_report = None

for attempt in range(max_attempts):
# Gradually increase temperature for more diversity
current_temperature = temperature + (attempt * 0.1)

prediction = generate_single_prediction(current_temperature)

# Check complexity if enabled
complexity_score = float('inf')
complexity_report = None
if check_complexity:
violations = check_sequence_complexity(
prediction.predicted_dna,
config=complexity_config
)
complexity_score = get_total_complexity_score(violations)
complexity_report = format_complexity_report(
prediction.predicted_dna,
violations
)

# Check forbidden sequences if enabled
forbidden_count = float('inf')
forbidden_report = None
if check_forbidden:
violations = find_forbidden_sequences(
prediction.predicted_dna,
forbidden_config
)
forbidden_count = len(violations)

if violations and forbidden_config.strategy in ('mutate', 'hybrid'):
fixed_dna, changes = fix_forbidden_sequences(
prediction.predicted_dna,
forbidden_config,
protein
)
if changes: # If fixes were made
prediction = DNASequencePrediction(
organism=prediction.organism,
protein=prediction.protein,
processed_input=prediction.processed_input,
predicted_dna=fixed_dna
)
# Recheck violations after fixes
violations = find_forbidden_sequences(fixed_dna, forbidden_config)
forbidden_count = len(violations)

forbidden_report = format_forbidden_sequence_report(
violations,
changes if 'changes' in locals() else None
)

# Update best prediction if better
if complexity_score < best_complexity_score or forbidden_count < best_forbidden_count:
best_prediction = prediction
best_complexity_score = complexity_score
best_forbidden_count = forbidden_count
best_complexity_report = complexity_report
best_forbidden_report = forbidden_report

# Check if prediction passes all enabled checks
passes_complexity = not check_complexity or complexity_score < 10
passes_forbidden = not check_forbidden or forbidden_count == 0

if passes_complexity and passes_forbidden:
if check_complexity and check_forbidden:
return prediction, complexity_report, forbidden_report
elif check_complexity:
return prediction, complexity_report
else:
return prediction, forbidden_report

# Return best attempt if no perfect solution found
if check_complexity and check_forbidden:
return best_prediction, best_complexity_report, best_forbidden_report
elif check_complexity:
return best_prediction, best_complexity_report
else:
return best_prediction, best_forbidden_report

# Standard prediction logic for when no checks are needed
# Load tokenizer
if not isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = load_tokenizer(tokenizer)
Expand Down
Loading
Loading