Quickstart: scPRINT

Estimated time to complete: 15 minutes

Learning Goals

  • Setup data schemes and the data loader
  • Learn about model inputs and outputs
  • Download dataset from CZ CELLxGENE
  • Denoise dataset using scPRINT (run inference)

Prerequisites

  • T4 GPU
  • Python version 3.11

Introduction

scPRINT is a cell foundation model, also called a large cell model (LCM), trained on more than 50M human and mouse cells available through CZ CELLxGENE. It is based on the transformer architecture and is fully open source and reproducible with multiple model sizes available from 2M to 100M parameters. scPRINT has various zero-shot capabilities, such as cell classification, expression imputation, cell embedding, and genome-wide cell-specific gene-network inference. It can be fine-tuned for other tasks.

In this quickstart, we will use the associated scDataLoader to showcase how one can setup required metadata, load AnnData objects, run preprocessing, and iterate over sets of AnnData objects. Note the following:

  • scDataLoader accepts an AnnData object as input (or a collection of AnnData objects when using LaminDB).
  • scPRINT uses a trained model checkpoint and an AnnData object as input.

Setup

The steps below will guide you through the following:

  1. Installation
  2. Creating a LaminDB instance
  3. Loading ontologies (metadata)
  4. Loading a random dataset
  5. Creating a data loader
  6. Downloading an scPRINT checkpoint
  7. Setting up a denoiser
  8. Denoising the anndata

Setup Google Colab

Before starting, connect to the T4 GPU runtime hosted for free by Google Colab using the dropdown menu in the upper right hand corner of this notebook.

Setup Local Environment

This quickstart has only been tested with NVIDIA version 11.7 through 12.2.

Step 1: Install scPRINT

Installation may take up to 10 minutes.

! pip install scprint
# You might be prompted to restart the session to apply all the updates

Step 2: Setup LaminDB

LaminDB is required to handle the collections of AnnData objects, connection with CZ CELLxGENE and other databases, and the setup of metadata. Create an instance (i.e., workspace) to setup LaminDB:

# Unistall the preinstalled torchao library in Colab
# The preinstalled library expects a pytorch version that is more recent than what we are using
! pip uninstall -y torchao

Restart your Google Colab session for all changes to take effect

# After restarting the session, verify the numpy version.
# It should be 1.26 (lamindb won't work with version 2.0 or later)
!pip list | grep numpy
Output:

numpy                                 1.26.0
numpy-groupies                        0.11.3
import lamindb_setup as ln_setup
ln_setup.init(storage="./testdb", name="test", modules="bionty")
Output:

! using anonymous user (to identify, call: lamin login)
→ initialized lamindb: anonymous/test

Step 3: Download required data

Download ontologies

Download ontologies which are usefull for classifications and other tasks:

pip list | grep torchmetrics
Output:

torchmetrics                          1.3.0
import lamindb as ln
import pandas as pd
import numpy as np
from scdataloader import DataModule, Preprocessor, utils
from scdataloader.preprocess import additional_postprocess, additional_preprocess

utils.populate_my_ontology()
! pip install torchmetrics==1.3.0

Download a file from CZ CELLxGENE's collection

We will load and preprocess a scRNAseq dataset to fit scDataLoader and/or scPRINT requirements (just to make sure we work with the same set of genes and same data format).

Load a random spatial AnnData object from CZ CELLxGENE (CxG):

cxg = ln.Collection.using(instance="laminlabs/cellxgene").filter(name='cellxgene-census').last()
cxg
Output:

! the database (1.2.0) is ahead of your installed lamindb package (1.0.4)
→ please update lamindb: pip install "lamindb>=1.2,<1.3"

Collection(uid='dMyEX3NTfKOEYXyMKDD7', version='2024-07-01', is_latest=True, key='cellxgene-census', hash='nI8Ag-HANeOpZOz-8CSn', created_by_id=1, space_id=1, run_id=27, created_at=2024-07-16 12:14:38 UTC)
file = cxg.artifacts.filter(description__contains="spatial").first()
file
Output:

