Quickstart: Ensemble 3D UNet Soup

Estimated time to complete: 30 minutes

Learning Goals

  • Learn about model inputs and outputs
  • Run inference with an ensemble of 3D U-Net models to identify particles within tomograms

Prerequisites

  • T4 GPU
  • Python 3.8+

Introduction

The Ensemble 3D UNet Soup model is a weighted ensemble (model soup) of multiple 3D U-Net architectures (tiny, medium, large). Each model is pre-trained on simulated data and fine-tuned on real annotated tomograms, with ensemble and test-time augmentations used for performance optimization. This model achieved 8th place in the CZII CryoET Object Identification Kaggle competition.

The model identifies 6 different particle types:

  • apo-ferritin (60Å radius)
  • beta-amylase (65Å radius) - limited testing
  • beta-galactosidase (90Å radius)
  • ribosome (150Å radius)
  • thyroglobulin (130Å radius)
  • virus-like-particle (135Å radius)

Input/Output Specifications

  • Input: 3D cryo-electron tomograms in copick format (.zarr files)
  • Processing: Sliding window inference with test-time augmentation and watershed post-processing
  • Output: CSV file with particle coordinates (x, y, z)

Data Requirements for Your Own Data

To use this model with your own data, you need:

  1. Copick configuration file pointing to your tomogram data
  2. Tomogram data in .zarr format

Setup

Setup Google Colab

To start, connect to the T4 GPU runtime hosted for free by Google Colab:

  1. Click on RuntimeChange runtime type
  2. Select T4 GPU under Hardware accelerator
  3. Click Save
  4. Restart after downloading requirements

Setup Local Environment

If running locally, ensure you have:

  • NVIDIA GPU with CUDA drivers installed
  • Python 3.8+ environment
  • At least 16GB VRAM available

Clone Repository and Install Dependencies

Estimated installation time: Up to 10 minutes

Note that you will need to restart the session after installing dependencies.

# Clone the repository
!git clone https://github.com/IAmPara0x/czii-8th-solution.git
%cd czii-8th-solution

# Install all required dependencies (this may take 5-10 minutes)
!pip install -r requirements.txt

# Restart the notebook after installing dependencies
# Reset the directory
%cd /content/czii-8th-solution
Output:

/content/czii-8th-solution

Download Model Weights

# Download model weights using gdown
!pip install -q gdown

# Create model_weights directory
!mkdir -p model_weights

!gdown 1qBtxL0iT_8n2gvTe2etTYj0lnhmgF1nw -O model_weights/tiny_unet_soup.pth
!gdown 1qDsSDnyX8-KU18AFylqHbutGO-TKLqaW -O model_weights/medium_unet_soup.pth
!gdown 1Jw6FBH1OGQ9SJ4hmf4CNDGOwGLqn5HEz -O model_weights/large_unet-soup-folds-69-86-99.pth
!gdown 1aEz-EF_jwSCb1B2ATPWH-v97xlMyuM7g -O model_weights/large_unet-ema-soup.pth

print("Model weights downloaded successfully!")

Download Tomogram Data

Estimated download time: Up to 3 minutes

# Download one tomogram for testing
!gdown --folder 1HKCrVvuIvjLXBhNICW6FiSfR5cFTSbNH
print("Tomogram data downloaded successfully!")

Run Model Inference

Now we'll demonstrate how to run the ensemble model on cryo-electron tomography data. The process involves:

  1. Model Setup - Import libraries and setup model configuration and constants
  2. Model Loading - Load the ensemble of trained models
  3. Data Setup - Prepare your copick configuration and tomogram data
  4. Inference - Run sliding window inference with test-time augmentation (90 seconds per tomogram)
  5. Post-processing - Apply watershed segmentation to extract particle locations
  6. Output - Generate CSV file with particle coordinates

Step 1: Setup Model Configuration and Constants

# Import inference functions and dependencies
import sys
sys.path.append('.')

# from github repo
from quick_start_aux_functions import load_models, infer_tomograph_locations_watershed_ensemble, ModelSpec, Model

# Import other necessary libraries
import copick
import pandas as pd
import torch
import numpy as np
import json
# Model configurations for different UNet architectures
model_config_tiny = {
    "spatial_dims": 3,
    "in_channels": 1,
    "out_channels": 7,
    "channels": (32, 64, 128, 128),
    "strides": (2, 2, 1),
    "num_res_units": 1,
}

model_config_medium = {
    "spatial_dims": 3,
    "in_channels": 1,
    "out_channels": 7,
    "channels": (32, 64, 128, 256),
    "strides": (2, 2, 1),
    "num_res_units": 2,
}

model_config_large = {
    "spatial_dims": 3,
    "in_channels": 1,
    "out_channels": 7,
    "channels": (32, 96, 256, 384),
    "strides": (2, 2, 1),
    "num_res_units": 2,
}

LABELS_7 = ["apo-ferritin", "beta-amylase", "beta-galactosidase", "ribosome", "thyroglobulin", "virus-like-particle"]

