A QuickStart Guide to Using Medical Imagery for Machine Learning

Automating an extremely important sector like Health Care with the power of Machine Learning can bring a very great and positive impact to the world.

In this article, let us study how medical datasets work, how we can pre-process them and get them ready for machine learning models as well as explore some models and techniques that are known to work very well with such data.

Basic Terms & Dataset Overview

Medical imagery such as that of captured by a CT/MRI is capable of providing us the full 3D structure of the body part by combining all the 2D images or ‘slices’ given to us as output by them. Usually in medical datasets provided to us, we are given a bunch of the same slices to work on in special image formats.

Two more of the most basic terms one must know-

Study ID: A unique identifier for the procedure done on each patient or subject. Each study ID might have multiple series and series IDs associated to it.

Series ID: Subsets of images within a study ID, so that images can be organised further based on some common properties [e.g. a common plane/view used (explained below)].

Planes/Views Used

These 2D slices can be acquired along three primary anatomical planes:

  • Axial (top to bottom)
  • Sagittal (left to right)
  • Coronal (front to back)

and hence, they would also need to be stacked differently in their respective plane to form the 3D image.

Contrast Mechanisms

Medical scanners can adjust how different parts of the body appear in images by changing their brightness intensity or contrast to make certain areas stand out. These mechanisms that decide which parts would be darker and which ones would be lighter are standardised and also given names (e.g. in in MR scans, we have T1-weighted images where bones appear lighter along with T2-weighted images where bones appear darker).

File Formats

There are various special image formats that are need to be dealt with while working with such a data, like-

DICOM

DICOM stands for Digital Imaging and Communications in Medicine. It’s a crucial standard in medical imaging and uses ‘.dcm’ as its extension. This format also has a bunch of metadata attached to it telling about the patient, scanner used, plane of the image, scanning mechanism used etc. 

A Python library called ‘PyDICOM’ is used to perform operations with these files. Let us try to use it to read a DICOM and access its metadata-

import pydicom# Read DICOM filedcm = pydicom.dcmread("sample_dicom.dcm")# Access metadataprint(f"Patient Name: {dcm.PatientName}")print(f"Modality: {dcm.Modality}")

We can also convert the DICOM to a PNG/JPG to view it easily-

import pydicomimport numpy as npfrom PIL import Imagedef dicom_to_image(dicom_file, output_file):    # Read the DICOM file    dicom = pydicom.dcmread(dicom_file)        # Get pixel array    pixel_array = dicom.pixel_array        # Rescale pixel values to 0-255    pixel_array = pixel_array - np.min(pixel_array)    pixel_array = pixel_array / np.max(pixel_array)    pixel_array = (pixel_array * 255).astype(np.uint8)        # Create PIL Image    image = Image.fromarray(pixel_array)        # Save as PNG or JPG    image.save(output_file)    print(f"Image saved as {output_file}")# Usagedicom_file = "sample_dicom.dcm"output_file = "output_image.png"  # or "output_image.jpg" for JPGdicom_to_image(dicom_file, output_file)
NifTI

The NifTI (Neuroimaging Informatics Technology Initiative) file format is a standard way of storing brain imaging data. It was created to make it easier for researchers and doctors to share and analyze brain scans. Although, now it is commonly used for other scans as well. The extension that this format uses is ‘.nii’ or also ‘.nii.gz’ when it’s compressed.

It is also a common practice to convert your 2D DICOM slices into a single 3D NifTI file. This file can then be viewed using specialised softwares for this domain.

The Python library ‘Nibabel’ is used to deal with this format, let’s try to use it-

import nibabel as nib# Load NIfTI fileimg = nib.load('sample_nifti.nii')# Get image data as numpy arraydata = img.get_fdata()# Access header informationprint(f"Image shape: {img.shape}")print(f"Data type: {img.get_data_dtype()}")# Get affine transformation matrixaffine = img.affineprint(f"Affine matrix:\n{affine}")

Let’s also try to convert DICOM files of the same series into a NifTI file so we can visualise it later using a software-