Artifact(uid='nnSyG0O54y0MJ58m2jYr', version='2024-07-01', is_latest=True, key='cell-census/2024-07-01/h5ads/9fddb063-056d-4202-8b8a-4b0ee531d3ce.h5ad', description='A single-cell and spatially-resolved atlas of human breast cancers', suffix='.h5ad', kind='dataset', otype='AnnData', size=839092209, hash='bpBJEd5AOqxfFUNOxmVrrA', n_observations=100064, space_id=1, storage_id=2, run_id=27, created_by_id=1, created_at=2024-07-12 12:34:09 UTC)
adata = file.load()
adata
Output:

! the database (1.2.0) is ahead of your installed lamindb package (1.0.4)
→ please update lamindb: pip install "lamindb>=1.2,<1.3"
! run input wasn't tracked, call `ln.track()` and re-run

AnnData object with n_obs × n_vars = 100064 × 29067
    obs: 'donor_id', 'percent_mito', 'nCount_RNA', 'nFeature_RNA', 'celltype_major', 'celltype_minor', 'celltype_subset', 'subtype', 'gene_module', 'calls', 'normal_cell_call', 'CNA_value', 'batch_run', 'multiplexed', 'cryo_state', 'development_stage_ontology_term_id', 'cancer_type', 'ER', 'PR', 'HER2_IHC', 'HER2_ISH', 'HER2_ISH_ratio', 'Ki67', 'subtype_by_IHC', 'treatment_status', 'treatment_details', 'assay_ontology_term_id', 'organism_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'sex_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'disease_ontology_term_id', 'grade', 'cell_type_ontology_term_id', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length'
    uns: 'citation', 'schema_reference', 'schema_version', 'title'
    obsm: 'X_Three-D', 'X_umap'

Preprocess the file to make sure is ready for scPRINT:

  • Add missing genes
  • Check for true count data
  • Run scanpy's preprocessing pipeline
preprocessor = Preprocessor(
    do_postp=False, # takes too much RAM in collab to do PCA
    # additional_postprocess=additional_postprocess,
    additional_preprocess=additional_preprocess,
    force_preprocess=True,
)
adata = preprocessor(adata)
Output:

Dropping layers:  KeysView(Layers with keys: )
checking raw counts
removed 0 non primary cells, 100064 renamining

/usr/local/lib/python3.11/dist-packages/scdataloader/preprocess.py:202: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual.
  adata.obs["nnz"] = np.array(np.sum(adata.X != 0, axis=1).flatten())[0]

filtered out 0 cells, 100064 renamining
Removed 0 genes.
validating

/usr/local/lib/python3.11/dist-packages/scdataloader/preprocess.py:304: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
  adata, organism=adata.obs.organism_ontology_term_id[0], need_all=False

startin QC
Seeing 36967 outliers (36.94% of total dataset):

/usr/local/lib/python3.11/dist-packages/scdataloader/preprocess.py:411: ImplicitModificationWarning: Trying to modify index of attribute `.obs` of view, initializing view as actual.
  adata.obs.index = [str(uuid4()) for _ in range(adata.shape[0])]

done
AnnData object with n_obs × n_vars = 100064 × 70611
    obs: 'donor_id', 'percent_mito', 'nCount_RNA', 'nFeature_RNA', 'celltype_major', 'celltype_minor', 'celltype_subset', 'subtype', 'gene_module', 'calls', 'normal_cell_call', 'CNA_value', 'batch_run', 'multiplexed', 'cryo_state', 'development_stage_ontology_term_id', 'cancer_type', 'ER', 'PR', 'HER2_IHC', 'HER2_ISH', 'HER2_ISH_ratio', 'Ki67', 'subtype_by_IHC', 'treatment_status', 'treatment_details', 'assay_ontology_term_id', 'organism_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'sex_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'disease_ontology_term_id', 'grade', 'cell_type_ontology_term_id', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'cell_culture', 'nnz', 'n_genes', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'uid', 'symbol', 'ncbi_gene_ids', 'biotype', 'synonyms', 'description', 'organism_id', 'mt', 'ribo', 'hb', 'organism', 'ensembl_gene_id', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'unseen_genes'

