Quick Start: Out-of-Distribution Cross-species Cell Type Classification with Transcriptformer
Estimated time to complete: Under 15 minutes with A100 GPU.
Google Colab Note: It is strongly recommended to run this notebook with an A100 GPU, which is only included with Google Colab Pro or Enterprise paid services. Alternatively, a "pay as you go" option is available to purchase premium GPUs. See Colab Service Plans for details.
Learning Goals
This tutorial walks through how TranscriptFormer embeddings can be used to transfer cell-type annotations between species.
We will replicate part of the analysis from the Transcriptformer publication, Figure 3. We will be using the spermatogenesis scRNA-seq data from Murat et al..
What you'll learn
- Embed cells with TranscriptFormer (TF-Exemplar) into a shared, species-agnostic latent space.
- Train a classifier on annotated cells from one species.
- Evaluate cross-species accuracy by predicting cell types in another species.
- Visualize results to understand how well biological structure is conserved.
Pre-requisites
- A100 GPU or equivalent (Note: It is possible to run this tutorial on a T4 GPU by setting our inference batch size to 2 (example below in the inference section), however this will increase the inference time)
- Python 3.11
Introduction
Model
TranscriptFormer is a family of generative foundation models representing a cross-species generative cell atlas trained on up to 112 million cells spanning 1.53 billion years of evolution across 12 species. The models include three distinct versions of TranscriptFormer, trained on different collections of data:
- TF-Metazoa (12 species)
- TF-Exemplar (5 model organisms)
- TF-Sapiens (human)
TranscriptFormer is designed to learn rich, context-aware representations of single-cell transcriptomes while jointly modeling genes and transcripts using a novel generative architecture. It employs a generative autoregressive joint model over genes and their expression levels per cell across species, with a transformer-based architecture, including a novel coupling between gene and transcript heads, expression-aware multi-head self-attention, causal masking, and a count likelihood to capture transcript-level variability. More details on the architecture of TranscriptFormer can be found in the “Model Details” section. TranscriptFormer demonstrates robust zero-shot performance for cell type classification across species, disease state identification in human cells, and prediction of cell type specific transcription factors and gene-gene regulatory relationships. This work establishes a powerful framework for integrating and interrogating cellular diversity across species as well as offering a foundation for in-silico experimentation with a generative single-cell atlas model.
Example Dataset
The Macaque and Marmoset datasets used in this demonstration were retrieved from bgee.org.
Setup
Setup Google Colab
To run this quickstart using Google Colab, you will need to choose the 'A100' GPU runtime from the "Connect" dropdown menu in the upper-right corner of this notebook. Note that this runtime configuration is not available in the free Colab version. To access premium GPUs, you will need to purchase additional compute units. The current quickstart was tested in Colab Enterprise using the following runtime configuration:
- Machine type: a2-highgpu-1g
- GPU type: NVIDIA_TESLA_A100 x 1
- Data disk type:100 GB Standard Disk (pd-standard)
It is also possible to run this quick start on a free-tier T4 instance by setting the inference batch size to 1 (see below in the inference section).
Setup Local Environment
Install Conda once & auto-restart
Below we use condacolab to bootstrap mamba, create an isolated env, install matching PyTorch + TorchVision, and finally install the Transcriptformer repo (that currently lives on Drive).
!python --version
!pip install -q condacolab
import condacolab
condacolab.install()
import condacolab
condacolab.check()
!conda install pip
Clone Transcriptformer repository and install dependencies.
Below we clone the transcriptformer repo and install from source. We can also install the pacakge via its PyPi distribution.
!git clone https://github.com/czi-ai/transcriptformer.git
%cd transcriptformer
Install Transcriptformer from source:
!uv pip install --system . torchvision
Import dependencies
First, we'll import some useful libraries.
%load_ext autoreload
%autoreload 2
import json
import logging
import os
from omegaconf import OmegaConf
from transcriptformer.model.inference import run_inference
from transcriptformer.datasets import bgee_testis_evolution, download_all_embeddings
import pandas as pd
import numpy as np
import anndata as ad
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import scanpy as sc
Download Transcriptformer Model Weights and ESM2 embeddings for out-of-distribution species
We need to download the transcriptformer model weights. You can use the python script download_artifacts.py
to
download the weights. We will be using the tf-exemplar model, trained on 5 model organisms (H. sapiens, M. musculus,
D. rerio, D. melanogaster, C. elegans). The model checkpoints and associated artifacts are ~4.2GB in total for this
model variant.
Since the species we will be showing in this notebook, Rhesus macaque and Marmoset, are out of distribution species for
the model (the model was not trained on scRNA-seq data from those species), we will also have to download the ESM2 gene
embeddings. We provide a convenience function download_all_embeddings
which will download and unzip all embeddings. The
tar file is ~4.8GB and uncompresed is ~5GB.
Note: This download will take few minutes.
!python ./download_artifacts.py tf-exemplar
download_all_embeddings(path="./embeddings")
!ls -la ./embeddings/all_embeddings/
Output:
total 5216076
drwxr-xr-x 2 999 999 4096 Apr 9 20:30 .
drwxr-xr-x 3 root root 4096 Apr 28 19:41 ..
-rw-r--r-- 1 999 999 212696350 Apr 6 20:36 caenorhabditis_elegans_gene.h5
-rw-r--r-- 1 999 999 235371868 Apr 6 20:36 callithrix_jacchus_gene.h5
-rw-r--r-- 1 999 999 324231520 Apr 6 20:49 danio_rerio_gene.h5
-rw-r--r-- 1 999 999 148914470 Apr 6 20:36 drosophila_melanogaster_gene.h5
-rw-r--r-- 1 999 999 180101052 Apr 6 20:36 gallus_gallus_gene.h5
-rw-r--r-- 1 999 999 232352100 Apr 6 20:36 gorilla_gorilla_gene.h5
-rw-r--r-- 1 999 999 247467360 Apr 6 20:36 heterocephalus_glaber_gene.h5
-rw-r--r-- 1 999 999 253848593 Apr 9 20:30 homo_sapiens_gene.h5
-rw-r--r-- 1 999 999 241397301 Apr 6 20:36 lytechinus_variegatus_gene.h5
-rw-r--r-- 1 999 999 232003346 Apr 6 20:36 macaca_mulatta_gene.h5
-rw-r--r-- 1 999 999 227994768 Apr 6 20:36 monodelphis_domestica_gene.h5
-rw-r--r-- 1 999 999 236434532 Apr 6 20:36 mus_musculus_gene.h5
-rw-r--r-- 1 999 999 185841076 Apr 6 20:36 ornithorhynchus_anatinus_gene.h5
-rw-r--r-- 1 999 999 219788872 Apr 6 20:36 oryctolagus_cuniculus_gene.h5
-rw-r--r-- 1 999 999 250848188 Apr 6 20:36 pan_troglodytes_gene.h5
-rw-r--r-- 1 999 999 110824302 Apr 6 20:36 petromyzon_marinus_gene.h5
-rw-r--r-- 1 999 999 55976107 Apr 6 20:36 plasmodium_falciparum_gene.h5
-rw-r--r-- 1 999 999 246155562 Apr 6 20:36 rattus_norvegicus_gene.h5
-rw-r--r-- 1 999 999 70036928 Apr 6 20:36 saccharomyces_cerevisiae_gene.h5
-rw-r--r-- 1 999 999 324815366 Apr 6 20:36 spongilla_lacustris_gene.h5
-rw-r--r-- 1 999 999 264162244 Apr 6 20:36 stylophora_pistillata_gene.h5
-rw-r--r-- 1 999 999 235211854 Apr 6 20:36 sus_scrofa_gene.h5
-rw-r--r-- 1 999 999 369537366 Apr 6 20:36 xenopus_laevis_gene.h5
-rw-r--r-- 1 999 999 235184246 Apr 6 20:36 xenopus_tropicalis_gene.h5
Next, we'll read the configuration files—both the model-variant and inference configs inside /conf—and merge them before running inference. Be sure to update the path to the pre-computed ESM-2 embeddings for your organism of interest.
We'll also download the data. A utility function, bgee_testis_evolution, fetches the processed dataset; it takes the species name and the destination file path as arguments. In this tutorial we'll train a cell-type classifier on Macaque cells and transfer the labels to Marmoset. Neither species was used during model training, underscoring TranscriptFormer's ability to generalize across the tree of life.
# download the data
adata = bgee_testis_evolution(organism="macaque")
# Load the inference config
cfg = OmegaConf.load("./conf/inference_config.yaml")
logging.debug(OmegaConf.to_yaml(cfg))
# Change the checkpoint path to the model variatn of interest, here tf_exemplar
cfg.model.checkpoint_path = "./checkpoints/tf_exemplar"
# Load the model specific configs
config_path = os.path.join(cfg.model.checkpoint_path, "config.json")
with open(config_path) as f:
config_dict = json.load(f)
model_cfg = OmegaConf.create(config_dict)
# Merge the model-specific configs with the inference config
cfg = OmegaConf.merge(model_cfg, cfg)
# Set the checkpoint paths based on the unified checkpoint_path
cfg.model.inference_config.load_checkpoint = os.path.join(cfg.model.checkpoint_path, "model_weights.pt")
cfg.model.data_config.aux_vocab_path = os.path.join(cfg.model.checkpoint_path, "vocabs")
cfg.model.data_config.esm2_mappings_path = os.path.join(cfg.model.checkpoint_path, "vocabs")
# Change the precomputed ESM2 embeddings path
cfg.model.inference_config.pretrained_embedding = "/content/transcriptformer/embeddings/all_embeddings/macaca_mulatta_gene.h5"
We can then run inference on the data.
*Note: While this tutorial was originally run an an instance with an A100 GPU, it is possible to run the tutorial using a T4 GPU by setting the inference batch size to 1.
# You can modify the inference batch size by specifying cfg.model.inference_config.batch_size.
# The default value is 8
# If you are running on an A100 or H100, you can set it to 16
# For a T4 instance, you can avoid running into memory errors by setting it to 1
#cfg.model.inference_config.batch_size = 1
# Set logging level to ERROR to reduce verbosity
logging.getLogger().setLevel(logging.ERROR)
adata_output = run_inference(cfg, data_files=[adata])
We'll now embed the dataset to which we want to map cell types, in this case, the Marmoset dataset. First, we'll download the data, then run the model in inference mode to generate the embeddings.
adata_map = bgee_testis_evolution(organism="marmoset")
cfg.model.inference_config.pretrained_embedding = "/content/transcriptformer/embeddings/all_embeddings/callithrix_jacchus_gene.h5"
adata_map_output = run_inference(cfg, data_files=[adata_map])
Next, we'll train a classifier to predict cell-type labels. We'll train the classifier on the first dataset (Macaque) and then apply it to the second dataset (Marmoset) to infer those labels.
pipeline = Pipeline(
[
("scaler", StandardScaler()),
("knn", KNeighborsClassifier()),
]
)
embeddings, labels = adata_output.obsm["embeddings"], adata_output.obs["cell_type"]
embeddings_map = adata_map_output.obsm["embeddings"]
pipeline.fit(embeddings, labels)
classes = pipeline.predict(embeddings_map)
The k-NN classifier is trained on the Macaque data and then used to predict cell types in the Marmoset data. The predictions are saved in the predicted_cell_type column of the obs attribute of the AnnData object. The UMAP layout is computed from the transcriptformer_embedding representation.
adata_map.obs['predicted_cell_type'] = pd.Categorical(classes)
adata_map.obsm["transcriptformer_embedding"] = embeddings_map
sc.pp.neighbors(adata_map, use_rep="transcriptformer_embedding")
sc.tl.umap(adata_map)
We can now plot the predicted cell-type labels for the second dataset. Using the TranscriptFormer embeddings, we'll compute a UMAP projection and visualize the results. The k-NN classifier appears to have transferred the cell types accurately from one species to the other.
sc.pl.umap(
adata_map,
color=['predicted_cell_type', "cell_type"],
ncols=2,
frameon=False,
wspace=0.3, # Add space between plots
title=['Predicted Cell Type', 'True cell Type'], # Add descriptive titles
)

