Rbio Demo: LLM Post-Training using soft verification
Estimated Time to Complete: 20 minutes with A100 GPU.
Google Colab Note: It is strongly recommended to run this notebook with an A100 GPU, which is only included with Google Colab Pro or Enterprise paid services. Alternatively, a "pay as you go" option is available to purchase premium GPUs. See Colab Service Plans for details.
Learning Goals
This notebook demonstrates how to post-train a Large Language Model (LLM) using Rbio, a method that distills knowledge from a simplified biology model.
We will use perturbation data from the PertQA dataset to improve an LLM's ability to answer questions about gene expression. Instead of relying on scarce and costly "hard" ground truth experimental labels, we'll use a Multi-Layer Perceptron (MLP) as a "soft verifier" to generate a reward signal for training the LLM.
What you'll learn
- Understand Rbio's approach for LLM post-training using soft-verification from a simplified Virtual Cell Model (VCM).
- Use an MLP to generate probabilistic "soft labels" for gene perturbation effects.
- Implement a reward mechanism that uses the MLP's signal to guide LLM training.
- Fine-tune an LLM to improve its ability to answer questions about differential gene expression.
Pre-requisites
-
GPU Access: A GPU (e.g., T4, A100) is required for model training. This notebook uses one of the open-source Qwen series as the base LLM. The ability to perform this post-training procedure will depend on your available GPU resources; for prototyping with smaller GPU's (i.e T4), you can leverage smaller models in the series.
-
Python Libraries: The following libraries are required:
numpy
pandas
torch
datasets
transformers
trl
accelerate
tiktoken
Introduction
Approach
In many scientific domains, like biology, verifying predictions often requires slow, expensive, and unscalable lab experiments. This work introduces Rbio, a reasoning model for biology that is post-trained from a Large Language Model (LLM) using a novel "soft verification" mechanism. Instead of relying on new experimental data, Rbio uses existing biological models as approximate oracles to simulate biological knowledge, providing a reward signal for training. This approach demonstrates that predictions from bio-models can be used to train powerful reasoning systems, using simulations rather than new experiments as a training paradigm.
The goal is to distill the knowledge from complex biological foundation models into a more flexible, conversational LLM. This makes the underlying science more accessible and allows users to explore hypotheses and ask questions in natural language.
Demo Model
Below, we use a simplified "VCM" consisting of a Multi Layer Perceptron (MLP) trained to answer perturbation questions. It exposes an interface that returns a probability when prompted with two gene names. This is the probability that a knockout of gene_a
is having an effect on the expression of gene_b
.
We use this signal as a soft verification signal within our reward mechanism in order to post-train our LLM. This improves the LLM capabilities to answer questions of the form "Is a knockdown of gene_a
in cell_line
cells likely to result in differential expression of gene_b
?"
Example Dataset
The training data for this notebook is derived from the PerturbQA benchmark. The dataset contains single-gene perturbation knockout data on cancer cell lines, which we use to fine-tune the LLM's ability to answer questions about differential gene expression.
Notes
- This code is a different implementation compared to the code that has been used to train the methods discussed in our paper "Rbio: ...."
- If you are interested only in using the perturbation data we employ in this dataset, please refer to the original repository https://github.com/genentech/PerturbQA and cite the work from our colleagues at Genentech accordingly.
Setup
Setup Google Colab
To run this quickstart using Google Colab, you will need to choose the 'A100' GPU runtime from the "Connect" dropdown menu in the upper-right corner of this notebook. Note that this runtime configuration is not available in the free Colab version. To access premium GPUs, you will need to purchase additional compute units. The current quickstart was tested in Colab Enterprise using the following runtime configuration:
- Machine type: a2-highgpu-1g
- GPU type: NVIDIA_TESLA_A100 x 1
- Data disk type:100 GB Standard Disk (pd-standard)
It is also possible to run this quick start on a free-tier T4 instance by using a smaller base LLM.
Download necessary assets
To run this tutorial we have made available the following implementations and assets:
- Reward Functions
- Utility Functions
- MLP Model Checkpoint
- ESM Embeddings
- A dataset from PerturbQA
#!pip install gdown #install down if not already; enables download of demo assets from google drive
Download Example Model Checkpoint and Embedding file
# Download embeddings
!gdown "https://drive.google.com/uc?id=1hBhnanbna2t5TBZlhPr6Y8sPZbOX_BV9"
# Download demo model checkpoint
!gdown "https://drive.google.com/uc?id=1o6FvLUGlFz8f-vYrhc1shAYVxtGNQEkh"
Download Example Data
!gdown "https://drive.google.com/uc?id=16WR4a4bdqiWXd72HToFvAJ66e1jxsp-k"
Download Utility Scripts and reward functions
# Utils
!gdown "https://drive.google.com/uc?id=138cIxmAPB8__0zBQsslHgKvaYIRQSlsK"
# Reward Functions
!gdown "https://drive.google.com/uc?id=18NY_yfyo0R_TQ8DUG_c2NCkzi_nwYT9h"
Set up Directory Structure
!mkdir checkpoints
!ls
Output:
checkpoints mlp_model.pt utils.py
esm_embedding_dictionary_filled.pkl rewards.py
k562-train-v0.3.0.csv sample_data
Install dependencies
# Disable WandB
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["DISABLE_MLFLOW_INTEGRATION"] = "true"
!pip install uv
!uv pip install --system --quiet numpy pandas==2.2.3 torch==2.6.0 datasets==3.5.0 transformers==4.51.3 trl==0.16.1 accelerate tiktoken protobuf==3.20.3
Imports, global variables, random seeds
import warnings
# Filter for the Qwen caching message
warnings.filterwarnings("ignore", message=".*Caching is incompatible with gradient checkpointing.*")
# Filter for the WANDB_DISABLED deprecation warning
warnings.filterwarnings("ignore", message=".*Using the `WANDB_DISABLED` environment variable is deprecated.*")
from typing import List
import numpy as np
import pandas as pd
import torch
from torch import nn
from datasets import Dataset
from trl import GRPOTrainer
from rewards import *
from utils import (
set_random_seeds,
load_mlp_classifier,
setup_model_and_tokenizer,
create_training_config,
mlp_classifier_inference
)
# Training configuration
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
N_STEPS = 100 #100000
BATCH_SIZE = 2
NUM_GENERATIONS = 2
SAVE_EVERY = 50 #10000
OUTPUT_DIR = "./checkpoints"
# Global step counter
STEP_COUNT = 0
# MLP classifier configuration (mlp was not trained on k562 cells)
MLP_MODEL_PATH = "./mlp_model.pt" #"./MLP-NO-k562-esm_emb/mlp_model.pt"
EMBEDDING_FILE = "./esm_embedding_dictionary_filled.pkl" #"./MLP-NO-k562-esm_emb/esm_embedding_dictionary_filled.pkl"
# Dataset paths
DATASET_PATHS = [
"./k562-train-v0.3.0.csv",
]
# Set seeds globally
set_random_seeds(42)
Simplified Virtual Cell Model (VCM)
Our simplified Virtual Cell Model (VCM) is a Multi-Layer Perceptron (MLP) Classifier. Its purpose is to act as a "soft labeler" for our reward strategy during the LLM's training.
Instead of providing a hard "yes" or "no" label, this model takes a pair of genes and returns a continuous probability (e.g., 0.75) that perturbing the first gene will affect the expression of the second. This probabilistic or "soft" label is a more nuanced signal that we can use to reward the LLM.
Model Architecture
The MLP is designed specifically for this gene-pair task. Its structure is simple yet effective:
- Input Layer: The model takes the embeddings for two genes (a perturbed gene and a monitored gene). These two vectors are concatenated, which is why the input layer's size is
input_dim * 2
. - Hidden Layer: A standard linear layer followed by a
ReLU
activation function introduces non-linearity, allowing the model to learn more complex relationships between the gene embeddings. - Output Layer: The final linear layer collapses the features into a single raw output value, known as a logit. This logit is later passed through a sigmoid (see
utils.py
) function to generate the final probability score.
class MLPClassifier(nn.Module):
"""Simple MLP classifier for gene pair classification"""
def __init__(self, input_dim: int, hidden_dim: int = 64):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
result = self.model(x)
return result
# Global variables for MLP model and embeddings
mlp_model = None
embeddings_dict = None
Dataset
Our training process is handled by two main functions:
load_and_prepare_dataset
This is a straightforward helper function that loads the raw training data from CSV files into a pandas DataFrame. This DataFrame is then used by the data generator.
create_mlp_labeled_dataset_generator
This function is the core of our data pipeline. It acts as a generator, yield
ing one fully-prepared sample at a time directly to the trainer. This on-the-fly process is memory-efficient and performs the following steps for each sample:
- Samples a Row: It picks a row from the initial DataFrame.
- Gets a Soft Label: It calls our
mlp_classifier_inference
function to get a precise probability (e.g., 0.81) from the VCM, which serves as the "soft label." - Constructs the Training Example: It builds a dictionary containing everything the trainer will need, including a binarized "hard" label and the
class_confidences
string (e.g., "0.190|0.810") with the raw probabilities for the reward calculation. - Formats the Prompt: It uses the tokenizer's
apply_chat_template
method to format the system and user prompts into the exact string the LLM expects. - Yields the Result: It provides the final, complete training example to the
GRPOTrainer
.
def load_and_prepare_dataset(dataset_paths: List[str], balance_pos_neg: bool = True) -> pd.DataFrame:
"""Load CSV datasets and combine them into a single DataFrame"""
if len(dataset_paths) == 1:
dataset_df = pd.read_csv(dataset_paths[0])
else:
dataset_list = []
for path in dataset_paths:
dataset_list.append(pd.read_csv(path))
dataset_df = pd.concat(dataset_list, ignore_index=True)
print(f"Loaded dataset with {len(dataset_df)} rows")
return dataset_df
def create_mlp_labeled_dataset_generator(dataset_df: pd.DataFrame, tokenizer, balance_pos_neg: bool = True):
"""Generate training examples with MLP-based labeling"""
if balance_pos_neg:
# Use 2x the dataset length to ensure enough samples for training
dataset_length = len(dataset_df) * 2
else:
dataset_length = len(dataset_df)
for i in range(dataset_length):
# Sample from dataset (with replacement for longer training)
sample_idx = i % len(dataset_df)
row = dataset_df.iloc[sample_idx]
# Prepare sample data for MLP classification
sample_data = {
"system_prompt": row["system_prompt"],
"user_prompt": row["user_prompt"],
"keywords": row["keywords"]
}
# Get MLP prediction
mlp_probability = mlp_classifier_inference(sample_data)
# Determine label based on MLP probability
predicted_label = 1 if mlp_probability > 0.5 else 0
# Prepare sample with MLP-generated label
sample = {
"system_prompt": row["system_prompt"],
"user_prompt": row["user_prompt"],
"label": predicted_label,
"classes": "no|yes",
"class_confidences": f"{1.0-mlp_probability:.3f}|{mlp_probability:.3f}",
"keywords": row["keywords"],
"task": row["task"],
"mlp_probability": mlp_probability
}
# Format messages for chat template
messages = [
{"role": "system", "content": sample["system_prompt"]},
{"role": "user", "content": sample["user_prompt"]},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
yield {
"prompt": prompt,
"label": sample["label"],
"classes": sample["classes"],
"class_confidences": sample["class_confidences"],
"keywords": sample["keywords"],
"task": sample["task"],
"system_prompt": sample["system_prompt"],
"user_prompt": sample["user_prompt"],
}
Rewards Strategy
The core of this training approach is a multi-objective reward system. We don't just want the model to give the right answer; we also want it to follow a structured, interpretable reasoning process. To achieve this, the total reward given to the LLM for any given output is a weighted combination of three distinct components.
The main function, compute_simple_reward
, orchestrates this by calling three specialized sub-functions for each generated text.
1. Answer Reward
Handled by reward_answer_against_label
, this is the most important component. It determines the correctness of the final answer.
- How it works: This function first extracts the binary "yes" or "no" from the LLM's
<answer>
tag. Then, instead of giving a simple +1 for a correct answer, it gives a soft reward corresponding to the VCM's confidence score for that answer. For example, if the VCM was 85% sure the answer was "yes" and the LLM says "yes", the reward is 0.85. This provides a much more nuanced signal.
2. Formatting Reward
Handled by composite_formatting_reward
, this component ensures the LLM's output follows the strict chain-of-thought
structure we expect (e.g., <think>...</think><answer>...</answer>
).
- How it works: This function runs a series of checks on the generated text, such as
starts_with_think
,ends_with_answer
,all_tags_properly_closed
, andno_nested_tags
. The final score is an average of how many checks the output passed, encouraging clean, predictable formatting.
3. Mention Reward
Handled by keywords_mentioned_in_think
, this component incentivizes the model to reason about the correct entities.
- How it works: It checks if the specific keywords from the prompt (i.e., the gene names) are mentioned within the
<think>
block. This discourages the model from hallucinating or providing generic reasoning that isn't grounded in the specific question.
Combining the Rewards
The compute_simple_reward
function brings everything together. For each completion, it calculates the three reward components and combines them into a final score. Crucially, it applies a specific weighting:
total_score = format_reward + 2.0 * answer_reward + mention_reward
By doubling the weight of the answer reward, we signal to the model that while formatting and reasoning are important, getting the biologically correct answer (as determined by our VCM) is the primary goal.
def reward_answer_against_label(completion: str, classes: str, class_confidence: str) -> float:
"""Compute reward based on whether answer matches expected label"""
answer = extract_binary_answer(completion)
if answer is None:
return 0.0
answer = "yes" if answer else "no"
possible_classes = classes.split("|")
confidences = [float(c) for c in class_confidence.split("|")]
for label, conf in zip(possible_classes, confidences):
if answer == label.strip().lower():
return conf
return 0.0
def composite_formatting_reward(text: str, use_go: bool = False) -> float:
"""Compute composite formatting reward based on multiple checks"""
at_least_one_think = has_at_least_one_think(text)
has_tags = has_any_tag(text)
checks = [
at_least_one_think,
low_untagged_ratio(text),
is_not_too_long(text),
has_one_answer(text),
answer_after_thinks(text),
thinks_have_text(text) * at_least_one_think,
no_nested_tags(text) * has_tags,
has_limited_thinks(text) * at_least_one_think,
all_tags_properly_closed(text) * has_tags,
ends_with_answer(text),
starts_with_think(text),
]
# Remove start_with_think dependency if using GO ontology
if use_go:
checks = checks[:-1]
return sum(checks) / len(checks)
def keywords_mentioned_in_think(text: str, keywords: str) -> float:
"""Check how many keywords are mentioned in think sections"""
keyword_list = [k for k in keywords.split("|") if k]
if not keyword_list:
return 1.0
think_contents = extract_think(text)
if not think_contents:
return 0.0
found_keywords = 0
for keyword in keyword_list:
if keyword in think_contents:
found_keywords += 1
return found_keywords / len(keyword_list)
def compute_simple_reward(
completions: List[str],
label: List[int],
classes: List[str],
class_confidences: List[str],
keywords: List[str],
**kwargs
) -> List[float]:
"""Compute rewards for model completions using format, mention, and answer rewards"""
scores = []
global STEP_COUNT
for completion, lbl, class_list, confidences, keyword_list in zip(
completions, label, classes, class_confidences, keywords
):
# Format reward: checks proper tag structure
format_reward = composite_formatting_reward(completion, use_go=False)
# Mention reward: checks if keywords are mentioned in think sections
mention_reward = keywords_mentioned_in_think(completion, keyword_list)
# Answer reward: checks if answer matches expected label
answer_reward = reward_answer_against_label(completion, class_list, confidences)
# Combine rewards (answer reward gets 2x weight as it's most important)
total_score = format_reward + 2.0 * answer_reward + mention_reward
scores.append(total_score)
# Debug prints every 100 steps to monitor model outputs
if STEP_COUNT % 100 == 0:
print("\n" + "="*80)
print(f"DEBUG: Sample {STEP_COUNT} - Step {STEP_COUNT}")
print("="*80)
# Print the completion to see what the model generated
print(f"MODEL OUTPUT:")
print(f"{completion}")
print()
# Print reward breakdown
print(f"REWARD BREAKDOWN:")
print(f" Format reward: {format_reward:.3f}")
print(f" Mention reward: {mention_reward:.3f}")
print(f" Answer reward: {answer_reward:.3f}")
print(f" Total score: {total_score:.3f}")
print()
# Print expected vs predicted
print(f"EXPECTED vs PREDICTED:")
print(f" VCM binarized label: {lbl}")
print(f" Possible classes: {class_list}")
print(f" Label VCM confidences: {confidences}")
print(f" Keywords: {keyword_list}")
print()
# Print reward details
print(f"REWARD DETAILS:")
print(f" Answer extraction: {extract_binary_answer(completion)}")
print(f" Think content: {extract_think(completion)[:100]}...")
print("="*80 + "\n")
STEP_COUNT += 1
return scores
Training
This final section brings all our components together: the dataset, the Virtual Cell Model (VCM), the reward functions, and the base LLM. The following code cell configures and launches the training process using the Hugging Face TRL (Transformer Reinforcement Learning) library.
The process follows these key steps:
-
Load Core Assets: The script begins by loading the three essential components into memory: the raw training data (
dataset_df
), the pre-trained MLP classifier that will act as our VCM (mlp_model
), and the base Qwen2 LLM and its tokenizer (model
,tokenizer
). -
Create a Streaming Dataset: The
create_mlp_labeled_dataset_generator
function is wrapped in a Hugging FaceDataset.from_generator
. This creates a special iterable-style dataset object that the trainer can pull from. This ensures that our "on-the-fly" labeling process is seamlessly integrated into the TRL ecosystem. -
Configure the Trainer: All hyperparameters for the training run—such as batch size, learning rate, and the number of steps—are defined in a
GRPOConfig
object. GRPO (Generative Rejected-based Policy Optimization) is the specific reinforcement learning algorithm we are using from TRL. -
Instantiate the
GRPOTrainer
: This is the final setup step, where all the pieces are assembled. We provide the trainer with:- The
model
we want to train. - Our
compute_simple_reward
function to score the model's generations. - The
training_config
with our hyperparameters. - The streaming
train_dataset
.
- The
-
Launch Training: The single command
trainer.train()
kicks off the entire process. Behind the scenes, the trainer will now begin a loop of generating text from the model, using our reward function to score the outputs, and updating the model's weights to maximize the reward. You will see the debug output from the reward function printed periodically as training progresses.
print("Starting RBIO training with streaming MLP labeling...")
# Load and prepare dataset
print("Loading dataset...")
dataset_df = load_and_prepare_dataset(DATASET_PATHS)
# Load MLP classifier
print("Loading MLP classifier...")
load_mlp_classifier(MLP_MODEL_PATH, EMBEDDING_FILE, MLPClassifier)
# Setup model and tokenizer
model, tokenizer = setup_model_and_tokenizer(MODEL_NAME)
# Create streaming dataset generator
print("Creating streaming dataset generator...")
dataset = Dataset.from_generator(
create_mlp_labeled_dataset_generator,
gen_kwargs={
"dataset_df": dataset_df,
"tokenizer": tokenizer,
"balance_pos_neg": True,
},
)
# Create training configuration
print("Setting up training configuration...")
training_config = create_training_config(
output_dir=OUTPUT_DIR,
batch_size=BATCH_SIZE,
num_generations=NUM_GENERATIONS,
max_steps=N_STEPS,
save_every=SAVE_EVERY
)
# Create trainer
print("Creating GRPO trainer...")
trainer = GRPOTrainer(
model=model,
reward_funcs=compute_simple_reward,
args=training_config,
train_dataset=dataset,
)
# Start training
print(f"Starting training for {N_STEPS} steps...")
trainer.train()
print("Training completed!")
Contact & Feedback
For issues or feedback about this tutorial please contact virtualcellmodels@chanzuckerberg.com.
Responsible Use
We are committed to advancing the responsible development and use of artificial intelligence. Please follow our Acceptable Use Policy when engaging with our services.
Should you have any security or privacy issues or questions related to the services, please reach out to our team at security@chanzuckerberg.com or privacy@chanzuckerberg.com respectively.