Step 4: Save the downloaded file as a collection of files

Note: This collection only contains a unique file but could contain much more.

Here we show how scDataLoader can run on a collection of AnnData files (H5AD format).

# we make a collection from this one dataset but we could use multiple.
# this is the point of a collection
file = ln.Artifact.from_anndata(adata, description="test anndata")
file.save()
col = ln.Collection(file, name="test")
col.save()
Output:

! no run & transform got linked, call `ln.track()` & re-run
! no run & transform got linked, call `ln.track()` & re-run
! run input wasn't tracked, call `ln.track()` and re-run
<ipython-input-7-eee8e2dce101>:5: FutureWarning: argument `name` will be removed, please pass test to `key` instead
  col = ln.Collection(file, name="test")
Collection(uid='fCAFhmRnI34a2XRN0000', is_latest=True, key='test', hash='ATO5zsOJCJsY6p0crfgCAw', created_by_id=1, space_id=1, created_at=2025-04-29 19:40:50 UTC)

Step 5: Set up and run scDataLoader

Instantiate the data loader

  • Randomly get 300 genes per cell from the set of expressed genes.
  • Use a batch size of 64
  • Perform weighted random sampling weighted by organism and cell type.
datamodule = DataModule(
    collection_name="test",
    organisms=["NCBITaxon:9606"],  # organism that we will work on
    how="random expr",  # for the collator (random set of expr genes will be selected)
    max_len=300,  # only 300 genes will be shown
    batch_size=64,
    do_gene_pos=False,
    num_workers=1,
    use_default_col=True,
    clss_to_weight=["organism_ontology_term_id", "cell_type_ontology_term_id"],
    hierarchical_clss=["cell_type_ontology_term_id"],
    validation_split=0.1,
    test_split=0,
)
datamodule.setup()
Output:

! no run & transform got linked, call `ln.track()` & re-run
! run input wasn't tracked, call `ln.track()` and re-run
won't do any check but we recommend to have your dataset coming from local storage
100.0% are aligned

[]

Iterate over it

for i in datamodule.train_dataloader():
    print(i)
    break
Output:

{'x': tensor([[ 1.,  1.,  1.,  ...,  0.,  0.,  0.],
        [ 1.,  7.,  1.,  ...,  0.,  0.,  0.],
        [12., 33.,  1.,  ...,  0.,  0.,  0.],
        ...,
        [ 2.,  1.,  3.,  ...,  0.,  0.,  0.],
        [ 1.,  1.,  1.,  ...,  0.,  0.,  0.],
        [ 1.,  1.,  2.,  ...,  0.,  0.,  0.]]), 'genes': tensor([[ 2866, 13109,   606,  ..., 63275,  9846, 34549],
        [15509,  5231, 16872,  ..., 52986, 65306, 44274],
        [19966,  3509,  9943,  ..., 67885, 37694, 55855],
        ...,
        [ 4531, 11803, 13468,  ...,  3892, 42968, 66839],
        [15151,  3406,  9856,  ..., 63901, 69323, 41856],
        [ 3079, 10725, 19916,  ..., 10777, 32488, 45297]], dtype=torch.int32), 'class': tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]], dtype=torch.int32), 'tp': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 'depth': tensor([ 3198.,  2234.,  3411.,  3374.,  3115.,  2553.,  1672.,  4038.,  2915.,
        16055.,  2785.,  5265.,  4052.,  3281.,   781.,  3114.,  2980.,  7107.,
        16746.,  3888.,  1904.,  5607.,  5051.,   440.,  2923., 12885.,  1721.,
         6297.,  5291., 28179.,  1758.,  5562.,   753.,  1597.,  1244.,  2268.,
         1462.,   468.,   601.,  1492.,  5470.,  3043.,  5146.,  4292., 15679.,
         8726.,  1972.,  1248.,  6502., 19086.,  3898., 11207.,  1890.,  1383.,
        44462., 10375., 18041.,  2193., 11054.,  6857.,  2339., 33694.,  5128.,
         1395.])}

Step 6: Run model inference

