Quick Start: Out-of-Distribution Cross-species Cell Type Classification with Transcriptformer

Estimated time to complete: Under 15 minutes with A100 GPU.

Google Colab Note: It is strongly recommended to run this notebook with an A100 GPU, which is only included with Google Colab Pro or Enterprise paid services. Alternatively, a "pay as you go" option is available to purchase premium GPUs. See Colab Service Plans for details.

Learning Goals

This tutorial walks through how TranscriptFormer embeddings can be used to transfer cell-type annotations between species.

We will replicate part of the analysis from the Transcriptformer publication, Figure 3. We will be using the spermatogenesis scRNA-seq data from Murat et al..

What you'll learn

  1. Embed cells with TranscriptFormer (TF-Exemplar) into a shared, species-agnostic latent space.
  2. Train a classifier on annotated cells from one species.
  3. Evaluate cross-species accuracy by predicting cell types in another species.
  4. Visualize results to understand how well biological structure is conserved.

Pre-requisites

  • A100 GPU or equivalent (Note: It is possible to run this tutorial on a T4 GPU by setting our inference batch size to 2 (example below in the inference section), however this will increase the inference time)
  • Python 3.11

Introduction

Model

TranscriptFormer is a family of generative foundation models representing a cross-species generative cell atlas trained on up to 112 million cells spanning 1.53 billion years of evolution across 12 species. The models include three distinct versions of TranscriptFormer, trained on different collections of data:

  • TF-Metazoa (12 species)
  • TF-Exemplar (5 model organisms)
  • TF-Sapiens (human)

TranscriptFormer is designed to learn rich, context-aware representations of single-cell transcriptomes while jointly modeling genes and transcripts using a novel generative architecture. It employs a generative autoregressive joint model over genes and their expression levels per cell across species, with a transformer-based architecture, including a novel coupling between gene and transcript heads, expression-aware multi-head self-attention, causal masking, and a count likelihood to capture transcript-level variability. More details on the architecture of TranscriptFormer can be found in the “Model Details” section. TranscriptFormer demonstrates robust zero-shot performance for cell type classification across species, disease state identification in human cells, and prediction of cell type specific transcription factors and gene-gene regulatory relationships. This work establishes a powerful framework for integrating and interrogating cellular diversity across species as well as offering a foundation for in-silico experimentation with a generative single-cell atlas model.

Example Dataset

The Macaque and Marmoset datasets used in this demonstration were retrieved from bgee.org.

Setup

Setup Google Colab

To run this quickstart using Google Colab, you will need to choose the 'A100' GPU runtime from the "Connect" dropdown menu in the upper-right corner of this notebook. Note that this runtime configuration is not available in the free Colab version. To access premium GPUs, you will need to purchase additional compute units. The current quickstart was tested in Colab Enterprise using the following runtime configuration:

  • Machine type: a2-highgpu-1g
  • GPU type: NVIDIA_TESLA_A100 x 1
  • Data disk type:100 GB Standard Disk (pd-standard)

It is also possible to run this quick start on a free-tier T4 instance by setting the inference batch size to 1 (see below in the inference section).

Setup Local Environment

Install Conda once & auto-restart

Below we use condacolab to bootstrap mamba, create an isolated env, install matching PyTorch + TorchVision, and finally install the Transcriptformer repo (that currently lives on Drive).

!python --version
!pip install -q condacolab
import condacolab
condacolab.install()
import condacolab
condacolab.check()
!conda install pip

Clone Transcriptformer repository and install dependencies.

Below we clone the transcriptformer repo and install from source. We can also install the pacakge via its PyPi distribution.

!git clone https://github.com/czi-ai/transcriptformer.git
%cd transcriptformer

Install Transcriptformer from source:

!uv pip install --system . torchvision

Import dependencies

First, we'll import some useful libraries.

%load_ext autoreload
%autoreload 2

import json
import logging
import os
from omegaconf import OmegaConf


from transcriptformer.model.inference import run_inference
from transcriptformer.datasets import bgee_testis_evolution, download_all_embeddings

import pandas as pd
import numpy as np
import anndata as ad

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import scanpy as sc

Download Transcriptformer Model Weights and ESM2 embeddings for out-of-distribution species

We need to download the transcriptformer model weights. You can use the python script download_artifacts.py to download the weights. We will be using the tf-exemplar model, trained on 5 model organisms (H. sapiens, M. musculus, D. rerio, D. melanogaster, C. elegans). The model checkpoints and associated artifacts are ~4.2GB in total for this model variant.

Since the species we will be showing in this notebook, Rhesus macaque and Marmoset, are out of distribution species for the model (the model was not trained on scRNA-seq data from those species), we will also have to download the ESM2 gene embeddings. We provide a convenience function download_all_embeddings which will download and unzip all embeddings. The tar file is ~4.8GB and uncompresed is ~5GB.

