Quickstart: CodonFM
Estimated time to complete: 1 hour 20 minutes on a T4 GPU
Learning Goals
- Predict mRFP expression from mRNA sequences
- Learn how to use a pretrained Encodon and Random Forest Regressor model
- Use EncodonInference wrapper for embedding extraction
Prerequisites
- Python version 3.12
- T4 GPU
Introduction
CodonFM Encodon is a suite of transformer-based models that predict masked codons in mRNA sequences to enable variant effect interpretation and codon optimization. The models process sequences up to 2,046 codons (6,138 nucleotides) and output codon probability distributions for each position.
In this quickstart, we will use a pretrained Encodon and Random Forest Regressor model to predict mRFP expression using the Sanofi mRFP expression dataset. We will use the EncodonInference wrapper for embedding extraction.
Setup
Google Colab
Before starting, connect to the T4 GPU runtime hosted for free by Google Colab using the dropdown menu in the upper right hand corner of this notebook. Please note that Google Colab continuously updates CUDA/pytorch versions, and the following installation was tested to work on T4 GPUs in Oct 2025 using Python 3.12.
Local Environment
The below installation procedure can work on local machines, but for guaranteed reproducibility the following Dockerfile is suggested for building and running inside a container: https://github.com/NVIDIA-Digital-Bio/CodonFM/blob/main/Dockerfile
Step 1: Clone repository and install dependencies
!git clone --branch colab-environment https://github.com/NVIDIA-Digital-Bio/CodonFMOutput:
Cloning into 'CodonFM'...
remote: Enumerating objects: 151, done.
remote: Counting objects: 100% (151/151), done.
remote: Compressing objects: 100% (132/132), done.
remote: Total 151 (delta 26), reused 137 (delta 14), pack-reused 0 (from 0)
Receiving objects: 100% (151/151), 218.22 KiB | 16.79 MiB/s, done.
Resolving deltas: 100% (26/26), done.%cd CodonFMOutput:
/content/CodonFM/CondonFM!uv pip install -r requirements.txtOutput:
Using Python 3.12.12 environment at: /usr
Resolved 162 packages in 1.32s
Prepared 40 packages in 2.06s
Uninstalled 1 package in 61ms
Installed 40 packages in 51ms
+ anndata==0.12.3
+ array-api-compat==1.12.0
+ autopage==0.5.2
+ biopython==1.85
+ cfgv==3.4.0
+ cliff==4.11.0
+ cmaes==0.12.0
+ cmd2==2.7.0
+ colorlog==6.10.1
+ crc32c==2.8
+ distlib==0.4.0
+ donfig==0.8.1.post1
+ hydra-colorlog==1.2.0
+ hydra-core==1.3.2
+ hydra-optuna-sweeper==1.2.0
+ identify==2.6.15
+ ipython-autotime==0.3.2
+ jedi==0.19.2
+ legacy-api-wrap==1.4.1
+ lightning==2.5.2
+ lightning-utilities==0.15.2
+ lru-dict==1.3.0
+ ninja==1.11.1.1
+ nodeenv==1.9.1
+ numcodecs==0.16.3
+ optuna==2.10.1
+ pre-commit==4.0.1
+ pyfaidx==0.8.1.4
+ pytorch-lightning==2.5.5
+ rich-argparse==1.7.1
+ rootutils==1.0.7
+ scanpy==1.11.3
- scipy==1.16.2
+ scipy==1.15.3
+ session-info2==0.2.3
+ sh==2.2.2
+ stevedore==5.5.0
+ torchmetrics==1.8.2
+ virtualenv==20.35.3
+ xformers==0.0.32.post2
+ zarr==3.1.3We install transformers and tokenizers to the appropriate versions using a separate command in order to avoid unexpected behavior or modifying other dependencies.
!uv pip install --system --no-deps transformers==4.54.1 tokenizers==0.21Output:
Using Python 3.12.12 environment at: /usr
Resolved 2 packages in 3ms
Prepared 2 packages in 462ms
Uninstalled 2 packages in 679ms
Installed 2 packages in 46ms
- tokenizers==0.22.1
+ tokenizers==0.21.0
- transformers==4.57.1
+ transformers==4.54.1!uv pip checkOutput:
Using Python 3.12.12 environment at: /usr
Checked 712 packages in 88ms
All installed packages are compatibleStep 2: Import libraries and Encodon modules
Next, import the necessary libraries and the Encodon module. We also check to ensure PyTorch and CUDA are working.
import os
import sys
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
# ML libraries
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split, GridSearchCV
# Visualization
import matplotlib.pyplot as plt
# Add project paths
sys.path.append('..')
# Import Encodon modules
from src.inference.encodon import EncodonInference
from src.inference.task_types import TaskTypes
from src.data.metadata import MetadataFields
# Import additional modules for dataset handling
from src.data.codon_bert_dataset import CodonBertDataset
from src.data.preprocess.codon_sequence import process_item
from torch.utils.data import DataLoader
# Fix random seed
torch.manual_seed(42)
np.random.seed(42)
print("✅ Libraries imported successfully!")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if "+cu126" not in torch.__version__:
print(f"⚠️ Warning: Torch build ({torch.__version__}) is not cu126, indicating the Colab environment has changed"
" since this was last tested in October 2025. Functionality might be affected.")Output:
✅ Libraries imported successfully!
PyTorch: 2.8.0+cu126
CUDA available: TrueStep 3: Download mRFP dataset
# Download mRFP Expression dataset if it doesn't exist
# NOTE: This assumes the notebook was launched from the codon-fm source directory.
# NOTE: otherwise change the path for the `subprocess` launch to correspond to the data_scripts path correctly
import subprocess
data_path = "mRFP_Expression.csv"
root_path = "."
if not os.path.exists(data_path):
print("📥 Downloading mRFP Expression dataset...")
try:
subprocess.run([
"python", "data_scripts/download_preprocess_codonbert_bench.py",
"--dataset", "mRFP_Expression.csv",
"--output-dir", root_path
], check=True)
print("✅ Dataset downloaded and preprocessed successfully!")
except subprocess.CalledProcessError as e:
print(f"❌ Error downloading dataset: {e}")
print("Please ensure the data_scripts are available and run manually if needed.")
else:
print("✅ Dataset already exists!")Output:
📥 Downloading mRFP Expression dataset...
✅ Dataset downloaded and preprocessed successfully!Step 4: Load pretrained Encodon model
!git clone https://huggingface.co/nvidia/NV-CodonFM-Encodon-1B-v1# Define checkpoint paths
checkpoint_paths = [
"NV-CodonFM-Encodon-1B-v1/NV-CodonFM-Encodon-1B-v1.safetensors",
]
checkpoint_path = checkpoint_paths[0]
model_loaded = False
if os.path.exists(checkpoint_path):
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create EncodonInference wrapper
encodon_model = EncodonInference(
model_path=checkpoint_path,
task_type=TaskTypes.EMBEDDING_PREDICTION,
)
# Configure model
encodon_model.configure_model()
encodon_model.to(device)
encodon_model.eval()
print(f"✅ Model loaded from: {checkpoint_path}")
print(f"Device: {device}")
print(f"Parameters: {sum(p.numel() for p in encodon_model.model.parameters()):,}")
model_loaded = True
except Exception as e:
print(f"Failed to load {checkpoint_path}: {e}")
if not model_loaded:
print("❌ Could not load any model. Please check checkpoint paths.")Output:
✅ Model loaded from: NV-CodonFM-Encodon-80M-v1/NV-CodonFM-Encodon-80M-v1.ckpt
Device: cuda
Parameters: 76,833,861Step 5: Load dataset
# Load mRFP Expression dataset
data_loaded = False
if os.path.exists(data_path):
try:
data = pd.read_csv(data_path)
print(f"✅ Loaded {len(data)} samples from: {data_path}")
print(f"Columns: {list(data.columns)}")
if 'split' in data.columns:
print(f"Data splits: {data['split'].value_counts().to_dict()}")
print(f"Target range: [{data['value'].min():.3f}, {data['value'].max():.3f}]")
data_loaded = True
except Exception as e:
print(f"Failed to load {data_path}: {e}")
if not data_loaded:
print("❌ Could not load mRFP data")Output:
✅ Loaded 1459 samples from: mRFP_Expression.csv
Columns: ['id', 'ref_seq', 'value', 'dataset', 'split']
Data splits: {'train': 1021, 'val': 219, 'test': 219}
Target range: [7.361, 11.379]Step 6: Preprocess data
This step takes an estimated 30 minutes using a T4 GPU.
batch_size = 16
if data_loaded and model_loaded:
print("=== DATA PREPROCESSING ===")
# Create dataset
dataset = CodonBertDataset(
data_path=data_path,
tokenizer=encodon_model.tokenizer,
process_item=lambda seq, tokenizer: process_item(
seq,
context_length=encodon_model.model.hparams.max_position_embeddings,
tokenizer=tokenizer
)
)
print(f"Processing {len(dataset)} sequences")
print(f"Target range: [{dataset.data['value'].min():.3f}, {dataset.data['value'].max():.3f}]")
# Create data loader for batch processing
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# Extract embeddings using the dataset
print("\nExtracting embeddings...")
all_embeddings = []
all_labels = []
for batch in tqdm(dataloader):
batch_input = {
MetadataFields.INPUT_IDS: batch[MetadataFields.INPUT_IDS].to(encodon_model.device),
MetadataFields.ATTENTION_MASK: batch[MetadataFields.ATTENTION_MASK].to(encodon_model.device),
}
# Extract embeddings
output = encodon_model.extract_embeddings(batch_input)
all_embeddings.append(output.embeddings)
all_labels.append(batch[MetadataFields.LABELS].numpy())
# Combine all embeddings and labels
embeddings = np.vstack(all_embeddings)
targets = np.concatenate(all_labels)
print(f"\n✅ Extracted embeddings: {embeddings.shape}")
else:
print("❌ Skipping preprocessing")Output:
=== DATA PREPROCESSING ===
Processing 1459 sequences
Target range: [7.361, 11.379]
Extracting embeddings...
100%|██████████| 92/92 [01:24<00:00, 1.09it/s]
✅ Extracted embeddings: (1459, 1024)Step 7: Train random forest
We use zero shot sequence embeddings of the model and train the random forest model to predict the mRFP expression. Then we calculate the metrics for the training, validation, and test sets so we can plot them in the following step.
This step takes around 50 minutes using a T4 GPU.
if 'embeddings' in locals():
print("=== TRAINING RANDOM FOREST ===")
# Split data based on the dataset splits
train_mask = dataset.data['split'] == 'train'
val_mask = dataset.data['split'] == 'val'
test_mask = dataset.data['split'] == 'test'
X_train = embeddings[train_mask]
X_val = embeddings[val_mask]
X_test = embeddings[test_mask]
y_train = targets[train_mask]
y_val = targets[val_mask]
y_test = targets[test_mask]
print(f"Train: {X_train.shape[0]}, Val: {X_val.shape[0]}, Test: {X_test.shape[0]}")
# Combine train and validation for GridSearchCV
X_train_val = np.vstack([X_train, X_val])
y_train_val = np.concatenate([y_train, y_val])
# Create validation indices for GridSearchCV
# Train indices: 0 to len(X_train)-1
# Val indices: len(X_train) to len(X_train_val)-1
train_indices = list(range(len(X_train)))
val_indices = list(range(len(X_train), len(X_train_val)))
cv_splits = [(train_indices, val_indices)]
# Define hyperparameter grid
param_grid = {
'n_estimators': [1000],
'max_depth': [10],
'min_samples_split': [25],
'min_samples_leaf': [2],
}
# Create base model
rf_base = RandomForestRegressor(random_state=42, n_jobs=-1)
# Grid search with validation split
print("Performing hyperparameter tuning...")
grid_search = GridSearchCV(
estimator=rf_base,
param_grid=param_grid,
cv=cv_splits,
scoring='r2',
n_jobs=-1,
verbose=1
)
# Fit grid search
grid_search.fit(X_train_val, y_train_val)
# Get best model
rf = grid_search.best_estimator_
print(f"\n=== BEST PARAMETERS ===")
for param, value in grid_search.best_params_.items():
print(f"{param}: {value}")
print(f"Best validation R²: {grid_search.best_score_:.4f}")
# Train final model on train set only
rf.fit(X_train, y_train)
# Predictions on all splits
y_pred_train = rf.predict(X_train)
y_pred_val = rf.predict(X_val)
y_pred_test = rf.predict(X_test)
# Calculate metrics for all splits
train_r2 = r2_score(y_train, y_pred_train)
val_r2 = r2_score(y_val, y_pred_val)
test_r2 = r2_score(y_test, y_pred_test)
train_spearmanr, _ = spearmanr(y_train, y_pred_train)
val_spearmanr, _ = spearmanr(y_val, y_pred_val)
test_spearmanr, _ = spearmanr(y_test, y_pred_test)
print(f"\n=== FINAL RESULTS ===")
print(f"Train R²: {train_r2:.4f} | Spearman r: {train_spearmanr:.4f}")
print(f"Val R²: {val_r2:.4f} | Spearman r: {val_spearmanr:.4f}")
print(f"Test R²: {test_r2:.4f} | Spearman r: {test_spearmanr:.4f}")
else:
print("❌ Cannot train - missing data")Output:
=== TRAINING RANDOM FOREST ===
Train: 1021, Val: 219, Test: 219
Performing hyperparameter tuning...
Fitting 1 folds for each of 1 candidates, totalling 1 fits
=== BEST PARAMETERS ===
max_depth: 10
min_samples_leaf: 2
min_samples_split: 25
n_estimators: 1000
Best validation R²: 0.2756
=== FINAL RESULTS ===
Train R²: 0.7091 | Spearman r: 0.8698
Val R²: 0.2756 | Spearman r: 0.5400
Test R²: 0.3018 | Spearman r: 0.6521Step 8: Plot results
if 'y_test' in locals():
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('mRFP Expression Prediction Results', fontsize=16)
# Predicted vs True for all splits
splits = [('Train', y_train, y_pred_train, train_r2),
('Validation', y_val, y_pred_val, val_r2),
('Test', y_test, y_pred_test, test_r2)]
for i, (split_name, y_true, y_pred, r2) in enumerate(splits):
axes[0, i].scatter(y_true, y_pred, alpha=0.6)
min_val = min(y_true.min(), y_pred.min())
max_val = max(y_true.max(), y_pred.max())
axes[0, i].plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2)
axes[0, i].set_xlabel('True mRFP Expression')
axes[0, i].set_ylabel('Predicted mRFP Expression')
axes[0, i].set_title(f'{split_name}\nR² = {r2:.3f}')
axes[0, i].grid(True, alpha=0.3)
# Performance comparison
r2_scores = [train_r2, val_r2, test_r2]
spearmanr_scores = [train_spearmanr, val_spearmanr, test_spearmanr]
x_pos = np.arange(len(splits))
width = 0.35
axes[1, 0].bar(x_pos - width/2, r2_scores, width, label='R²', alpha=0.7)
axes[1, 0].bar(x_pos + width/2, spearmanr_scores, width, label='Spearman r', alpha=0.7)
axes[1, 0].set_xlabel('Dataset Split')
axes[1, 0].set_ylabel('Score')
axes[1, 0].set_title('Performance Comparison')
axes[1, 0].set_xticks(x_pos)
axes[1, 0].set_xticklabels(['Train', 'Val', 'Test'])
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Target distribution across splits
axes[1, 1].hist([y_train, y_val, y_test], bins=15, alpha=0.7,
label=['Train', 'Val', 'Test'], edgecolor='black')
axes[1, 1].set_xlabel('mRFP Expression')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].set_title('Target Distribution by Split')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
# Feature importance
top_features = np.argsort(rf.feature_importances_)[-10:]
axes[1, 2].barh(range(10), rf.feature_importances_[top_features])
axes[1, 2].set_xlabel('Importance')
axes[1, 2].set_title('Top 10 Feature Importances')
axes[1, 2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
else:
print("❌ No results to plot")
Model Outputs
splits: predicted vs true for all dataset splits- mRFP expression plots for training, validation, and test splits
<SET_NAME>_r2and<SET_NAME>_spearmanr: R^2 and Spearman r correlation for all dataset splitstop_features: feature importances
Troubleshooting & Optimization Tips
Common Issues and Solutions:
1. Model Loading Issues
- Problem: Checkpoint not found
- Solution: Update checkpoint paths in Step 4
- Check: Verify checkpoint files exist and are accessible
2. Data Loading Issues
- Problem: Dataset not found
- Solution: Update data paths in Step 5
- Check: Ensure CSV files have required columns (id, ref_seq, value)
3. Memory Issues
- Problem: CUDA out of memory
- Solution: Reduce batch_size in data preprocessing section
- Alternative: Use CPU by setting device='cpu'
4. Performance Issues
- Problem: Low R² scores
- Solutions:
- Try larger models (600M or 1B parameters)
- Implement fine-tuning instead of just embeddings
- Tune Random Forest hyperparameters
- Check data quality and preprocessing
Optimization Strategies:
1. Model Architecture
- 80M model: Fast, good for initial experiments
- 600M model: Better performance, moderate cost
- 1B model: Best performance, highest computational cost
2. Hyperparameter Tuning
# Try these Random Forest parameters:
rf_params = {
'n_estimators': [100, 200, 500],
'max_depth': [10, 15, 20, None],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4]
}Contact and Acknowledgments
For issues with this quickstart please contact Timur Rvachov at trvachov@nvidia.com.
We thank the NVIDIA Corporation team for developing this suite of models.
Responsible Use
We are committed to advancing the responsible development and use of artificial intelligence. Please follow our Acceptable Use Policy when engaging with our services.