Quickstart: scPRINT
Estimated time to complete: 15 minutes
Learning Goals
- Setup data schemes and the data loader
- Learn about model inputs and outputs
- Download dataset from CZ CELLxGENE
- Denoise dataset using scPRINT (run inference)
Prerequisites
- T4 GPU
- Python version 3.11
Introduction
scPRINT is a cell foundation model, also called a large cell model (LCM), trained on more than 50M human and mouse cells available through CZ CELLxGENE. It is based on the transformer architecture and is fully open source and reproducible with multiple model sizes available from 2M to 100M parameters. scPRINT has various zero-shot capabilities, such as cell classification, expression imputation, cell embedding, and genome-wide cell-specific gene-network inference. It can be fine-tuned for other tasks.
In this quickstart, we will use the associated scDataLoader to showcase how one can setup required metadata, load AnnData objects, run preprocessing, and iterate over sets of AnnData objects. Note the following:
- scDataLoader accepts an AnnData object as input (or a collection of AnnData objects when using LaminDB).
- scPRINT uses a trained model checkpoint and an AnnData object as input.
Setup
The steps below will guide you through the following:
- Installation
- Creating a LaminDB instance
- Loading ontologies (metadata)
- Loading a random dataset
- Creating a data loader
- Downloading an scPRINT checkpoint
- Setting up a denoiser
- Denoising the anndata
Setup Google Colab
Before starting, connect to the T4 GPU runtime hosted for free by Google Colab using the dropdown menu in the upper right hand corner of this notebook.
Setup Local Environment
This quickstart has only been tested with NVIDIA version 11.7 through 12.2.
Step 1: Install scPRINT
Installation may take up to 10 minutes.
! pip install scprint
# You might be prompted to restart the session to apply all the updates
Step 2: Setup LaminDB
LaminDB is required to handle the collections of AnnData objects, connection with CZ CELLxGENE and other databases, and the setup of metadata. Create an instance (i.e., workspace) to setup LaminDB:
# Unistall the preinstalled torchao library in Colab
# The preinstalled library expects a pytorch version that is more recent than what we are using
! pip uninstall -y torchao
Restart your Google Colab session for all changes to take effect
# After restarting the session, verify the numpy version.
# It should be 1.26 (lamindb won't work with version 2.0 or later)
!pip list | grep numpy
Output:
numpy 1.26.0
numpy-groupies 0.11.3
import lamindb_setup as ln_setup
ln_setup.init(storage="./testdb", name="test", modules="bionty")
Output:
! using anonymous user (to identify, call: lamin login)
→ initialized lamindb: anonymous/test
Step 3: Download required data
Download ontologies
Download ontologies which are usefull for classifications and other tasks:
pip list | grep torchmetrics
Output:
torchmetrics 1.3.0
import lamindb as ln
import pandas as pd
import numpy as np
from scdataloader import DataModule, Preprocessor, utils
from scdataloader.preprocess import additional_postprocess, additional_preprocess
utils.populate_my_ontology()
! pip install torchmetrics==1.3.0
Download a file from CZ CELLxGENE's collection
We will load and preprocess a scRNAseq dataset to fit scDataLoader and/or scPRINT requirements (just to make sure we work with the same set of genes and same data format).
Load a random spatial AnnData object from CZ CELLxGENE (CxG):
cxg = ln.Collection.using(instance="laminlabs/cellxgene").filter(name='cellxgene-census').last()
cxg
Output:
! the database (1.2.0) is ahead of your installed lamindb package (1.0.4)
→ please update lamindb: pip install "lamindb>=1.2,<1.3"
Collection(uid='dMyEX3NTfKOEYXyMKDD7', version='2024-07-01', is_latest=True, key='cellxgene-census', hash='nI8Ag-HANeOpZOz-8CSn', created_by_id=1, space_id=1, run_id=27, created_at=2024-07-16 12:14:38 UTC)
file = cxg.artifacts.filter(description__contains="spatial").first()
file
Output:
Artifact(uid='nnSyG0O54y0MJ58m2jYr', version='2024-07-01', is_latest=True, key='cell-census/2024-07-01/h5ads/9fddb063-056d-4202-8b8a-4b0ee531d3ce.h5ad', description='A single-cell and spatially-resolved atlas of human breast cancers', suffix='.h5ad', kind='dataset', otype='AnnData', size=839092209, hash='bpBJEd5AOqxfFUNOxmVrrA', n_observations=100064, space_id=1, storage_id=2, run_id=27, created_by_id=1, created_at=2024-07-12 12:34:09 UTC)
adata = file.load()
adata
Output:
! the database (1.2.0) is ahead of your installed lamindb package (1.0.4)
→ please update lamindb: pip install "lamindb>=1.2,<1.3"
! run input wasn't tracked, call `ln.track()` and re-run
AnnData object with n_obs × n_vars = 100064 × 29067
obs: 'donor_id', 'percent_mito', 'nCount_RNA', 'nFeature_RNA', 'celltype_major', 'celltype_minor', 'celltype_subset', 'subtype', 'gene_module', 'calls', 'normal_cell_call', 'CNA_value', 'batch_run', 'multiplexed', 'cryo_state', 'development_stage_ontology_term_id', 'cancer_type', 'ER', 'PR', 'HER2_IHC', 'HER2_ISH', 'HER2_ISH_ratio', 'Ki67', 'subtype_by_IHC', 'treatment_status', 'treatment_details', 'assay_ontology_term_id', 'organism_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'sex_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'disease_ontology_term_id', 'grade', 'cell_type_ontology_term_id', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length'
uns: 'citation', 'schema_reference', 'schema_version', 'title'
obsm: 'X_Three-D', 'X_umap'
Preprocess the file to make sure is ready for scPRINT:
- Add missing genes
- Check for true count data
- Run scanpy's preprocessing pipeline
preprocessor = Preprocessor(
do_postp=False, # takes too much RAM in collab to do PCA
# additional_postprocess=additional_postprocess,
additional_preprocess=additional_preprocess,
force_preprocess=True,
)
adata = preprocessor(adata)
Output:
Dropping layers: KeysView(Layers with keys: )
checking raw counts
removed 0 non primary cells, 100064 renamining
/usr/local/lib/python3.11/dist-packages/scdataloader/preprocess.py:202: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual.
adata.obs["nnz"] = np.array(np.sum(adata.X != 0, axis=1).flatten())[0]
filtered out 0 cells, 100064 renamining
Removed 0 genes.
validating
/usr/local/lib/python3.11/dist-packages/scdataloader/preprocess.py:304: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
adata, organism=adata.obs.organism_ontology_term_id[0], need_all=False
startin QC
Seeing 36967 outliers (36.94% of total dataset):
/usr/local/lib/python3.11/dist-packages/scdataloader/preprocess.py:411: ImplicitModificationWarning: Trying to modify index of attribute `.obs` of view, initializing view as actual.
adata.obs.index = [str(uuid4()) for _ in range(adata.shape[0])]
done
AnnData object with n_obs × n_vars = 100064 × 70611
obs: 'donor_id', 'percent_mito', 'nCount_RNA', 'nFeature_RNA', 'celltype_major', 'celltype_minor', 'celltype_subset', 'subtype', 'gene_module', 'calls', 'normal_cell_call', 'CNA_value', 'batch_run', 'multiplexed', 'cryo_state', 'development_stage_ontology_term_id', 'cancer_type', 'ER', 'PR', 'HER2_IHC', 'HER2_ISH', 'HER2_ISH_ratio', 'Ki67', 'subtype_by_IHC', 'treatment_status', 'treatment_details', 'assay_ontology_term_id', 'organism_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'sex_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'disease_ontology_term_id', 'grade', 'cell_type_ontology_term_id', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'cell_culture', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier'
var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'synonyms', 'description', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
uns: 'unseen_genes'
Step 4: Save the downloaded file as a collection of files
Note: This collection only contains a unique file but could contain much more.
Here we show how scDataLoader can run on a collection of AnnData files (H5AD format).
# we make a collection from this one dataset but we could use multiple.
# this is the point of a collection
file = ln.Artifact.from_anndata(adata, description="test anndata")
file.save()
col = ln.Collection(file, name="test")
col.save()
Output:
! no run & transform got linked, call `ln.track()` & re-run
! no run & transform got linked, call `ln.track()` & re-run
! run input wasn't tracked, call `ln.track()` and re-run
<ipython-input-7-eee8e2dce101>:5: FutureWarning: argument `name` will be removed, please pass test to `key` instead
col = ln.Collection(file, name="test")
Collection(uid='fCAFhmRnI34a2XRN0000', is_latest=True, key='test', hash='ATO5zsOJCJsY6p0crfgCAw', created_by_id=1, space_id=1, created_at=2025-04-29 19:40:50 UTC)
Step 5: Set up and run scDataLoader
Instantiate the data loader
- Randomly get 300 genes per cell from the set of expressed genes.
- Use a batch size of 64
- Perform weighted random sampling weighted by organism and cell type.
datamodule = DataModule(
collection_name="test",
organisms=["NCBITaxon:9606"], # organism that we will work on
how="random expr", # for the collator (random set of expr genes will be selected)
max_len=300, # only 300 genes will be shown
batch_size=64,
do_gene_pos=False,
num_workers=1,
use_default_col=True,
clss_to_weight=["organism_ontology_term_id", "cell_type_ontology_term_id"],
hierarchical_clss=["cell_type_ontology_term_id"],
validation_split=0.1,
test_split=0,
)
datamodule.setup()
Output:
! no run & transform got linked, call `ln.track()` & re-run
! run input wasn't tracked, call `ln.track()` and re-run
won't do any check but we recommend to have your dataset coming from local storage
100.0% are aligned
[]
Iterate over it
for i in datamodule.train_dataloader():
print(i)
break
Output:
{'x': tensor([[ 1., 1., 1., ..., 0., 0., 0.],
[ 1., 7., 1., ..., 0., 0., 0.],
[12., 33., 1., ..., 0., 0., 0.],
...,
[ 2., 1., 3., ..., 0., 0., 0.],
[ 1., 1., 1., ..., 0., 0., 0.],
[ 1., 1., 2., ..., 0., 0., 0.]]), 'genes': tensor([[ 2866, 13109, 606, ..., 63275, 9846, 34549],
[15509, 5231, 16872, ..., 52986, 65306, 44274],
[19966, 3509, 9943, ..., 67885, 37694, 55855],
...,
[ 4531, 11803, 13468, ..., 3892, 42968, 66839],
[15151, 3406, 9856, ..., 63901, 69323, 41856],
[ 3079, 10725, 19916, ..., 10777, 32488, 45297]], dtype=torch.int32), 'class': tensor([[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]], dtype=torch.int32), 'tp': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 'depth': tensor([ 3198., 2234., 3411., 3374., 3115., 2553., 1672., 4038., 2915.,
16055., 2785., 5265., 4052., 3281., 781., 3114., 2980., 7107.,
16746., 3888., 1904., 5607., 5051., 440., 2923., 12885., 1721.,
6297., 5291., 28179., 1758., 5562., 753., 1597., 1244., 2268.,
1462., 468., 601., 1492., 5470., 3043., 5146., 4292., 15679.,
8726., 1972., 1248., 6502., 19086., 3898., 11207., 1890., 1383.,
44462., 10375., 18041., 2193., 11054., 6857., 2339., 33694., 5128.,
1395.])}
Step 6: Run model inference
Now let's work on an example inference doing imputation of spatial transcriptomic data. For this example, we will:
- download a model
- use a denoiser tool (made for scPRINT) that:
- takes a dataset
- removes a lot of counts
- uses scPRINT to recover/impute the removed counts
- compares the pre-removal with the prediction
- make sure that post-denoising, we improve correlation to ground truth
import os
import urllib.request
import torch
import scanpy as sc
from scprint import scPrint
from scprint.tasks import Denoiser, Embedder, GNInfer
Load the model
ckpt_path = "small.ckpt"
if not os.path.exists(ckpt_path):
url = "https://huggingface.co/jkobject/scPRINT/resolve/main/small.ckpt"
urllib.request.urlretrieve(url, ckpt_path)
model = scPrint.load_from_checkpoint(
ckpt_path,
precpt_gene_emb=None,
# triton gets installed so it must think it has cuda enabled
# transformer="normal", #else normal, without flashattention
)
Output:
RuntimeError caught: scPrint is not attached to a `Trainer`.
Setup and run the denoiser
For denoising, we will:
- denoise a random set of 1000 cells from this AnnData
- load a batch of 16 cells at a time with 8 workers (CPUs)
- only denoise 4000 genes (the same most variable genes across all cells)
- add noise (do a 70% downsampling of transcript counts) before running denoising so that we can get a ground truth expression profile to get scored. This is only used to assess the quality of the denoising and shouldn't be used when one wants to increase the library size (i.e., total count) of a dataset
- increase the library size of this artificially downsampled dataset to 5x the input library size.
dn = Denoiser(
max_cells=1000, #number of cells which will be processed
batch_size=16,
num_workers=8,
how="most var", # how to select the max_len genes
max_len=4000, # we will work on 2000 genes (input and output)
downsample=0.7, #we are removing 70% of the counts,
# should be modified to make the data look more like st
predict_depth_mult=5, # how much to increase expression
)
metrics, random_indices, nadata = dn(
model=model,
adata=adata,
)
print('increased expression correlation by: ' + str(metrics["reco2full"] - metrics["noisy2full"]))
Output:
100%|██████████| 63/63 [00:10<00:00, 6.08it/s]
AnnData object with n_obs × n_vars = 1000 × 44756
obs: 'pred_cell_type_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
obsm: 'scprint_emb'
layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
increased expression correlation by: 0.4251738060380997
Model Outputs
See output definitions below:
- Metrics: a dictionary with the denoising score (when we do the
noising process within the denoiser by setting
downsample
to a value between 0 and 1) - Random_indices: the set of cell indices on which we ran denoising
when
max_cells
is smaller than total number of cells in the dataset - nadata: the AnnData with the denoised expression profile on the
set of
max_len
genes selected withhow
Learn More
Check out the Denoiser script to see how anndata gets "noised" and scPRINT used for denoising. You can look at Collator and model.forward() function to see how data goes from the mapped
file to the format
received by scPRINT. Read the scPRINT manuscript to learn more about the model. Finally, have a look at Embedder to get classification and embeddings of cells and GNInfer to see how to generate gene networks from groups of cells.
References
Kalfon, J., Samaran, J., Peyré, G. et al. scPRINT: pre-training on 50 million cells allows robust gene network predictions. Nat Commun 16, 3607 (2025). https://doi.org/10.1038/s41467-025-58699-1
Contact and Acknowledgments
For issues with this quickstart please contact Jérémie Kalfon via a pull request message on GitHub.
The project leading to this manuscript has received funding from the Inception program (Investissement d'Avenir grant ANR-16-CONV-0005) L.C. and the European Union (ERC StG, MULTIview-CELL, 101115618) L.C. We acknowledge the help of the HPC Core Facility of the Institut Pasteur. The work of G. Peyré was supported by the French government under management of Agence Nationale de la Recherche as part of the 'Investissements d'avenir' program, reference ANR19-P3IA-0001 (PRAIRIE 3IA Institute).
Responsible Use
We are committed to advancing the responsible development and use of artificial intelligence. Please follow our Acceptable Use Policy when engaging with our services.