PARTICLE_RADIUS_7 = {
    "apo-ferritin": 60,
    "beta-amylase": 65,
    "beta-galactosidase": 90,
    "ribosome": 150,
    "thyroglobulin": 130,
    "virus-like-particle": 135
}

Step 2: Load the Ensemble of Trained Models

model_specs = [
    ModelSpec("model_weights/tiny_unet_soup.pth", model_config_tiny, False),
    ModelSpec("model_weights/medium_unet_soup.pth", model_config_medium, False),
    ModelSpec("model_weights/large_unet-soup-folds-69-86-99.pth", model_config_large, True),
    ModelSpec("model_weights/large_unet-ema-soup.pth", model_config_large, True),
]

device_id = 0
models = load_models(model_specs, device_id)
Output:

can't find state_dict key, loading checkpoint directly
model does NOT uses ema, discarding ema_model
can't find state_dict key, loading checkpoint directly
model does NOT uses ema, discarding ema_model
can't find state_dict key, loading checkpoint directly
model does uses ema, swapping ema_model into model
can't find state_dict key, loading checkpoint directly
model does uses ema, swapping ema_model into model

Step 3: Create Copick Configuration File

# Create a copick configuration file for your tomogram data

copick_config = {
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",
    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": True,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
            "name": "beta-amylase",
            "is_particle": True,
            "pdb_id": "1FA2",
            "label": 2,
            "color": [153, 63, 0, 128],
            "radius": 65,
            "map_threshold": 0.035
        },
        {
            "name": "beta-galactosidase",
            "is_particle": True,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [76, 0, 92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": True,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [0, 92, 49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": True,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [43, 206, 72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": True,
            "pdb_id": "6N4V",
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        }
    ],
    "overlay_root": "./data/overlay",
    "overlay_fs_args": {
        "auto_mkdir": True
    },
    "static_root": "./tomograms/test/static"
}

with open("test_copick.config", "w") as f:
    json.dump(copick_config, f, indent=2)

Step 4: Run Inference on Tomograms

Estimated inference time: 90 seconds per tomogram

# Run inference

root = copick.from_file("/content/czii-8th-solution/test_copick.config")
all_runs = root.runs

print(f"Found {len(all_runs)} tomogram runs to process")

voxel_size = 10
tomo_type = "denoised"

# This will process all runs and save results to submission.csv
print("Starting inference on all tomograms...")

# Run inference on all tomograms
locs = []
for i, run in enumerate(all_runs):
    print(f"Processing tomogram {i+1}/{len(all_runs)}: {run.name}")

    # Get the tomogram data at 10Å voxel spacing
    tomo = run.get_voxel_spacing(voxel_size)
    tomo_data = tomo.get_tomograms(tomo_type)[0].numpy()

    # Run ensemble inference with test-time augmentation and watershed post-processing
    loc_df = infer_tomograph_locations_watershed_ensemble(models, tomo_data, run.name)
    locs.append(loc_df)

    # Clear GPU cache to prevent memory issues
    torch.cuda.empty_cache()
    print(f"Completed inference on {run.name}")

# Combine all results and save to CSV
final_results = pd.concat(locs)
final_results.to_csv("submission.csv", index=False)
Output:

Found 1 tomogram runs to process
Starting inference on all tomograms...
Processing tomogram 1/1: TS_5_4
running TS_5_4 with {'roi_size': (160, 384, 384), 'sw_batch_size': 1, 'overlap': 0.25, 'mode': 'gaussian', 'padding_mode': 'reflect'}
running TS_5_4 with {'roi_size': (160, 384, 384), 'sw_batch_size': 1, 'overlap': 0.25, 'mode': 'gaussian', 'padding_mode': 'reflect'}
running TS_5_4 with {'roi_size': (160, 384, 384), 'sw_batch_size': 1, 'overlap': 0.25, 'mode': 'gaussian', 'padding_mode': 'reflect'}
running TS_5_4 with {'roi_size': (160, 384, 384), 'sw_batch_size': 1, 'overlap': 0.25, 'mode': 'gaussian', 'padding_mode': 'reflect'}

6it [00:31,  5.27s/it]

Completed inference on TS_5_4

Explore Model Outputs

The ensemble model produces the following output:

Primary Output: CSV File

  • File: submission.csv
  • Columns:
    • experiment_name: Name of the tomogram
    • particle_type: One of the 6 supported particle types
    • x, y, z: 3D coordinates in Angstroms
# Examine output
final_results
Output:

    experiment        particle_type            x            y            z
0       TS_5_4         apo-ferritin   129.126002  4403.058563   380.127615
1       TS_5_4         apo-ferritin   355.100429  5900.885071   552.504864
2       TS_5_4         apo-ferritin   468.975093  5923.138979   598.022406
3       TS_5_4         apo-ferritin   519.783592  4278.294706   654.674033
4       TS_5_4         apo-ferritin   554.883747  4369.815789   598.542861
..         ...                  ...          ...          ...          ...
216     TS_5_4  virus-like-particle  2631.654423  4222.250606   971.266489
217     TS_5_4  virus-like-particle  3003.754538  4954.748190  1172.217550
218     TS_5_4  virus-like-particle  3135.223016  3576.066420   376.307184
219     TS_5_4  virus-like-particle  3295.876607  3028.801345   676.558135
220     TS_5_4  virus-like-particle  3443.012295  6185.960700  1022.491603

