Quickstart: CELL-Diff

Learning Goals

  • Learn about CELL-Diff model inputs and outputs
  • Run CELL-Diff model inference with a protein sequence
  • Compare a simulated image with reference images

Introduction

CELL-Diff is a suite of unified diffusion models that facilitate bidirectional transformations between protein sequences and microscopy images. By using reference images of nuclei, microtubules, and the endoplasmic reticulum as conditional inputs, along with either an image or the sequence of a protein of interest, CELL-Diff generates the corresponding output. For instance, given reference images and a protein sequence, CELL-Diff can produce a simulated microscopy image of the protein of interest stained in the reference marker cells.

CELL-Diff has 2 models, one pre-trained on the Human Protein Atlas (HPA) data and the other was further fine-tuned using OpenCell data.

In this quickstart, running model inference with reference images and a protein sequence as inputs will be demonstrated for the HPA-trained model. Reference images from the HPA will be used along with the sequence of nucleophosmin protein 1 (NPM 1), which is primarily found in a subcompartment of the nucleus called the nucleolus.

Setup

Google Colab and the CELL-Diff repo must be set up to complete this tutorial. If you are running the tutorial locally, an environment manager should be used.

Setup Google Colab

This tutorial is a notebook that can be run within the Google Colab interface.

To start, connect to the T4 GPU runtime hosted for free by Google Colab using the dropdown menu in the upper right hand corner of this notebook. Using a GPU significantly speeds up running model inference, but CPU compute can also be used.

Note that this tutorial will use commands written for Google Colab, and some of those commands may need to be modified to work with other computing setups.

Setup CELL-Diff

CELL-Diff Github repo is a convenient code wrapper to run the CELL-Diff model in inference.

Clone Repository

To run CELL-Diff in Google Colab, start by cloning the CELL-Diff repo and navigate to the newly created CELL-Diff folder using the commands below. The folder will also be present in the file management system in Google Colab which is accessible by clicking the folder icon on the left hand side bar of this notebook.

# clone the CELL-Diff repo
!git clone https://github.com/BoHuangLab/CELL-Diff.git

# navigate the SubCellPortable directory
%cd /content/CELL-Diff

Overview of CELL-Diff

CELL-Diff contains several items in its top level directory, which are described in the table below.

File or Directory
Description / Purpose
`cell_diff/`
Directory containing the core implementation of the CELL-Diff model.
`data/`
Directory designated for storing datasets used by the model, including the demo dataset for this quickstart.
`img/`
Directory for storing images related to the Github project, such as the hero image in the README.
`LICENSE`
Licensing information; CELL-Diff is licensed under the MIT License.
`README.md`
Provides an overview of CELL-Diff, including installation instructions and usage guidelines.
`install.sh`
Shell script for setting up the environment and installing necessary dependencies for CELL-Diff.
`run_evaluate_hpa.sh`
Shell script to evaluate the model using the Human Protein Atlas (HPA) dataset.
`run_image_prediction_hpa.sh`
Shell script to generate protein images based on sequences using the HPA-trained model.
`run_image_prediction_opencell.sh`
Shell script to generate protein images based on sequences using the OpenCell-trained model.

The demo dataset for this tutorial is found in the data/HPA folder and contains 5 sets of reference images of ER, microtubules, and nuclei saved as PNG files (ER.png, microtubule.png, and nucleus.png) each in their own subdirectory numbered 1-5. (Note that the data/OpenCell directory has the same organization with only nuclei reference images.)The images in the 1 subdirectory will be used, and to view them, install matplotlib and plot the images using the code cells below.

!pip install matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Read reference images of ER, microtubules, and nucleus
ER = mpimg.imread('/content/CELL-Diff/data/hpa/1/ER.png')
MT = mpimg.imread('/content/CELL-Diff/data/hpa/1/microtubule.png')
Nuc = mpimg.imread('/content/CELL-Diff/data/hpa/1/nucleus.png')

# Create a figure with a row of 3 subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Display each image in a subplot
axes[0].imshow(ER, cmap='gray')
axes[0].set_title("ER")
axes[0].axis('off')  # Hide axes for clarity

axes[1].imshow(MT, cmap='gray')
axes[1].set_title("MT")
axes[1].axis('off')

axes[2].imshow(Nuc, cmap='gray')
axes[2].set_title("Nuc")
axes[2].axis('off')

# Adjust layout and show the figure
plt.tight_layout()
plt.show()
Reference images of ER, Microtubules, and Nucleus

The packages required for model inference are found in install.sh, and the contents of install.sh are listed below for convenience.

pip install torch torchvision torchaudio
pip install tqdm
pip install timm
pip install fair-esm
pip install loguru
pip install wandb
pip install transformers
pip install einops
pip install frc
pip install pytorch-fid

Install those packages in the following cell. This may take a few minutes.

# install requirements for model inference
!bash install.sh

