scLDM.CD4 Quickstart Inference Tutorial
Learning Goals
Perform inference and generate scRNA-seq perturbation data using a pre-trained checkpoint.
Prerequisites
- Python >=3.11
- Compute resources: Inference with the released pre-trained checkpoint has been tested on NVIDIA A100, H100, and A6000 GPUs. CPU-only inference is not currently supported; we recommend using at least one GPU for inference.
- General requirements: This tutorial assumes that you have already downloaded the pre-trained checkpoint and associated config file from Hugging Face, set up a virtual environment, and updated paths in yaml files as described in the repo's README. To set up Python and a proper virtual environment, first clone the GitHub Repository, then run ./init.sh (follow steps 1–2 in the GitHub README).
Setup
Installation
First, import the necessary modules.
import os
import sys
from pathlib import Path
from notebook_inference import NotebookInference, inference
# For displaying results
import anndata as ad
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as pltDefine the required paths.
# Paths for inference
INFERENCE_CONFIG_PATH = "../experiments/config"
INFERENCE_CONFIG_NAME = "inference_fm"
CHECKPOINT_PATH = "../model/last.ckpt" # modify as necessary
OUTPUT_DIR = "../inference_outputs" # modify as necessary# Paths for visualization
TEST_ADATA_PATH = "../data/test_hvg/adata_1_1k.h5ad" # modify as necessaryRun Model Inference
Inference will generate a new adata, which will also be saved to the output directory.
generated_adata = inference(
config_path=INFERENCE_CONFIG_PATH,
config_name=INFERENCE_CONFIG_NAME,
checkpoint_path=CHECKPOINT_PATH,
output_dir=OUTPUT_DIR,
dataset_generation_idx=0,
seed=42,
batch_size=32,
device="cuda",
overrides=["model.batch_size=32"]
)
print(f"Generated {generated_adata.n_obs} cells")
print(f"Features: {generated_adata.n_vars}")Note: if you run out of GPU memory, try reducing batch_size.
Run time: 2k cells run in ~4 minutes on a single H100 GPU and ~17 minutes on a single A6000 GPU
Model Outputs and Visualization
Examine the generated adata.
generated_adataNote that the generated adata contains two datasets: generated_unconditional and generated_conditional. When comparing generated data with true perturbation data, we wish to only consider generated_conditional cells; the generation of these cells is conditioned on the cell's perturbation, donor, and timepoint labels.
generated_cond_adata = generated_adata[generated_adata.obs["dataset"]=="generated_conditional"].copy()Load test data to compare with generated data.
test_adata = sc.read_h5ad(TEST_ADATA_PATH)Add a "dataset" column to the test data so we can compare it with the generated data.
test_adata.obs["dataset"] = "test"Concatenate the two adatas so we can visualize the true and generated cells together.
adata = ad.concat([generated_adata, test_adata], join="outer")Run standard processing steps.
sc.pp.normalize_total(adata, target_sum=10_000)
sc.pp.log1p(adata)sc.pp.pca(adata)sc.pp.neighbors(adata)
sc.tl.umap(adata)Plot PCA, colored by dataset.
sc.pl.pca(adata, color="dataset")Plot PCA, colored by experimental time point.
sc.pl.pca(adata, color="experimental_perturbation_time_point")Plot UMAP, colored by dataset.
sc.pl.umap(adata, color="dataset")Plot UMAP, colored by experimental time point.
sc.pl.umap(adata, color="experimental_perturbation_time_point")Plot UMAP, colored by donor.
sc.pl.umap(adata, color="donor_id")Contact and Acknowledgements
For issues with this tutorial, please contact Mei Knudson at knudsonm@uchicago.edu.
Special thanks to Kavita Kulkarni and Jason Perera for their consultation on this quickstart.
Responsible Use
We are committed to advancing the responsible development and use of artificial intelligence. Please follow our Acceptable Use Policy when engaging with our services.
Should you have any security or privacy issues or questions related to the services, please reach out to our team at security@chanzuckerberg.com or privacy@chanzuckerberg.com respectively.