[221 rows x 5 columns]
print(final_results['particle_type'].value_counts())
Output:

particle_type
thyroglobulin          78
apo-ferritin           56
ribosome               43
beta-galactosidase     33
virus-like-particle    11
Name: count, dtype: int64

Visualize Particle Localization

colors = {
    'apo-ferritin': 'cyan',
    'beta-amylase': 'orange',
    'beta-galactosidase': 'red',
    'ribosome': 'blue',
    'thyroglobulin': 'green',
    'virus-like-particle': 'magenta'
}

import copick
import matplotlib.pyplot as plt
import json
import pandas as pd
import numpy as np

# Load tomogram using copick (same as in inference)
root = copick.from_file("/content/czii-8th-solution/test_copick.config")

# Get the single run directly
run_name = 'TS_5_4'
target_run = root.get_run(run_name)

# Get tomogram data using copick (same parameters as inference)
voxel_size = 10
tomo_type = "denoised"
tomo = target_run.get_voxel_spacing(voxel_size)
tomo_data = tomo.get_tomograms(tomo_type)[0].numpy()

slice_idx = 60
z_min = slice_idx - 5
z_max = slice_idx + 5
base_gt_path = '/content/czii-8th-solution/tomograms/train/overlay/ExperimentRuns/TS_5_4/Picks/'
particle_types = list(colors.keys())

print(f"Visualizing slice {slice_idx}, particles in z-range [{z_min}, {z_max})")
print(f"Tomogram shape: {tomo_data.shape}")

# Load ground truth coordinates
ground_truth_coords = {}
for pt in particle_types:
    try:
        with open(f'{base_gt_path}{pt}.json', 'r') as f:
            x, y = [], []
            data = json.load(f)['points']
            for p in data:
                z = float(p['location']['z']/10)
                if z_min <= z < z_max:
                    x.append(float(p['location']['x'])/10)
                    y.append(float(p['location']['y'])/10)
            ground_truth_coords[pt] = (x, y)
            print(f"Found {len(x)} {pt} GT points")
    except:
        print(f"{pt} ground truth not found")

# Load predictions
run_preds = final_results[final_results['experiment'] == run_name]
z_filtered_preds = run_preds[(run_preds['z']/10 >= z_min) & (run_preds['z']/10 < z_max)]

pred_coords = {}
for pt in particle_types:
    type_preds = z_filtered_preds[z_filtered_preds['particle_type'] == pt]
    pred_coords[pt] = ((type_preds['x'] / 10).tolist(), (type_preds['y'] / 10).tolist())
    print(f"Found {len(type_preds)} {pt} predictions")

# Get the slice from copick-loaded data
tomo_slice = tomo_data[slice_idx]

fig, axes = plt.subplots(1, 3, figsize=(18, 6))
cmap_args = {'cmap':'gray', 'vmin':-0.00005, 'vmax':0.00005}

axes[0].imshow(tomo_slice, **cmap_args)
axes[0].set_title(f'Original Tomogram - Slice {slice_idx}')
axes[0].axis('off')

axes[1].imshow(tomo_slice, **cmap_args)
axes[1].set_title('Ground Truth Labels')
for pt, (x, y) in ground_truth_coords.items():
    if x and y:
        axes[1].scatter(x, y, c=colors[pt], s=30, alpha=0.8, marker='o',
                        edgecolors='white', linewidth=0.5, label=f'{pt} GT ({len(x)})')
axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
axes[1].axis('off')

axes[2].imshow(tomo_slice, **cmap_args)
axes[2].set_title('Model Predictions')
for pt, (x, y) in pred_coords.items():
    if x and y:
        axes[2].scatter(x, y, c=colors[pt], s=30, alpha=0.8, marker='x',
                        linewidth=2, label=f'{pt} Pred ({len(x)})')
axes[2].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
axes[2].axis('off')

plt.tight_layout()
plt.show()
Output:

Visualizing slice 60, particles in z-range [55, 65)
Tomogram shape: (184, 630, 630)
apo-ferritin ground truth not found
beta-amylase ground truth not found
beta-galactosidase ground truth not found
ribosome ground truth not found
thyroglobulin ground truth not found
virus-like-particle ground truth not found
Found 11 apo-ferritin predictions
Found 0 beta-amylase predictions
Found 4 beta-galactosidase predictions
Found 1 ribosome predictions
Found 6 thyroglobulin predictions
Found 1 virus-like-particle predictions

Contact and Acknowledgments

For issues with this quickstart please contact Sergio Alvarez at sasjsergioalvarezjunior@gmail.com.

Special thank you to Karyna Rosario Cora and Dannielle McCarthy for their consultation on this quickstart.

References

Responsible Use

We are committed to advancing the responsible development and use of artificial intelligence. Please follow our Acceptable Use Policy when engaging with our services.