All of the pretrained models are available in a public Amazon Web Services (AWS) S3 bucket, and in the next cell, the AWS Command Line Tool will be installed to facilitate accessing the models.

!pip install awscli

To list the contents of the S3 bucket with the CELL-Diff models, run the below cell.

!aws s3 ls --no-sign-request s3://czi-celldiff-public/checkpoints/

3 models are available, and for this tutorial, the HPA-trained model, hpa_checkpoint.pt will be used. Next, make a subdirectory in CELL-Diff called model_weights and download the HPA-trained model into that directory using the commands below. This may take a few minutes as the model is about 2GB.

# make a directory called model_weights
!mkdir model_weights

# download HPA-trained model in the model_weights directory
!aws s3 cp --no-sign-request s3://czi-celldiff-public/checkpoints/hpa_checkpoint.pt /content/CELL-Diff/model_weights

Run Model Inference

To run inference, the environment variables in the table below must be set before calling run_image_prediction_hpa.sh to run inference for the HPA-trained model. (Note use run_image_prediction_opencell.sh for the OpenCell-trained model)

Environment Variable
Description
`save_dir`
Directory where output files (e.g., generated images) will be saved.
`cell_morphology_image_path`
Path to the directory containing reference images saved as PNG files
`test_sequence`
Protein sequence of the target protein (e.g., NPM1) in string format, used as input for the model.
`loadcheck_path`
Path to the pretrained model weights file (e.g., `./hpa_checkpoint.pt`).
`seed`
Number for initializing diffusion model.

The cell below defines those variables and calls run_image_prediction_hpa.sh to run model inference with the HPA-trained model.

# Run the image generation script
!\
save_dir='/content/CELL-Diff/output' \
cell_morphology_image_path='/content/CELL-Diff/data/hpa/1/' \
test_sequence='MEDSMDMDMSPLRPQNYLFGCELKADKDYHFKVDNDENEHQLSLRTVSLGAGAKDELHIVEAEAMNYEGSPIKVTLATLKMSVQPTVSLGGFEITPPVVLRLKCGSGPVHISGQHLVAVEEDAESEDEEEEDVKLLSISGKRSAPGGGSKVPQKKVKLAADEDDDDDDEEDDDEDDDDDDFDDEEAEEKAPVKKSIRDTPAKNAQKSNQNGKDSKPSSTPRSKGQESFKKQEKTPKTPKGPSSVEDIKAKMQASIEKGGSLPKVEAKFINYVKNCFRMTDQEAIQDLWQWRKSL' \
loadcheck_path='/content/CELL-Diff/model_weights/hpa_checkpoint.pt' \
seed=6 \
bash run_image_prediction_hpa.sh

Model Outputs

When model inference is run to generate images from a protein sequence, CELL-Diff has 2 model outputs: pred_protein_img.png, which is the simulated microscopy image for the protein of interest and pred_protein_img_cat.png, which is the simulated microscopy image overlaid with the nuclei reference image.

To examine the simulated image, in the next cell, plot the simulated image in a grid with the reference images. Recall that the NPM1 protein sequence was used, and this protein primarily localizes to the nucleolus, a subcompartment of the nucleus.

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Read reference images of ER, microtubules, nucleus, and simulated protein
ER = mpimg.imread('/content/CELL-Diff/data/hpa/1/ER.png')
MT = mpimg.imread('/content/CELL-Diff/data/hpa/1/microtubule.png')
Nuc = mpimg.imread('/content/CELL-Diff/data/hpa/1/nucleus.png')
Sim_protein = mpimg.imread('/content/CELL-Diff/output/pred_protein_img.png')

# Create a figure with a 2x2 grid of subplots
fig, axes = plt.subplots(2, 2, figsize=(10, 10))

# Display each image in a subplot
axes[0, 0].imshow(ER, cmap='gray')
axes[0, 0].set_title("ER")
axes[0, 0].axis('off')  # Hide axes for clarity

axes[0, 1].imshow(MT, cmap='gray')
axes[0, 1].set_title("MT")
axes[0, 1].axis('off')

axes[1, 0].imshow(Nuc, cmap='gray')
axes[1, 0].set_title("Nucleus")
axes[1, 0].axis('off')

axes[1, 1].imshow(Sim_protein, cmap='gray')
axes[1, 1].set_title("NPM1 Simulated Protein")
axes[1, 1].axis('off')

# Adjust layout and show the figure
plt.tight_layout()
plt.show()
Model outputs

Contact and Acknowledgments

For issues with this quickstart, please contact virtualcellmodels@chanzuckerberg.com.

Special thank you to Dihan Zheng from Professor Bo Huang's lab for their consultation on this tutorial.

Responsible Use

We are committed to advancing the responsible development and use of artificial intelligence. Please follow our Acceptable Use Policy  when engaging with our services.

Should you have any security or privacy issues or questions related to the services, please reach out to our team at security@chanzuckerberg.com or privacy@chanzuckerberg.com respectively.