import osimport pydicomimport nibabel as nibimport numpy as npfrom collections import defaultdictdef dicom_series_to_nifti(dicom_folder, output_file):    # Read all DICOM files in the folder    dicom_files = [pydicom.dcmread(os.path.join(dicom_folder, f))                    for f in os.listdir(dicom_folder)                    if f.endswith('.dcm')]        # Sort files by Instance Number    dicom_files.sort(key=lambda x: int(x.InstanceNumber))        # Extract pixel data and create 3D volume    pixel_arrays = [dcm.pixel_array for dcm in dicom_files]    volume = np.stack(pixel_arrays, axis=-1)        # Get affine matrix    first_slice = dicom_files[0]    pixel_spacing = first_slice.PixelSpacing    slice_thickness = first_slice.SliceThickness    image_position = first_slice.ImagePositionPatient        affine = np.eye(4)    affine[0, 0] = pixel_spacing[0]    affine[1, 1] = pixel_spacing[1]    affine[2, 2] = slice_thickness    affine[:3, 3] = image_position        # Create NIfTI image and save    nifti_image = nib.Nifti1Image(volume, affine)    nib.save(nifti_image, output_file)# Usagedicom_folder = 'dicom_series_folder'output_file = 'output.nii.gz'dicom_series_to_nifti(dicom_folder, output_file)

These are the 2 most popular format but other than them, many more such as MetaImage Header Archive (.mha) are also in use.

Visualisation Softwares

There are various special softwares that exist specially to deal with medical imagery — be it simply viewing them, segmenting them or analysing them.

Some popular softwares include-

  • ITK-SNAP: A platform to create segment structures for 3D medical images
  • 3D Slicer: An open-source platform for medical image informatics, image processing, and three-dimensional visualization
  • OsiriX: A DICOM viewer for medical imaging, popular among radiologists

Popular Tools, Models and Techniques

Study-Level Cropping

One of the most important things while training a model using medical data is to understand the exact problem/disease/condition we are dealing with and read up about it. Many times only a certain region-of-interest is of use to us within the scans that we have due to the nature of the problem.

The crops of these region-of-interests from base scans are called study-level crops. A model works much better if trained on these correct crops instead of the full scan.

Mostly, automating this process of study-level-cropping over an entire dataset requires training a segmentation model which can segment out parts-of-interest from the scans from which we can map out the bounding box that we need.

TotalSegmentator

TotalSegmentator is a tool also available as a Python library based on a nnU-Net which is trained on a bunch of medical data that allows users to segment all major body parts in CT scans and now — MRI scans automatically with just a call.

It has various modes that can be used, its default modes are capable to segment all the parts but other than them, it has specialized ones for different categories of parts as well. Let’s try to use to it to segment out a few DICOM MRI axial slices of lumbar spine that we have-