Note: This download will take few minutes.

!python ./download_artifacts.py tf-exemplar
download_all_embeddings(path="./embeddings")
!ls -la ./embeddings/all_embeddings/
Output:

    total 5216076
    drwxr-xr-x 2  999  999      4096 Apr  9 20:30 .
    drwxr-xr-x 3 root root      4096 Apr 28 19:41 ..
    -rw-r--r-- 1  999  999 212696350 Apr  6 20:36 caenorhabditis_elegans_gene.h5
    -rw-r--r-- 1  999  999 235371868 Apr  6 20:36 callithrix_jacchus_gene.h5
    -rw-r--r-- 1  999  999 324231520 Apr  6 20:49 danio_rerio_gene.h5
    -rw-r--r-- 1  999  999 148914470 Apr  6 20:36 drosophila_melanogaster_gene.h5
    -rw-r--r-- 1  999  999 180101052 Apr  6 20:36 gallus_gallus_gene.h5
    -rw-r--r-- 1  999  999 232352100 Apr  6 20:36 gorilla_gorilla_gene.h5
    -rw-r--r-- 1  999  999 247467360 Apr  6 20:36 heterocephalus_glaber_gene.h5
    -rw-r--r-- 1  999  999 253848593 Apr  9 20:30 homo_sapiens_gene.h5
    -rw-r--r-- 1  999  999 241397301 Apr  6 20:36 lytechinus_variegatus_gene.h5
    -rw-r--r-- 1  999  999 232003346 Apr  6 20:36 macaca_mulatta_gene.h5
    -rw-r--r-- 1  999  999 227994768 Apr  6 20:36 monodelphis_domestica_gene.h5
    -rw-r--r-- 1  999  999 236434532 Apr  6 20:36 mus_musculus_gene.h5
    -rw-r--r-- 1  999  999 185841076 Apr  6 20:36 ornithorhynchus_anatinus_gene.h5
    -rw-r--r-- 1  999  999 219788872 Apr  6 20:36 oryctolagus_cuniculus_gene.h5
    -rw-r--r-- 1  999  999 250848188 Apr  6 20:36 pan_troglodytes_gene.h5
    -rw-r--r-- 1  999  999 110824302 Apr  6 20:36 petromyzon_marinus_gene.h5
    -rw-r--r-- 1  999  999  55976107 Apr  6 20:36 plasmodium_falciparum_gene.h5
    -rw-r--r-- 1  999  999 246155562 Apr  6 20:36 rattus_norvegicus_gene.h5
    -rw-r--r-- 1  999  999  70036928 Apr  6 20:36 saccharomyces_cerevisiae_gene.h5
    -rw-r--r-- 1  999  999 324815366 Apr  6 20:36 spongilla_lacustris_gene.h5
    -rw-r--r-- 1  999  999 264162244 Apr  6 20:36 stylophora_pistillata_gene.h5
    -rw-r--r-- 1  999  999 235211854 Apr  6 20:36 sus_scrofa_gene.h5
    -rw-r--r-- 1  999  999 369537366 Apr  6 20:36 xenopus_laevis_gene.h5
    -rw-r--r-- 1  999  999 235184246 Apr  6 20:36 xenopus_tropicalis_gene.h5

Next, we'll read the configuration files—both the model-variant and inference configs inside /conf—and merge them before running inference. Be sure to update the path to the pre-computed ESM-2 embeddings for your organism of interest.

We'll also download the data. A utility function, bgee_testis_evolution, fetches the processed dataset; it takes the species name and the destination file path as arguments. In this tutorial we'll train a cell-type classifier on Macaque cells and transfer the labels to Marmoset. Neither species was used during model training, underscoring TranscriptFormer's ability to generalize across the tree of life.

# download the data
adata = bgee_testis_evolution(organism="macaque")

# Load the inference config
cfg = OmegaConf.load("./conf/inference_config.yaml")
logging.debug(OmegaConf.to_yaml(cfg))

# Change the checkpoint path to the model variatn of interest, here tf_exemplar
cfg.model.checkpoint_path = "./checkpoints/tf_exemplar"

# Load the model specific configs
config_path = os.path.join(cfg.model.checkpoint_path, "config.json")
with open(config_path) as f:
    config_dict = json.load(f)
model_cfg = OmegaConf.create(config_dict)

# Merge the model-specific configs with the inference config
cfg = OmegaConf.merge(model_cfg, cfg)

