Semantic Segmentation Training¶
This page explains how to train CryoSiam's semantic segmentation model from scratch using your own annotated tomograms.
It follows the four-step CLI workflow:
- Ground-truth filtering
- Training data preprocessing
- Patch creation
- Model training
CryoSiam supports large-scale, multi-GPU semantic segmentation training with flexible augmentation, patch sampling, and dense output prediction.
1. Training Workflow Overview¶
Step 1. Filter Ground-Truth Labels¶
Ensures the labels are: - cleaned, - within valid ranges, - aligned with tomograms, - converted to the expected class indices.
Command:
cryosiam semantic_filter_ground_truth --config_file=configs/semantic_training.yaml
Step 2. Training Workflow Overview¶
Normalizes tomograms, aligns labels, optionally applies lamella masks, and prepares the dataset.
Command:
cryosiam semantic_train_preprocess --config_file=configs/semantic_training.yaml
Step 3 — Create Training Patches¶
CryoSiam splits full tomograms into overlapping 3D patches for efficient GPU training.
Command:
cryosiam semantic_train_create_patches --config_file=configs/semantic_training.yaml
Step 4 — Train the Model¶
Runs the full training pipeline using PyTorch Lightning.
Command:
cryosiam semantic_train --config_file=configs/semantic_training.yaml
Produces:
- trained model .ckpt
- training logs
- validation reports
- tensorboard files
2. Expected Folder Structure¶
dataset/
├── raw/ # original tomograms
├── labels/ # ground-truth semantic segmentations
├── processed/ # preprocessed tomograms & labels
├── patches/ # training patches created here
└── logs/ # training logs & checkpoints
3. Example Configuration (semantic_training.yaml)¶
data_folder: '/g/zaugg/stojanov/simulated_datasets/final_models/dense_simsiam_regression/predictions/' labels_folder: '/g/mahamid/stojanov/simulated_datasets/simulated_data/tomograms/semantic_gt_for_training' patches_folder: '/scratch/stojanov/simulated_datasets/final_models/dense_simsiam_semantic_complexes/patches_denoised_128' temp_dir: '/scratch/stojanov/simulated_datasets/final_models/dense_simsiam_semantic_complexes' log_dir: '/g/zaugg/stojanov/simulated_datasets/final_models/dense_simsiam_semantic_complexes' prediction_folder: '/g/zaugg/stojanov/simulated_datasets/final_models/dense_simsiam_semantic_complexes/predictions' pretrained_model: '/g/zaugg/stojanov/simulated_datasets/experiments/dense_simsiam/version_1/model/last.ckpt' file_extension: '.mrc'
train_files: ['sample_1.mrc', 'sample_2.mrc', 'sample_3.mrc', 'sample_4.mrc', 'sample_5.mrc', 'sample_6.mrc', 'sample_7.mrc', 'sample_8.mrc', 'sample_9.mrc', 'sample_10.mrc', 'sample_11.mrc', 'sample_12.mrc', 'sample_13.mrc', 'sample_14.mrc', 'sample_15.mrc', 'sample_16.mrc', 'sample_17.mrc', 'sample_18.mrc', 'sample_19.mrc', 'sample_20.mrc', 'sample_21.mrc', 'sample_22.mrc', 'sample_23.mrc', 'sample_24.mrc', 'sample_25.mrc', 'sample_26.mrc', 'sample_27.mrc', 'sample_28.mrc', 'sample_29.mrc', 'sample_30.mrc']
test_files: null val_files: null validation_ratio: 0.1
continue_training: False
parameters: nodes: 1 gpu_devices: 8 data: patch_size: [ 128, 128, 128 ] patch_overlap: 0.5 min: 0 max: 1 mean: 0 std: 1 transforms: low_pass_sigma_range: [ 0.5, 2 ] high_pass_sigma_range: [ 0.1, 0.5 ] high_pass_sigma2_range: [ 4, 5 ] noise_sigma_range: [ 0.1, 0.5 ] combine_transforms: True use_noisy_input: False scale_intensity_factors: null network: in_channels: 1 spatial_dims: 3 out_channels: 89 dense_dim: 64 filters: [ 32, 64 ] kernel_size: 3 padding: 1 threshold: 0.5 distance_prediction: True use_dice_loss: True unfreeze_decoder: True unfreeze_backbone: True
hyper_parameters: cache_rate: 0 val_interval: 1 batch_size: 3 optimizer: 'adamw' lr: 0.001 momentum: 0.9 weight_decay: 0.00001 max_epochs: 200
⚙️ 4. Config Reference 🔹 Top-level Keys Key Type Must change? Description data_folder str ✓ Tomograms or precomputed predictions used as input. labels_folder str ✓ Ground-truth semantic labels. patches_folder str ✓ Where training patches will be saved. temp_dir str ✗ Temporary folder for intermediate files. log_dir str ✗ Logs, checkpoints, TensorBoard. prediction_folder str ✗ Output folder for prediction checks. pretrained_model str ✗ Optional starting point for training. file_extension str ✗ .mrc or .rec file format. train_files list[str] ✓/✗ Explicit list of training tomograms. validation_ratio float ✗ Auto-split for validation. continue_training bool ✗ Resume previous training run. 🔹 parameters Hardware Key Type Description nodes int Number of compute nodes. gpu_devices int/list[int] GPUs per node. Data Key Type Description patch_size list[int] 3D patch size patch_overlap float Fraction of overlap min/max float Intensity clipping mean/std float Normalization Augmentations Transform Purpose low_pass_sigma_range Gaussian blur high_pass_sigma_range High-pass filtering noise_sigma_range Add Gaussian noise combine_transforms Mix augmentations use_noisy_input Train with noisy tomogram Network Key Type Purpose in_channels int Usually 1 spatial_dims int Always 3 out_channels int Number of semantic labels filters list UNet filter sizes dense_dim int Bottleneck distance_prediction bool Multi-head distance + label use_dice_loss bool Balanced segmentation loss 🎯 Best Practices
Use ≥4 GPUs for 128³ patches
Increase patch_overlap for smoother boundaries
Use augmentation to reduce overfitting
Always validate on unseen tomograms
Start with a pretrained model when possible