Now let's work on an example inference doing imputation of spatial transcriptomic data. For this example, we will:

  • download a model
  • use a denoiser tool (made for scPRINT) that:
    • takes a dataset
    • removes a lot of counts
    • uses scPRINT to recover/impute the removed counts
    • compares the pre-removal with the prediction
  • make sure that post-denoising, we improve correlation to ground truth
import os
import urllib.request
import torch
import scanpy as sc

from scprint import scPrint
from scprint.tasks import Denoiser, Embedder, GNInfer

Load the model

ckpt_path = "small.ckpt"
if not os.path.exists(ckpt_path):
    url = "https://huggingface.co/jkobject/scPRINT/resolve/main/small.ckpt"
    urllib.request.urlretrieve(url, ckpt_path)
model = scPrint.load_from_checkpoint(
        ckpt_path,
        precpt_gene_emb=None,
        # triton gets installed so it must think it has cuda enabled
        # transformer="normal", #else normal, without flashattention
)
Output:

RuntimeError caught: scPrint is not attached to a `Trainer`.

Setup and run the denoiser

For denoising, we will:

  • denoise a random set of 1000 cells from this AnnData
  • load a batch of 16 cells at a time with 8 workers (CPUs)
  • only denoise 4000 genes (the same most variable genes across all cells)
  • add noise (do a 70% downsampling of transcript counts) before running denoising so that we can get a ground truth expression profile to get scored. This is only used to assess the quality of the denoising and shouldn't be used when one wants to increase the library size (i.e., total count) of a dataset
  • increase the library size of this artificially downsampled dataset to 5x the input library size.
dn = Denoiser(
    max_cells=1000, #number of cells which will be processed
    batch_size=16,
    num_workers=8,
    how="most var", # how to select the max_len genes
    max_len=4000, # we will work on 2000 genes (input and output)
    downsample=0.7, #we are removing 70% of the counts,
    # should be modified to make the data look more like st
    predict_depth_mult=5, # how much to increase expression
)
metrics, random_indices, nadata = dn(
    model=model,
    adata=adata,
)
print('increased expression correlation by: ' + str(metrics["reco2full"] - metrics["noisy2full"]))
Output:

100%|██████████| 63/63 [00:10<00:00,  6.08it/s]

AnnData object with n_obs × n_vars = 1000 × 44756
    obs: 'pred_cell_type_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb'
    layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
increased expression correlation by: 0.4251738060380997

Model Outputs

See output definitions below:

  • Metrics: a dictionary with the denoising score (when we do the noising process within the denoiser by setting downsample to a value between 0 and 1)
  • Random_indices: the set of cell indices on which we ran denoising when max_cells is smaller than total number of cells in the dataset
  • nadata: the AnnData with the denoised expression profile on the set of max_len genes selected with how

Learn More

Check out the Denoiser script to see how anndata gets "noised" and scPRINT used for denoising. You can look at Collator and model.forward() function to see how data goes from the mapped file to the format received by scPRINT. Read the scPRINT manuscript to learn more about the model. Finally, have a look at Embedder to get classification and embeddings of cells and GNInfer to see how to generate gene networks from groups of cells.

References

Kalfon, J., Samaran, J., Peyré, G. et al. scPRINT: pre-training on 50 million cells allows robust gene network predictions. Nat Commun 16, 3607 (2025). https://doi.org/10.1038/s41467-025-58699-1

Contact and Acknowledgments

For issues with this quickstart please contact Jérémie Kalfon via a pull request message on GitHub.

The project leading to this manuscript has received funding from the Inception program (Investissement d'Avenir grant ANR-16-CONV-0005) L.C. and the European Union (ERC StG, MULTIview-CELL, 101115618) L.C. We acknowledge the help of the HPC Core Facility of the Institut Pasteur. The work of G. Peyré was supported by the French government under management of Agence Nationale de la Recherche as part of the 'Investissements d'avenir' program, reference ANR19-P3IA-0001 (PRAIRIE 3IA Institute).

Responsible Use

We are committed to advancing the responsible development and use of artificial intelligence. Please follow our Acceptable Use Policy when engaging with our services.