Quickstart: GREmLN

Estimated time to complete: 10 minutes

Google Colab Note: It is strongly recommended to run this notebook with an A100 GPU, which is only included with Google Colab Pro or Enterprise paid services. Alternatively, a "pay as you go" option is available to purchase premium GPUs. See Colab Service Plans for details.

Learning Goals

  1. Understand what makes GREmLN novel, and what inputs and dependencies are required to use it effectively.
  2. Learn how to generate cell-level and gene-level embeddings from single-cell RNA-seq data using GREmLN's inference API.
  3. Appreciate the unique biological insights captured by GREmLN embeddings, including how they reflect regulatory structure, expression dynamics, and functional roles of genes and cells.

Prerequisites

Running GREmLN requires an advanced GPU (e.g., an NVIDIA A100) and a Python environment with the following packages installed:

numpy
pandas
scanpy
scipy
scikit-learn
pyarrow
loralib
viper-in-python
python-igraph
leidenalg
louvain
torch
torch-geometric
lightning
flash-attn

For detailed setup instructions, including environment configuration and package installation tips, see the Setup section of this tutorial.

Introduction

GREmLN (Gene Regulatory Embedding-based Large Neural model) is a foundation model designed to process single-cell RNA sequencing (scRNA-seq) data using the structure of gene regulatory networks (GRNs). Unlike conventional transformer models that treat gene expression data as a sequence of tokens, GREmLN integrates biologically meaningful relationships by embedding GRN structure directly into the model's attention mechanism. This allows it to capture long-range regulatory dependencies and produce high-quality cell and gene embeddings that are useful for tasks such as cell type classification and regulatory network inference.

Inputs and Outputs

(Inputs) There are two inputs to GREmLN:

  1. A preprocessed scRNA-seq gene expression matrix (cells × genes)
  2. A gene regulatory network (GRN), constructed for each cell type or cell lineage

(Outputs) For each cell, GREmLN produces an embedding matrix (G x 512) where:

  1. Row i represents a 512-dimensional embedding vector for gene i.
  2. Column j represents a learned hidden feature dimension shared across genes.

Demo Datasets

To help you get started, we provide two small datasets along with their corresponding GRNs. These are designed to demonstrate use of GREmLN's inference API and illustrate key properties of its embeddings.

1. Human Immune Cells

A sample of 2,473 cells from both human bone marrow and peripheral blood, covering 9 distinct immune cell types. This dataset was provided by scGPT and was not used during GREmLN's pretraining. The relevant files are:

  • GREmLN_Tutorial/data/human_immune_cells.h5ad -- preprocessed gene expression matrix

  • GREmLN_Tutorial/networks/{cell_type}/network.tsv -- a GRN for each immune cell type in the dataset, inferred using ARACNe

2. Epithelial Cells

A sample of 1,000 epithelial cells from the CELLxGENE project. These data were explicitly held out during GREmLN's pretraining to allow for evaluation on unseen cell types. The relevant files are:

  • GREmLN_Tutorial/data/epithelial_cells.h5ad -- preprocessed gene expression matrix

  • GREmLN_Tutorial/networks/epithelial_cell/network.tsv -- a GRN for epithelial cells, inferred using ARACNe

Setup

GREmLN requires an NVIDIA GPU with compatible CUDA drivers (e.g., CUDA 11 or 12) for optimal performance. Its dependencies include a stack of high-performance libraries that are sensitive to your system's CUDA version, GPU hardware, and operating system (especially differences between Linux, macOS, and Windows).

To avoid compatibility issues, we recommend installing critical dependencies such as PyTorch, PyTorch Lightning, and FlashAttention manually, following instructions specific to your hardware and OS.

GREmLN's setup.py assumes that torch, torch-geometric and flash-attn have already been installed and are compatible with your system. Once these are installed, clone the scGraphLLM package and run pip install scGraphLLM.

Setup Google Colab

(note that you may have to restart your environment once you have successfully installed all dependencies)

If you are using Google Colab with an NVIDIA A100 GPU and CUDA 12.1, you can install the required dependencies with the following commands:

Note: After installation, you may need to restart the runtime for changes to take effect.

Note: An Ampere tier GPU is necessary for the tutorial because Flash Attention does not currently support T tier GPU's.

# Install PyTorch and torchvision etc. for CUDA 12.1
!pip install torch=='2.4.1+cu121' torchvision=='0.19.1+cu121' torchaudio=='2.4.1+cu121' --index-url https://download.pytorch.org/whl/cu121