We can also visualize this by plotting a confusion matrix that compares the predicted and true cell types.
true_labels = adata_map.obs['cell_type']
pred_labels = adata_map.obs['predicted_cell_type']
# Get unique labels from both true and predicted
all_labels = np.unique(np.concatenate([true_labels.cat.categories, pred_labels.cat.categories]))
# Compute the confusion matrix
cm = confusion_matrix(true_labels, pred_labels, labels=all_labels)
# Normalize the confusion matrix by row (true labels)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# Create a heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(cm_normalized, annot=True, fmt='.1f', cmap='Blues',
xticklabels=all_labels, yticklabels=all_labels)
plt.xlabel('Predicted Cell Type')
plt.ylabel('True Cell Type')
plt.title('Confusion Matrix: True vs Predicted Cell Types')
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

Finally, we can plot the embeddings of both datasets within the same UMAP space.
# Concatenate the AnnData objects
adata_joint = ad.concat(
[adata_output, adata_map_output],
label="dataset",
keys=["macaque", "marmoset"],
)
# Perform UMAP and plot
sc.pp.neighbors(adata_joint, use_rep="embeddings")
sc.tl.umap(adata_joint)
sc.pl.umap(
adata_joint,
color=["cell_type", "dataset"],
ncols=2,
frameon=False,
wspace=0.3, # Add space between plots
title=['Cell Type', 'Dataset'], # Add descriptive titles
)
