Quickstart: Ensemble 3D UNet Soup
Estimated time to complete: 30 minutes
Learning Goals
- Learn about model inputs and outputs
- Run inference with an ensemble of 3D U-Net models to identify particles within tomograms
Prerequisites
- T4 GPU
- Python 3.8+
Introduction
The Ensemble 3D UNet Soup model is a weighted ensemble (model soup) of multiple 3D U-Net architectures (tiny, medium, large). Each model is pre-trained on simulated data and fine-tuned on real annotated tomograms, with ensemble and test-time augmentations used for performance optimization. This model achieved 8th place in the CZII CryoET Object Identification Kaggle competition.
The model identifies 6 different particle types:
- apo-ferritin (60Å radius)
- beta-amylase (65Å radius) - limited testing
- beta-galactosidase (90Å radius)
- ribosome (150Å radius)
- thyroglobulin (130Å radius)
- virus-like-particle (135Å radius)
Input/Output Specifications
- Input: 3D cryo-electron tomograms in copick format (.zarr files)
- Processing: Sliding window inference with test-time augmentation and watershed post-processing
- Output: CSV file with particle coordinates (x, y, z)
Data Requirements for Your Own Data
To use this model with your own data, you need:
- Copick configuration file pointing to your tomogram data
- Tomogram data in
.zarr
format
Setup
Setup Google Colab
To start, connect to the T4 GPU runtime hosted for free by Google Colab:
- Click on
Runtime
→Change runtime type
- Select
T4 GPU
under Hardware accelerator - Click
Save
- Restart after downloading requirements
Setup Local Environment
If running locally, ensure you have:
- NVIDIA GPU with CUDA drivers installed
- Python 3.8+ environment
- At least 16GB VRAM available
Clone Repository and Install Dependencies
Estimated installation time: Up to 10 minutes
Note that you will need to restart the session after installing dependencies.
# Clone the repository
!git clone https://github.com/IAmPara0x/czii-8th-solution.git
%cd czii-8th-solution
# Install all required dependencies (this may take 5-10 minutes)
!pip install -r requirements.txt
# Restart the notebook after installing dependencies
# Reset the directory
%cd /content/czii-8th-solution
Output:
/content/czii-8th-solution
Download Model Weights
# Download model weights using gdown
!pip install -q gdown
# Create model_weights directory
!mkdir -p model_weights
!gdown 1qBtxL0iT_8n2gvTe2etTYj0lnhmgF1nw -O model_weights/tiny_unet_soup.pth
!gdown 1qDsSDnyX8-KU18AFylqHbutGO-TKLqaW -O model_weights/medium_unet_soup.pth
!gdown 1Jw6FBH1OGQ9SJ4hmf4CNDGOwGLqn5HEz -O model_weights/large_unet-soup-folds-69-86-99.pth
!gdown 1aEz-EF_jwSCb1B2ATPWH-v97xlMyuM7g -O model_weights/large_unet-ema-soup.pth
print("Model weights downloaded successfully!")
Download Tomogram Data
Estimated download time: Up to 3 minutes
# Download one tomogram for testing
!gdown --folder 1HKCrVvuIvjLXBhNICW6FiSfR5cFTSbNH
print("Tomogram data downloaded successfully!")
Run Model Inference
Now we'll demonstrate how to run the ensemble model on cryo-electron tomography data. The process involves:
- Model Setup - Import libraries and setup model configuration and constants
- Model Loading - Load the ensemble of trained models
- Data Setup - Prepare your copick configuration and tomogram data
- Inference - Run sliding window inference with test-time augmentation (90 seconds per tomogram)
- Post-processing - Apply watershed segmentation to extract particle locations
- Output - Generate CSV file with particle coordinates
Step 1: Setup Model Configuration and Constants
# Import inference functions and dependencies
import sys
sys.path.append('.')
# from github repo
from quick_start_aux_functions import load_models, infer_tomograph_locations_watershed_ensemble, ModelSpec, Model
# Import other necessary libraries
import copick
import pandas as pd
import torch
import numpy as np
import json
# Model configurations for different UNet architectures
model_config_tiny = {
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 7,
"channels": (32, 64, 128, 128),
"strides": (2, 2, 1),
"num_res_units": 1,
}
model_config_medium = {
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 7,
"channels": (32, 64, 128, 256),
"strides": (2, 2, 1),
"num_res_units": 2,
}
model_config_large = {
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 7,
"channels": (32, 96, 256, 384),
"strides": (2, 2, 1),
"num_res_units": 2,
}
LABELS_7 = ["apo-ferritin", "beta-amylase", "beta-galactosidase", "ribosome", "thyroglobulin", "virus-like-particle"]
PARTICLE_RADIUS_7 = {
"apo-ferritin": 60,
"beta-amylase": 65,
"beta-galactosidase": 90,
"ribosome": 150,
"thyroglobulin": 130,
"virus-like-particle": 135
}
Step 2: Load the Ensemble of Trained Models
model_specs = [
ModelSpec("model_weights/tiny_unet_soup.pth", model_config_tiny, False),
ModelSpec("model_weights/medium_unet_soup.pth", model_config_medium, False),
ModelSpec("model_weights/large_unet-soup-folds-69-86-99.pth", model_config_large, True),
ModelSpec("model_weights/large_unet-ema-soup.pth", model_config_large, True),
]
device_id = 0
models = load_models(model_specs, device_id)
Output:
can't find state_dict key, loading checkpoint directly
model does NOT uses ema, discarding ema_model
can't find state_dict key, loading checkpoint directly
model does NOT uses ema, discarding ema_model
can't find state_dict key, loading checkpoint directly
model does uses ema, swapping ema_model into model
can't find state_dict key, loading checkpoint directly
model does uses ema, swapping ema_model into model
Step 3: Create Copick Configuration File
# Create a copick configuration file for your tomogram data
copick_config = {
"name": "czii_cryoet_mlchallenge_2024",
"description": "2024 CZII CryoET ML Challenge training data.",
"version": "1.0.0",
"pickable_objects": [
{
"name": "apo-ferritin",
"is_particle": True,
"pdb_id": "4V1W",
"label": 1,
"color": [0, 117, 220, 128],
"radius": 60,
"map_threshold": 0.0418
},
{
"name": "beta-amylase",
"is_particle": True,
"pdb_id": "1FA2",
"label": 2,
"color": [153, 63, 0, 128],
"radius": 65,
"map_threshold": 0.035
},
{
"name": "beta-galactosidase",
"is_particle": True,
"pdb_id": "6X1Q",
"label": 3,
"color": [76, 0, 92, 128],
"radius": 90,
"map_threshold": 0.0578
},
{
"name": "ribosome",
"is_particle": True,
"pdb_id": "6EK0",
"label": 4,
"color": [0, 92, 49, 128],
"radius": 150,
"map_threshold": 0.0374
},
{
"name": "thyroglobulin",
"is_particle": True,
"pdb_id": "6SCJ",
"label": 5,
"color": [43, 206, 72, 128],
"radius": 130,
"map_threshold": 0.0278
},
{
"name": "virus-like-particle",
"is_particle": True,
"pdb_id": "6N4V",
"label": 6,
"color": [255, 204, 153, 128],
"radius": 135,
"map_threshold": 0.201
}
],
"overlay_root": "./data/overlay",
"overlay_fs_args": {
"auto_mkdir": True
},
"static_root": "./tomograms/test/static"
}
with open("test_copick.config", "w") as f:
json.dump(copick_config, f, indent=2)
Step 4: Run Inference on Tomograms
Estimated inference time: 90 seconds per tomogram
# Run inference
root = copick.from_file("/content/czii-8th-solution/test_copick.config")
all_runs = root.runs
print(f"Found {len(all_runs)} tomogram runs to process")
voxel_size = 10
tomo_type = "denoised"
# This will process all runs and save results to submission.csv
print("Starting inference on all tomograms...")
# Run inference on all tomograms
locs = []
for i, run in enumerate(all_runs):
print(f"Processing tomogram {i+1}/{len(all_runs)}: {run.name}")
# Get the tomogram data at 10Å voxel spacing
tomo = run.get_voxel_spacing(voxel_size)
tomo_data = tomo.get_tomograms(tomo_type)[0].numpy()
# Run ensemble inference with test-time augmentation and watershed post-processing
loc_df = infer_tomograph_locations_watershed_ensemble(models, tomo_data, run.name)
locs.append(loc_df)
# Clear GPU cache to prevent memory issues
torch.cuda.empty_cache()
print(f"Completed inference on {run.name}")
# Combine all results and save to CSV
final_results = pd.concat(locs)
final_results.to_csv("submission.csv", index=False)
Output:
Found 1 tomogram runs to process
Starting inference on all tomograms...
Processing tomogram 1/1: TS_5_4
running TS_5_4 with {'roi_size': (160, 384, 384), 'sw_batch_size': 1, 'overlap': 0.25, 'mode': 'gaussian', 'padding_mode': 'reflect'}
running TS_5_4 with {'roi_size': (160, 384, 384), 'sw_batch_size': 1, 'overlap': 0.25, 'mode': 'gaussian', 'padding_mode': 'reflect'}
running TS_5_4 with {'roi_size': (160, 384, 384), 'sw_batch_size': 1, 'overlap': 0.25, 'mode': 'gaussian', 'padding_mode': 'reflect'}
running TS_5_4 with {'roi_size': (160, 384, 384), 'sw_batch_size': 1, 'overlap': 0.25, 'mode': 'gaussian', 'padding_mode': 'reflect'}
6it [00:31, 5.27s/it]
Completed inference on TS_5_4
Explore Model Outputs
The ensemble model produces the following output:
Primary Output: CSV File
- File:
submission.csv
- Columns:
experiment_name
: Name of the tomogramparticle_type
: One of the 6 supported particle typesx
,y
,z
: 3D coordinates in Angstroms
# Examine output
final_results
Output:
experiment particle_type x y z
0 TS_5_4 apo-ferritin 129.126002 4403.058563 380.127615
1 TS_5_4 apo-ferritin 355.100429 5900.885071 552.504864
2 TS_5_4 apo-ferritin 468.975093 5923.138979 598.022406
3 TS_5_4 apo-ferritin 519.783592 4278.294706 654.674033
4 TS_5_4 apo-ferritin 554.883747 4369.815789 598.542861
.. ... ... ... ... ...
216 TS_5_4 virus-like-particle 2631.654423 4222.250606 971.266489
217 TS_5_4 virus-like-particle 3003.754538 4954.748190 1172.217550
218 TS_5_4 virus-like-particle 3135.223016 3576.066420 376.307184
219 TS_5_4 virus-like-particle 3295.876607 3028.801345 676.558135
220 TS_5_4 virus-like-particle 3443.012295 6185.960700 1022.491603
[221 rows x 5 columns]
print(final_results['particle_type'].value_counts())
Output:
particle_type
thyroglobulin 78
apo-ferritin 56
ribosome 43
beta-galactosidase 33
virus-like-particle 11
Name: count, dtype: int64
Visualize Particle Localization
colors = {
'apo-ferritin': 'cyan',
'beta-amylase': 'orange',
'beta-galactosidase': 'red',
'ribosome': 'blue',
'thyroglobulin': 'green',
'virus-like-particle': 'magenta'
}
import copick
import matplotlib.pyplot as plt
import json
import pandas as pd
import numpy as np
# Load tomogram using copick (same as in inference)
root = copick.from_file("/content/czii-8th-solution/test_copick.config")
# Get the single run directly
run_name = 'TS_5_4'
target_run = root.get_run(run_name)
# Get tomogram data using copick (same parameters as inference)
voxel_size = 10
tomo_type = "denoised"
tomo = target_run.get_voxel_spacing(voxel_size)
tomo_data = tomo.get_tomograms(tomo_type)[0].numpy()
slice_idx = 60
z_min = slice_idx - 5
z_max = slice_idx + 5
base_gt_path = '/content/czii-8th-solution/tomograms/train/overlay/ExperimentRuns/TS_5_4/Picks/'
particle_types = list(colors.keys())
print(f"Visualizing slice {slice_idx}, particles in z-range [{z_min}, {z_max})")
print(f"Tomogram shape: {tomo_data.shape}")
# Load ground truth coordinates
ground_truth_coords = {}
for pt in particle_types:
try:
with open(f'{base_gt_path}{pt}.json', 'r') as f:
x, y = [], []
data = json.load(f)['points']
for p in data:
z = float(p['location']['z']/10)
if z_min <= z < z_max:
x.append(float(p['location']['x'])/10)
y.append(float(p['location']['y'])/10)
ground_truth_coords[pt] = (x, y)
print(f"Found {len(x)} {pt} GT points")
except:
print(f"{pt} ground truth not found")
# Load predictions
run_preds = final_results[final_results['experiment'] == run_name]
z_filtered_preds = run_preds[(run_preds['z']/10 >= z_min) & (run_preds['z']/10 < z_max)]
pred_coords = {}
for pt in particle_types:
type_preds = z_filtered_preds[z_filtered_preds['particle_type'] == pt]
pred_coords[pt] = ((type_preds['x'] / 10).tolist(), (type_preds['y'] / 10).tolist())
print(f"Found {len(type_preds)} {pt} predictions")
# Get the slice from copick-loaded data
tomo_slice = tomo_data[slice_idx]
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
cmap_args = {'cmap':'gray', 'vmin':-0.00005, 'vmax':0.00005}
axes[0].imshow(tomo_slice, **cmap_args)
axes[0].set_title(f'Original Tomogram - Slice {slice_idx}')
axes[0].axis('off')
axes[1].imshow(tomo_slice, **cmap_args)
axes[1].set_title('Ground Truth Labels')
for pt, (x, y) in ground_truth_coords.items():
if x and y:
axes[1].scatter(x, y, c=colors[pt], s=30, alpha=0.8, marker='o',
edgecolors='white', linewidth=0.5, label=f'{pt} GT ({len(x)})')
axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
axes[1].axis('off')
axes[2].imshow(tomo_slice, **cmap_args)
axes[2].set_title('Model Predictions')
for pt, (x, y) in pred_coords.items():
if x and y:
axes[2].scatter(x, y, c=colors[pt], s=30, alpha=0.8, marker='x',
linewidth=2, label=f'{pt} Pred ({len(x)})')
axes[2].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
axes[2].axis('off')
plt.tight_layout()
plt.show()
Output:
Visualizing slice 60, particles in z-range [55, 65)
Tomogram shape: (184, 630, 630)
apo-ferritin ground truth not found
beta-amylase ground truth not found
beta-galactosidase ground truth not found
ribosome ground truth not found
thyroglobulin ground truth not found
virus-like-particle ground truth not found
Found 11 apo-ferritin predictions
Found 0 beta-amylase predictions
Found 4 beta-galactosidase predictions
Found 1 ribosome predictions
Found 6 thyroglobulin predictions
Found 1 virus-like-particle predictions

Contact and Acknowledgments
For issues with this quickstart please contact Sergio Alvarez at sasjsergioalvarezjunior@gmail.com.
Special thank you to Karyna Rosario Cora and Dannielle McCarthy for their consultation on this quickstart.
References
Responsible Use
We are committed to advancing the responsible development and use of artificial intelligence. Please follow our Acceptable Use Policy when engaging with our services.