Skip to content

Semantic Segmentation Training

This page explains how to train CryoSiam's semantic segmentation model from scratch using your own annotated tomograms.
It follows the four-step CLI workflow:

  1. Ground-truth filtering
  2. Training data preprocessing
  3. Patch creation
  4. Model training

CryoSiam supports large-scale, multi-GPU semantic segmentation training with flexible augmentation, patch sampling, and dense output prediction.


1. Training Workflow Overview

Step 1. Filter Ground-Truth Labels

Ensures the labels are: - cleaned, - within valid ranges, - aligned with tomograms, - converted to the expected class indices.

Command:

cryosiam semantic_filter_ground_truth --config_file=configs/semantic_training.yaml

Step 2. Training Workflow Overview

Normalizes tomograms, aligns labels, optionally applies lamella masks, and prepares the dataset.

Command:

cryosiam semantic_train_preprocess --config_file=configs/semantic_training.yaml

Step 3 — Create Training Patches

CryoSiam splits full tomograms into overlapping 3D patches for efficient GPU training.

Command:

cryosiam semantic_train_create_patches --config_file=configs/semantic_training.yaml

Step 4 — Train the Model

Runs the full training pipeline using PyTorch Lightning.

Command:

cryosiam semantic_train --config_file=configs/semantic_training.yaml

Produces:

  • trained model .ckpt
  • training logs
  • validation reports
  • tensorboard files

2. Expected Folder Structure

dataset/
├── raw/                 # original tomograms
├── labels/              # ground-truth semantic segmentations
├── processed/           # preprocessed tomograms & labels
├── patches/             # training patches created here
└── logs/                # training logs & checkpoints

3. Example Configuration (semantic_training.yaml)

data_folder: '/g/zaugg/stojanov/simulated_datasets/final_models/dense_simsiam_regression/predictions/' labels_folder: '/g/mahamid/stojanov/simulated_datasets/simulated_data/tomograms/semantic_gt_for_training' patches_folder: '/scratch/stojanov/simulated_datasets/final_models/dense_simsiam_semantic_complexes/patches_denoised_128' temp_dir: '/scratch/stojanov/simulated_datasets/final_models/dense_simsiam_semantic_complexes' log_dir: '/g/zaugg/stojanov/simulated_datasets/final_models/dense_simsiam_semantic_complexes' prediction_folder: '/g/zaugg/stojanov/simulated_datasets/final_models/dense_simsiam_semantic_complexes/predictions' pretrained_model: '/g/zaugg/stojanov/simulated_datasets/experiments/dense_simsiam/version_1/model/last.ckpt' file_extension: '.mrc'

train_files: ['sample_1.mrc', 'sample_2.mrc', 'sample_3.mrc', 'sample_4.mrc', 'sample_5.mrc', 'sample_6.mrc', 'sample_7.mrc', 'sample_8.mrc', 'sample_9.mrc', 'sample_10.mrc', 'sample_11.mrc', 'sample_12.mrc', 'sample_13.mrc', 'sample_14.mrc', 'sample_15.mrc', 'sample_16.mrc', 'sample_17.mrc', 'sample_18.mrc', 'sample_19.mrc', 'sample_20.mrc', 'sample_21.mrc', 'sample_22.mrc', 'sample_23.mrc', 'sample_24.mrc', 'sample_25.mrc', 'sample_26.mrc', 'sample_27.mrc', 'sample_28.mrc', 'sample_29.mrc', 'sample_30.mrc']

test_files: null val_files: null validation_ratio: 0.1

continue_training: False

parameters: nodes: 1 gpu_devices: 8 data: patch_size: [ 128, 128, 128 ] patch_overlap: 0.5 min: 0 max: 1 mean: 0 std: 1 transforms: low_pass_sigma_range: [ 0.5, 2 ] high_pass_sigma_range: [ 0.1, 0.5 ] high_pass_sigma2_range: [ 4, 5 ] noise_sigma_range: [ 0.1, 0.5 ] combine_transforms: True use_noisy_input: False scale_intensity_factors: null network: in_channels: 1 spatial_dims: 3 out_channels: 89 dense_dim: 64 filters: [ 32, 64 ] kernel_size: 3 padding: 1 threshold: 0.5 distance_prediction: True use_dice_loss: True unfreeze_decoder: True unfreeze_backbone: True

hyper_parameters: cache_rate: 0 val_interval: 1 batch_size: 3 optimizer: 'adamw' lr: 0.001 momentum: 0.9 weight_decay: 0.00001 max_epochs: 200

⚙️ 4. Config Reference 🔹 Top-level Keys Key Type Must change? Description data_folder str ✓ Tomograms or precomputed predictions used as input. labels_folder str ✓ Ground-truth semantic labels. patches_folder str ✓ Where training patches will be saved. temp_dir str ✗ Temporary folder for intermediate files. log_dir str ✗ Logs, checkpoints, TensorBoard. prediction_folder str ✗ Output folder for prediction checks. pretrained_model str ✗ Optional starting point for training. file_extension str ✗ .mrc or .rec file format. train_files list[str] ✓/✗ Explicit list of training tomograms. validation_ratio float ✗ Auto-split for validation. continue_training bool ✗ Resume previous training run. 🔹 parameters Hardware Key Type Description nodes int Number of compute nodes. gpu_devices int/list[int] GPUs per node. Data Key Type Description patch_size list[int] 3D patch size patch_overlap float Fraction of overlap min/max float Intensity clipping mean/std float Normalization Augmentations Transform Purpose low_pass_sigma_range Gaussian blur high_pass_sigma_range High-pass filtering noise_sigma_range Add Gaussian noise combine_transforms Mix augmentations use_noisy_input Train with noisy tomogram Network Key Type Purpose in_channels int Usually 1 spatial_dims int Always 3 out_channels int Number of semantic labels filters list UNet filter sizes dense_dim int Bottleneck distance_prediction bool Multi-head distance + label use_dice_loss bool Balanced segmentation loss 🎯 Best Practices

Use ≥4 GPUs for 128³ patches

Increase patch_overlap for smoother boundaries

Use augmentation to reduce overfitting

Always validate on unseen tomograms

Start with a pretrained model when possible