import osimport numpy as npimport nibabel as nibimport SimpleITK as sitkimport matplotlib.pyplot as pltfrom totalsegmentator.python_api import totalsegmentatordef convert_dicom_to_nifti(dicom_directory, nifti_output_path):    """    Convert DICOM series to NIfTI format.        Args:    dicom_directory (str): Path to the directory containing DICOM files.    nifti_output_path (str): Path where the NIfTI file will be saved.    """    reader = sitk.ImageSeriesReader()    series_ids = reader.GetGDCMSeriesIDs(dicom_directory)        if not series_ids:        raise ValueError("No DICOM series found in the specified directory.")        dicom_files = reader.GetGDCMSeriesFileNames(dicom_directory, series_ids[0])    reader.SetFileNames(dicom_files)    image = reader.Execute()        sitk.WriteImage(image, nifti_output_path)    print(f"Converted DICOM to NIfTI: {nifti_output_path}")def segment_and_visualize(nifti_file, segmentation_output_dir, visualization_output_dir):    """    Perform segmentation on a NIfTI file and visualize the results.        Args:    nifti_file (str): Path to the input NIfTI file.    segmentation_output_dir (str): Directory to store segmentation results.    visualization_output_dir (str): Directory to store visualization images.    """    # Perform segmentation    totalsegmentator(nifti_file, segmentation_output_dir, task="total_mr", verbose=True)        # Create visualization directory    os.makedirs(visualization_output_dir, exist_ok=True)        # Visualize original slices    visualize_slices(nifti_file,                      output_path=os.path.join(visualization_output_dir, "original_slices.png"))        # Visualize vertebrae segmentation    vertebrae_segmentation = os.path.join(segmentation_output_dir, "vertebrae.nii.gz")    visualize_slices(nifti_file,                      segmentation_path=vertebrae_segmentation,                     output_path=os.path.join(visualization_output_dir, "vertebrae_segmentation.png"))def visualize_slices(image_path, segmentation_path=None, output_path=None, num_slices=9):    """    Visualize multiple slices of a 3D image with optional segmentation overlay.        Args:    image_path (str): Path to the NIfTI image file.    segmentation_path (str, optional): Path to the segmentation NIfTI file.    output_path (str, optional): Path to save the visualization.    num_slices (int): Number of slices to visualize.    """    image_data = nib.load(image_path).get_fdata()        if segmentation_path:        segmentation_data = nib.load(segmentation_path).get_fdata()        total_slices = image_data.shape[2]    slice_indices = np.linspace(0, total_slices-1, num_slices, dtype=int)        fig, axes = plt.subplots(3, 3, figsize=(15, 15))    for i, ax in enumerate(axes.flat):        if i < num_slices:            slice_num = slice_indices[i]            ax.imshow(image_data[:, :, slice_num].T, cmap='gray')            if segmentation_path:                ax.imshow(segmentation_data[:, :, slice_num].T, alpha=0.5, cmap='jet')            ax.set_title(f'Slice {slice_num}')            ax.axis('off')        plt.tight_layout()        if output_path:        plt.savefig(output_path)        plt.close(fig)        print(f"Saved visualization to {output_path}")    else:        plt.show()# Main executionif __name__ == "__main__":    # Define paths    dicom_directory = "sample_dicom_slices_folder"    nifti_file = "mri_scan.nii.gz"    segmentation_output_dir = "segmentation_results"    visualization_output_dir = "visualizations"    # Convert DICOM to NIfTI    convert_dicom_to_nifti(dicom_directory, nifti_file)    # Perform segmentation and visualization    segment_and_visualize(nifti_file, segmentation_output_dir, visualization_output_dir)

Let us see the original slices as well as the segmented slices-

Original slices
Segmented slices from TotalSegmentator

MaxViT-UNet

This is an architecture that uses MaxViT (Multi-Axis Vision Transformer) as its backbone which is a model that combines features of both a ViT as well as a convolution model. Along with this, the UNet part allows for precise localization mechanisms — making it good for medical segmentation.

This research paper titled ‘MaxViT-UNet: Multi-Axis Attention for Medical Image Segmentation’ also shows its prowess in the task.

2.5D CNN

This architecture bridges the gap between 2D and 3D but there’s actually nothing ‘2.5D’ about the model itself.

What’s 2.5D is the data that we input into it – a common practice is to stack adjacent slices or select ‘n’ number of equivalent slices from a series and stack them together to give depth to the data and hence, also making it 3D without it actually being 3D.

In case, we are using study-level crops to create a 2.5D image, we can even stack a segmentation mask of the actual part-of-interest on it in case the crops are also capturing noise, this can make the model even better.

Multi-View CNN

If our data provides us with different views, then we can use all of them through a Multi-View CNN or MVCNN which allows for different input streams. All the input streams process the different views seperately, hence capturing all of their information efficiently.

Once the data has gone through the different streams, the information is combined or fused. For the fusion process, we have a lot of options that we can consider — all of them having different strength and weaknesses (e.g. if the different views have some sort of sequence that we’d like to capture, we can use a LSTM or GRU layer, otherwise we can also use a simple concat function).