# Set the checkpoint paths based on the unified checkpoint_path
cfg.model.inference_config.load_checkpoint = os.path.join(cfg.model.checkpoint_path, "model_weights.pt")
cfg.model.data_config.aux_vocab_path = os.path.join(cfg.model.checkpoint_path, "vocabs")
cfg.model.data_config.esm2_mappings_path = os.path.join(cfg.model.checkpoint_path, "vocabs")

# Change the precomputed ESM2 embeddings path
cfg.model.inference_config.pretrained_embedding = "/content/transcriptformer/embeddings/all_embeddings/macaca_mulatta_gene.h5"

We can then run inference on the data.

*Note: While this tutorial was originally run an an instance with an A100 GPU, it is possible to run the tutorial using a T4 GPU by setting the inference batch size to 1.

# You can modify the inference batch size by specifying cfg.model.inference_config.batch_size.
# The default value is 8
# If you are running on an A100 or H100, you can set it to 16
# For a T4 instance, you can avoid running into memory errors by setting it to 1
#cfg.model.inference_config.batch_size = 1
# Set logging level to ERROR to reduce verbosity
logging.getLogger().setLevel(logging.ERROR)

adata_output = run_inference(cfg, data_files=[adata])

We'll now embed the dataset to which we want to map cell types, in this case, the Marmoset dataset. First, we'll download the data, then run the model in inference mode to generate the embeddings.

adata_map = bgee_testis_evolution(organism="marmoset")
cfg.model.inference_config.pretrained_embedding = "/content/transcriptformer/embeddings/all_embeddings/callithrix_jacchus_gene.h5"
adata_map_output = run_inference(cfg, data_files=[adata_map])

Next, we'll train a classifier to predict cell-type labels. We'll train the classifier on the first dataset (Macaque) and then apply it to the second dataset (Marmoset) to infer those labels.

pipeline = Pipeline(
                [
                    ("scaler", StandardScaler()),
                    ("knn", KNeighborsClassifier()),
                ]
            )

embeddings, labels = adata_output.obsm["embeddings"], adata_output.obs["cell_type"]
embeddings_map = adata_map_output.obsm["embeddings"]
pipeline.fit(embeddings, labels)
classes = pipeline.predict(embeddings_map)

The k-NN classifier is trained on the Macaque data and then used to predict cell types in the Marmoset data. The predictions are saved in the predicted_cell_type column of the obs attribute of the AnnData object. The UMAP layout is computed from the transcriptformer_embedding representation.

adata_map.obs['predicted_cell_type'] = pd.Categorical(classes)
adata_map.obsm["transcriptformer_embedding"] = embeddings_map
sc.pp.neighbors(adata_map, use_rep="transcriptformer_embedding")
sc.tl.umap(adata_map)

We can now plot the predicted cell-type labels for the second dataset. Using the TranscriptFormer embeddings, we'll compute a UMAP projection and visualize the results. The k-NN classifier appears to have transferred the cell types accurately from one species to the other.

sc.pl.umap(
    adata_map,
    color=['predicted_cell_type', "cell_type"],
    ncols=2,
    frameon=False,
    wspace=0.3,  # Add space between plots
    title=['Predicted Cell Type', 'True cell Type'],  # Add descriptive titles
)
UMAP plot of predicted cell types

We can also visualize this by plotting a confusion matrix that compares the predicted and true cell types.

true_labels = adata_map.obs['cell_type']
pred_labels = adata_map.obs['predicted_cell_type']

# Get unique labels from both true and predicted
all_labels = np.unique(np.concatenate([true_labels.cat.categories, pred_labels.cat.categories]))

# Compute the confusion matrix
cm = confusion_matrix(true_labels, pred_labels, labels=all_labels)

# Normalize the confusion matrix by row (true labels)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Create a heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(cm_normalized, annot=True, fmt='.1f', cmap='Blues',
            xticklabels=all_labels, yticklabels=all_labels)
plt.xlabel('Predicted Cell Type')
plt.ylabel('True Cell Type')
plt.title('Confusion Matrix: True vs Predicted Cell Types')
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
Confusion matrix of predicted vs true cell types

Finally, we can plot the embeddings of both datasets within the same UMAP space.

# Concatenate the AnnData objects
adata_joint = ad.concat(
    [adata_output, adata_map_output],
    label="dataset",
    keys=["macaque", "marmoset"],
)

# Perform UMAP and plot
sc.pp.neighbors(adata_joint, use_rep="embeddings")
sc.tl.umap(adata_joint)
sc.pl.umap(
    adata_joint,
    color=["cell_type", "dataset"],
    ncols=2,
    frameon=False,
    wspace=0.3,  # Add space between plots
    title=['Cell Type', 'Dataset'],  # Add descriptive titles
)
UMAP plot of embeddings from both datasets