# Install PyG packages (torch-geometric, torch-scatter etc) for CUDA 12.1
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.1+cu121.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.1+cu121.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.1+cu121.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.1+cu121.html
!pip install torch-geometric
!pip install lightning

# Install flash-attn from PyPI if available for CUDA 12.1
!pip install flash-attn --no-build-isolation --extra-index-url https://download.pytorch.org/whl/cu121

You can clone the repo for GREmLN here: https://github.com/czi-ai/GREmLN.git

!git clone https://github.com/czi-ai/GREmLN.git
# # Install scGraphLLM with pip via setup.py once you have downloaded the repository
%cd GREmLN
!pip install .

Download the model weights, ARACNe networks, and demo data for the tutorial via gdown (you can also find those assets here).

!pip install gdown #gdown will enable us to download the tutorial assets from google drive
#%cd content - unnecessary if session was restarted, if not then you need to navigate back to the content directory
Output:

/content
!ls
Output:

GREmLN	sample_data
!gdown --folder https://drive.google.com/drive/folders/1cMR9HoAC22i6sKSWgfQUEQRf0UP_w3_m?usp=sharing
Output:

Retrieving folder contents
Processing file 19RMsAKriM-rY6Y4u4SXQ_9nH3RASEVqT model.ckpt
Processing file 1SBy8C0BOeQV5C68V63tDXeaQW6Y3zLGc .DS_Store
Retrieving folder 1xgN0H3EKagBNm2_J1LpVzOCW-6sIr5O0 networks
Retrieving folder 1lKZwd6S2qnG6EPLMc-AnQf1XKKLjczI6 cd20_b_cells
Processing file 1HRm4S_HaXdDjyTt-JLz71-Ha2c4quaA6 network.tsv
Retrieving folder 17XelZp5_htLx58zaPqGvlu-FQSYN7z-2 erythrocytes
Processing file 1wRfvbn_v3NiyVetb-mO9ytXM3pAYBT5P network.tsv
Retrieving folder 191saUqqZKkX6Eyw35F3332kQdV0pWc8I cd16_monocytes
Processing file 17AQh02OmCTo0_hJp6eY_1UM6IEn7zVBc network.tsv
Retrieving folder 1acUBPLtTlZlCo_023yakTM4qISMu0es6 epithelial_cell
Processing file 1JTV_sEQvIEpCiBZBMRWwVMiiWhmp57lc network.tsv
Retrieving folder 1JB_EahNh3NkieVeMHtWnt-9AQP7psar7 nkt_cells
Processing file 1BmSvnG9ejWfFboYkBmcYscrNk6Yhgk8a network.tsv
Retrieving folder 1eQhWN6_jNx5TTQooSt4SoaNfTZHTyOCl cd4_t_cells
Processing file 1AeQHOfpHlT0sA6IlnN7uv7lx2Y_YmMn5 network.tsv
Retrieving folder 17f52EdtuA1n8L1CbPwWKx3hdSIZp7xGa monocyte-derived_dendritic_cells
Processing file 1enQZiDGvDJ403xsTRG6IwYTgqnUKvJib network.tsv
Retrieving folder 1J39eDyuJg-aJ8AOQ5Ia3g0sQYAr7z1Rm nk_cells
Processing file 1xjIcMDB_iyAxdabu40PTWN6dOe97JJ0e network.tsv
Retrieving folder 1eV294cXL247a7x8LDzLsz1NZcN-J7Cup cd8_t_cells
Processing file 1La5qAJcz1jAuPY_l3wnfjSMm2awwd3PG network.tsv
Retrieving folder 12aLFifTSUoAMeAHm4k-OlKc3hWU_0gIJ cd14_monocytes
Processing file 1eBb8LMwDNw8xRwFB7s75EzbtjZqlV-_o network.tsv
Retrieving folder 1mh59dKR03FD3yJ9UY0NaWeq5D9tWFgll data
Processing file 1ijOpjqw0hMDPdeac0sPcMJZHSoXB2VGn human_immune_cells.h5ad
Processing file 1T6CQCUf7sY28qfHpXDK7yzKRXD9i78Ou epithelial_cells.h5ad
Retrieving folder contents completed
Building directory structure
Building directory structure completed
Downloading...
From (original): https://drive.google.com/uc?id=19RMsAKriM-rY6Y4u4SXQ_9nH3RASEVqT
From (redirected): https://drive.google.com/uc?id=19RMsAKriM-rY6Y4u4SXQ_9nH3RASEVqT&confirm=t&uuid=a18cb287-bb55-4b68-acc0-72a72915b19f
To: /content/GREmLN_Tutorial/model.ckpt
100% 120M/120M [00:01<00:00, 73.0MB/s]
Downloading...
From: https://drive.google.com/uc?id=1SBy8C0BOeQV5C68V63tDXeaQW6Y3zLGc
To: /content/GREmLN_Tutorial/.DS_Store
100% 6.15k/6.15k [00:00<00:00, 30.3MB/s]
Downloading...
From: https://drive.google.com/uc?id=1HRm4S_HaXdDjyTt-JLz71-Ha2c4quaA6
To: /content/GREmLN_Tutorial/networks/cd20_b_cells/network.tsv
100% 67.2k/67.2k [00:00<00:00, 2.56MB/s]
Downloading...
From: https://drive.google.com/uc?id=1wRfvbn_v3NiyVetb-mO9ytXM3pAYBT5P
To: /content/GREmLN_Tutorial/networks/erythrocytes/network.tsv
100% 1.22M/1.22M [00:00<00:00, 10.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=17AQh02OmCTo0_hJp6eY_1UM6IEn7zVBc
To: /content/GREmLN_Tutorial/networks/cd16_monocytes/network.tsv
100% 78.7k/78.7k [00:00<00:00, 2.26MB/s]
Downloading...
From: https://drive.google.com/uc?id=1JTV_sEQvIEpCiBZBMRWwVMiiWhmp57lc
To: /content/GREmLN_Tutorial/networks/epithelial_cell/network.tsv
100% 10.8M/10.8M [00:00<00:00, 54.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1BmSvnG9ejWfFboYkBmcYscrNk6Yhgk8a
To: /content/GREmLN_Tutorial/networks/nkt_cells/network.tsv
100% 150k/150k [00:00<00:00, 3.67MB/s]
Downloading...
From: https://drive.google.com/uc?id=1AeQHOfpHlT0sA6IlnN7uv7lx2Y_YmMn5
To: /content/GREmLN_Tutorial/networks/cd4_t_cells/network.tsv
100% 88.1k/88.1k [00:00<00:00, 2.98MB/s]
Downloading...
From: https://drive.google.com/uc?id=1enQZiDGvDJ403xsTRG6IwYTgqnUKvJib
To: /content/GREmLN_Tutorial/networks/monocyte-derived_dendritic_cells/network.tsv
100% 334k/334k [00:00<00:00, 4.81MB/s]
Downloading...
From: https://drive.google.com/uc?id=1xjIcMDB_iyAxdabu40PTWN6dOe97JJ0e
To: /content/GREmLN_Tutorial/networks/nk_cells/network.tsv
100% 23.8k/23.8k [00:00<00:00, 32.2MB/s]
Downloading...
From: https://drive.google.com/uc?id=1La5qAJcz1jAuPY_l3wnfjSMm2awwd3PG
To: /content/GREmLN_Tutorial/networks/cd8_t_cells/network.tsv
100% 203k/203k [00:00<00:00, 3.95MB/s]
Downloading...
From: https://drive.google.com/uc?id=1eBb8LMwDNw8xRwFB7s75EzbtjZqlV-_o
To: /content/GREmLN_Tutorial/networks/cd14_monocytes/network.tsv
100% 120k/120k [00:00<00:00, 3.45MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1ijOpjqw0hMDPdeac0sPcMJZHSoXB2VGn
From (redirected): https://drive.google.com/uc?id=1ijOpjqw0hMDPdeac0sPcMJZHSoXB2VGn&confirm=t&uuid=eec30c26-b7de-40da-89e2-6318ff0a083a
To: /content/GREmLN_Tutorial/data/human_immune_cells.h5ad
100% 179M/179M [00:01<00:00, 90.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1T6CQCUf7sY28qfHpXDK7yzKRXD9i78Ou
To: /content/GREmLN_Tutorial/data/epithelial_cells.h5ad
100% 34.5M/34.5M [00:00<00:00, 75.8MB/s]
Download completed
# Change to working directory to the GREmLN_Tutorial directory, containing demo data
import os
os.chdir("/content/GREmLN_Tutorial")
print("Current working directory:", os.getcwd())
Output:

Current working directory: /content/GREmLN_Tutorial

Run Model Inference

To run model inference and generate single-cell embeddings with our trained model:

  1. Preprocess your data

Normalize raw counts using Counts Per Million (CPM), then apply a log1p transformation i.e., log(1 + CPM)

  1. Generate a gene regulatory network (GRN):

You can use any GRN inference method, but we recommend ARACNe for consistent and biologically meaningful results. Note: networks should be cell type specific.

  1. Prepare inputs with our API:

Use the provided RegulatoryNetwork, GraphTokenizer and InferenceDataset classes to convert your processed data and GRN into the model's required input format (demonstrated below).

  1. Perform inference with our API

Generate cell or gene embeddings using the inference module (demonstrated below)

The scGraphLLM package can be used to prepare inputs and and generate embeddings in the following way:

import pandas as pd
from scGraphLLM import RegulatoryNetwork, GeneVocab, GraphTokenizer, InferenceDataset
from scGraphLLM.models import GDTransformer
from scGraphLLM.config import graph_kernel_attn_3L_4096
from scGraphLLM.inference import get_cell_embeddings

# Load your preprocessed single-cell expression data
data = pd.read_csv("your_data.h5ad")

# Load your regulatory network
network = RegulatoryNetwork.from_csv("your_network_file.tsv", sep="\t")

# Load trained model checkpoint
model = GDTransformer.load_from_checkpoint("path_to_model.ckpt", config=graph_kernel_attn_3L_4096)

# Load default gene vocabulary
vocab = GeneVocab.load_default()

# Create tokenizer from vocab and network
tokenizer = GraphTokenizer(vocab=vocab, network=network)

# Create dataset for inference
dataset = InferenceDataset(expression=data, tokenizer=tokenizer)

# Run inference and get pooled cell embeddings
embeddings_df = get_cell_embeddings(dataset, model)

Generating Cell Embeddings for Human Immune Cells with GREmLN

Estimated Run Time: 5 Minutes

GREmLN cell embeddings reflect important cell properties and functions by summarizing the gene expression landscape of a cell in the context of its underlying regulatory network. The embeddings serve as compact, biologically informed representations that can be used for downstream tasks such as cell type classification, trajectory inference, or response prediction.

In this example, we apply GREmLN to human immune cells to extract embeddings that distinguish between different immune cell states and subtypes.

# Load Human Immune Cells
import scanpy as sc
immune_cells = sc.read_h5ad("data/human_immune_cells.h5ad")
immune_cells
Output:

AnnData object with n_obs × n_vars = 2473 × 11971
    obs: 'batch', 'chemistry', 'data_type', 'dpt_pseudotime', 'final_annotation', 'mt_frac', 'n_counts', 'n_genes', 'sample_ID', 'size_factors', 'species', 'study', 'tissue', 'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'sample_id', 'cluster', 'set'
    var: 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    obsm: 'X_pca'
    layers: 'counts'
# Load Gene Vocabulary used during model pretraining
from scGraphLLM.vocab import GeneVocab
vocab = GeneVocab.load_default()
vocab
Output:

GeneVocab with 19,247 genes
Special tokens present: ['<CLS>', '<PAD>', '<MASK>']
# Load pretrained GREmLN model
from scGraphLLM.models import GDTransformer
from scGraphLLM.config import graph_kernel_attn_3L_4096
model = GDTransformer.load_from_checkpoint("model.ckpt", config=graph_kernel_attn_3L_4096)
# Get GREmLN embeddings for all immune cells (3 minutes)
from scGraphLLM import RegulatoryNetwork, GraphTokenizer, InferenceDataset
from scGraphLLM.inference import get_cell_embeddings

emb_list = []
for cell_type in immune_cells.obs["cell_type"].unique():
    # select cells in cell_type
    cells = immune_cells[immune_cells.obs["cell_type"] == cell_type]
    print(f"Producing GREmLN embeddings for {len(cells):,} {cell_type} cells...")

    # load cell type network
    network = RegulatoryNetwork.from_csv(f"networks/{cell_type}/network.tsv", sep="\t")

    # Initialize Graph Tokenizer to tokenize single cell data
    tokenizer = GraphTokenizer(vocab=vocab, network=network)

    # initialize dataset for inference
    dataset = InferenceDataset(expression=cells.to_df(), tokenizer=tokenizer)

    # get cell embeddings via forward pass using `get_cell_embeddings` from inference api
    x = get_cell_embeddings(dataset, model, vocab, cls_policy="include", batch_size=64)

    # join with metadata
    emb = sc.AnnData(x.values, obs=cells.obs.loc[x.index])
    emb_list.append(emb)
Output:

Producing GREmLN embeddings for 291 cd14_monocytes cells...
Cache Directory: None
Observation Count: 291
Forward Pass: 100%|██████████| 5/5 [00:22<00:00,  4.56s/it]
Producing GREmLN embeddings for 253 cd16_monocytes cells...
Cache Directory: None
Observation Count: 253
Forward Pass: 100%|██████████| 4/4 [00:19<00:00,  4.80s/it]
Producing GREmLN embeddings for 300 cd20_b_cells cells...
Cache Directory: None
Observation Count: 300
Forward Pass: 100%|██████████| 5/5 [00:22<00:00,  4.50s/it]
Producing GREmLN embeddings for 300 cd4_t_cells cells...
Cache Directory: None
Observation Count: 300
Forward Pass: 100%|██████████| 5/5 [00:22<00:00,  4.52s/it]
Producing GREmLN embeddings for 300 cd8_t_cells cells...
Cache Directory: None
Observation Count: 300
Forward Pass: 100%|██████████| 5/5 [00:23<00:00,  4.64s/it]
Producing GREmLN embeddings for 300 erythrocytes cells...
Cache Directory: None
Observation Count: 300
Forward Pass: 100%|██████████| 5/5 [00:24<00:00,  5.00s/it]
Producing GREmLN embeddings for 129 monocyte-derived_dendritic_cells cells...
Cache Directory: None
Observation Count: 129
Forward Pass: 100%|██████████| 3/3 [00:10<00:00,  3.49s/it]
Producing GREmLN embeddings for 300 nk_cells cells...
Cache Directory: None
Observation Count: 300
Forward Pass: 100%|██████████| 5/5 [00:22<00:00,  4.47s/it]
Producing GREmLN embeddings for 300 nkt_cells cells...
Cache Directory: None
Observation Count: 300
Forward Pass: 100%|██████████| 5/5 [00:22<00:00,  4.57s/it]
# concatenate embeddings from all cell types
from scGraphLLM.preprocess import concatenate_partitions
cell_emb = concatenate_partitions(emb_list, require_matching_metadata=True)
# project embeddings onto principal components
sc.tl.pca(cell_emb, svd_solver='arpack')
# visualizing GREmLN Embeddings for Human Immune Cells
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 8))
sc.pl.pca(cell_emb, color="cell_type", ax=ax, size=120, alpha=0.6, title="GREmLN Embeddings for Immune Cells")

Generating Gene Embeddings for Epithelial Cells with GREmLN

Estimated Run Time: 5 Minutes

GREmLN gene embeddings reflect important properties of a gene's regulatory role and biological function by integrating expression patterns with graph-structured information from gene regulatory networks. Embeddings encode functional and regulatory distinctions --- separating broadly active, housekeeping genes from those with specialized, context-specific roles.

In this example, we apply GREmLN to epithelial cells to reveal how gene regulatory architecture shapes the organization of gene activity in this specific cellular context.

# Load epithelial cells
epithelial_cells = sc.read_h5ad("data/epithelial_cells.h5ad")
epithelial_cells
Output:

AnnData object with n_obs × n_vars = 1000 × 19221
    obs: 'index', 'soma_joinid', 'dataset_id', 'assay', 'assay_ontology_term_id', 'cell_type', 'cell_type_ontology_term_id', 'development_stage', 'development_stage_ontology_term_id', 'disease', 'disease_ontology_term_id', 'donor_id', 'is_primary_data', 'observation_joinid', 'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id', 'sex', 'sex_ontology_term_id', 'suspension_type', 'tissue', 'tissue_ontology_term_id', 'tissue_type', 'tissue_general', 'tissue_general_ontology_term_id', 'raw_sum', 'nnz', 'raw_mean_nnz', 'raw_variance_nnz', 'n_measured_vars', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_counts'
    var: 'index', 'soma_joinid', 'feature_name', 'feature_length', 'nnz', 'n_measured_obs', 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'log1p'
# Load regulatory network for epithelial cells
epithelial_network = RegulatoryNetwork.from_csv("networks/epithelial_cell/network.tsv", sep="\t")
epithelial_network
Output:

RegulatoryNetwork with 183,247 edges between 14,628 genes.
Number of regulons: 1,926
Median targets per regulon: 54
Top regulators (by out-degree): ENSG00000164104 (981), ENSG00000124766 (670), ENSG00000059728 (594)
# Initialize inference dataset for epithelial cells and get gene embeddings (2 minutes)
epithelial_dataset = InferenceDataset(
    expression=epithelial_cells.to_df(),
    tokenizer=GraphTokenizer(vocab=vocab, network=epithelial_network)
)

from scGraphLLM.inference import get_gene_embeddings
x_gene = get_gene_embeddings(epithelial_dataset, model=model, vocab=vocab)

gene_emb = sc.AnnData(x_gene.values, obs=epithelial_cells.var.loc[x_gene.index])
gene_emb
Output:

Cache Directory: None
Observation Count: 1,000
Forward Pass: 100%|██████████| 4/4 [02:28<00:00, 37.23s/it]

AnnData object with n_obs × n_vars = 13706 × 512
    obs: 'index', 'soma_joinid', 'feature_name', 'feature_length', 'nnz', 'n_measured_obs', 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
# Project gene embeddings onto principal components
sc.tl.pca(gene_emb, svd_solver='arpack')
# analyze gene regulon size
import numpy as np
import pandas as pd

regulon_size = epithelial_network.df.groupby(network.reg_name).size().rename("regulon_size")
gene_emb.obs = gene_emb.obs.join(regulon_size)

gene_emb.obs["regulon_label"] = pd.cut(
    gene_emb.obs["regulon_size"],
    bins=[1, 99, np.inf],
    labels=["<100 Targets", "100+ Targets"],
    include_lowest=True
)
# Plot gene embeddings colored by regulon size
fig, ax = plt.subplots(figsize=(10, 8))
sc.pl.pca(gene_emb, color="regulon_label", ax=ax, size=80, alpha=0.6, title="GREmLN Gene Embeddings for Epithelial Cells")

As shown above, gene embeddings signify regulatory influence. Genes with large regulons (more than 100 regulatory targets) are distinct from those with small regulons (less than 100 targets) and those with no regulatory role whatsoever (gray).

# assign mean counts to quartiles
gene_emb.obs["mean_counts_quartile"] = pd.qcut(
    gene_emb.obs["mean_counts"],
    q=4, labels=["Q1", "Q2", "Q3", "Q4"]
)
# plot gene embeddings colored by mean UMI counts quartile
fig, ax = plt.subplots(figsize=(10, 8))
sc.pl.pca(
    adata=sc.pp.subsample(gene_emb, n_obs=5000, copy=True),
    color="mean_counts_quartile",
    ax=ax, size=60, alpha=0.4,
    palette=[(.0, .4, .8), (.8, 0, .1), (1.0, 0.6, 0.0), (.2, .5, 0.3)],
    title="GREmLN Gene Embeddings for Epithelial Cells",
    show=False
)
ax.get_legend().set_title("Mean UMI Count Quartile")

Gene embeddings also convey biological function. As shown above, the first principal component distinguishes genes based on their mean UMI count, a measurement which indicates commonness of gene acitvation. Commonly expressed genes tend to have housekeeping roles while rarely expressed genes determine the niche state of a cell.

Model Outputs

GREmLN is a pretrained foundation model. Given processed gene expression data and a regulatory network, its raw output is an 512-dimensional embedding vector for each expressed gene in cell, a G x 512 matrix.

Cell Embeddings

GREmLN's inference API provides support for reducing GREmLN's raw output to cell-level embeddings using via the get_cell_embeddings function. For a dataset of C cells, get_cells_embeddings will produce a dataframe of C rows and 512 columns, where row i represents the embedding for cell i.

Gene Embeddings

GREmLN's inference API also provides support for reducing GREmLN's raw output to gene-level embeddings using via the get_gene_embeddings function. For a dataset of any number of cells, get_cells_embeddings will produce a dataframe of G rows and 512 columns, where row i represents the embedding for gene i, and G is the total number of expressed genes in the entire dataset.

Contact and Acknowledgments

For issues with this quickstart please contact the developers:

Special thank you to Maximilian Lombardo (mlombardo@chanzuckerberg.com) for their consultation on this quickstart.

References

Responsible Use

We are committed to advancing the responsible development and use of artificial intelligence. Please follow our Acceptable Use Policy when engaging with our services. Should you have any security or privacy issues or questions related to the services, please reach out to our team at security@chanzuckerberg.com or privacy@chanzuckerberg.com respectively.