Source code for unravel.register.reg

#!/usr/bin/env python3

"""
Use ``reg`` from UNRAVEL to register an average template brain/atlas to a resampled autofl brain. 

Prereqs: 
    ``reg_prep``, [``seg_copy_tifs``], & [``seg_brain_mask``]

Inputs:
    - template to register (e.g., gubra_template_CCFv3_30um.nii.gz, an iDISCO/LSFM template in CCFv3 space)
    - reg_inputs/autofl_50um_masked.nii.gz (from ``reg_prep``)
    - atlas/atlas_CCFv3_2020_30um.nii.gz (default; from Allen Brain Institute)

Outputs:
    - reg_outputs/autofl_50um_masked_fixed_reg_input.nii.gz (padded fixed image used for registration with ANTsPy)
    - reg_outputs/<atlas>_in_tissue_space.nii.gz (warped atlas to tissue space for checking reg)
    - transformation matrices and deformation fields in reg_outputs

Note:
    - Images in reg_inputs are not padded.
    - Images in reg_outputs have 15% padding.
    - ort_code is a 3 letter orientation code of the fixed image if not set in fixed_img (e.g., RAS)
    - Letter options: A/P=Anterior/Posterior, L/R=Left/Right, S/I=Superior/Inferior
    - The side of the brain at the positive direction of the x, y, and z axes determines the 3 letters (axis order xyz)

Next commands: 
    ``reg_check`` for assessing registration, ``vstats_prep`` for preparing voxel-wise stats inputs, or ``rstats`` for regional stats.

Usage for tissue registration:
------------------------------
    reg -m <path/template.nii.gz> -bc -sm 0.4 -ort <3 letter orientation code> [-m2 atlas/atlas_CCFv3_2020_30um.nii.gz] [-f reg_inputs/autofl_50um_masked.nii.gz] [-mas reg_inputs/autofl_50um_brain_mask.nii.gz] [-ro reg_outputs] [-bc] [-sm 0.4] [-d list of paths] [-p sample??] [-v]

Usage for atlas to atlas registration:
--------------------------------------
    reg -m <path/atlas1.nii.gz> -f <path/atlas2.nii.gz> -m2 <path/atlas2.nii.gz> [-d list of paths] [-p sample??] [-v]

Usage for template to template registration:
--------------------------------------------
    reg -m <path/template1.nii.gz> -f <path/template2.nii.gz> -m2 <path/template2.nii.gz> -inp linear [-d list of paths] [-p sample??] [-v]
"""

import os
import subprocess
import ants
import nibabel as nib
from ants import n4_bias_field_correction, registration
from pathlib import Path
import numpy as np
from rich import print
from rich.live import Live
from rich.traceback import install
from scipy.ndimage import gaussian_filter

from unravel.image_io.reorient_nii import reorient_nii
from unravel.core.help_formatter import RichArgumentParser, SuppressMetavar, SM
from unravel.core.config import Configuration
from unravel.core.img_io import resolve_path
from unravel.core.img_tools import pad
from unravel.core.utils import log_command, verbose_start_msg, verbose_end_msg, print_func_name_args_times, initialize_progress_bar, get_samples
from unravel.warp.warp import warp


[docs] def parse_args(): parser = RichArgumentParser(formatter_class=SuppressMetavar, add_help=False, docstring=__doc__) reqs = parser.add_argument_group('Required arguments') reqs.add_argument('-m', '--moving_img', help='path/moving_img.nii.gz (e.g., average template optimally matching tissue)', required=True, action=SM) opts = parser.add_argument_group('Optional arguments') opts.add_argument('-f', '--fixed_img', help='reg_inputs/autofl_50um_masked.nii.gz (from ``reg_prep``)', default="reg_inputs/autofl_50um_masked.nii.gz", action=SM) opts.add_argument('-mas', '--mask', help="Brain mask for bias correction. Default: reg_inputs/autofl_50um_brain_mask.nii.gz. or pass in None", default="reg_inputs/autofl_50um_brain_mask.nii.gz", action=SM) opts.add_argument('-ro', '--reg_outputs', help="Name of folder w/ outputs from ``reg`` (e.g., transforms). Default: reg_outputs", default="reg_outputs", action=SM) opts.add_argument('-bc', '--bias_correct', help='Perform N4 bias field correction. Default: False', action='store_true', default=False) opts.add_argument('-sm', '--smooth', help='Sigma value for smoothing the fixed image. Default: 0 for no smoothing. Use 0.4 for autofl', default=0, type=float, action=SM) opts.add_argument('-ort', '--ort_code', help='3 letter orientation code of fixed image if not set in fixed_img (e.g., RAS)', action=SM) opts.add_argument('-m2', '--moving_img2', help='path/atlas.nii.gz (outputs <reg_outputs>/<atlas>_in_tissue_space.nii.gz for checking reg; Default: atlas/atlas_CCFv3_2020_30um.nii.gz)', default='atlas/atlas_CCFv3_2020_30um.nii.gz', action=SM) opts.add_argument('-inp', '--interpol', help='Interpolation method for warping -m2 to padded fixed img space (nearestNeighbor, multiLabel \[default], linear, bSpline)', default="multiLabel", action=SM) opts.add_argument('-it', '--init_time', help='Time in seconds allowed for ``reg_affine_initializer`` to run. Default: 30' , default='30', type=str, action=SM) general = parser.add_argument_group('General arguments') general.add_argument('-d', '--dirs', help='Paths to sample?? dirs and/or dirs containing them (space-separated) for batch processing. Default: current dir', nargs='*', default=None, action=SM) general.add_argument('-p', '--pattern', help='Pattern for directories to process. Default: sample??', default='sample??', action=SM) general.add_argument('-v', '--verbose', help='Increase verbosity. Default: False', action='store_true', default=False) return parser.parse_args()
# TODO: Update padding/unpadding logic to allow for additional padding if needed.
[docs] @print_func_name_args_times() def bias_correction(image_path, mask_path=None, shrink_factor=2, verbose=False): """Perform N4 bias field correction on a .nii.gz and return an ndarray Args: image_path (str): Path to input image.nii.gz mask_path (str): Path to mask image.nii.gz shrink_factor (int): Shrink factor for bias field correction verbose (bool): Print output output_dir (str): Path to save corrected image""" ants_img = ants.image_read(str(image_path)) if mask_path: ants_mask = ants.image_read(str(mask_path)) ants_img_corrected = n4_bias_field_correction(image=ants_img, mask=ants_mask, shrink_factor=shrink_factor, verbose=verbose) else: ants_img_corrected = n4_bias_field_correction(ants_img) ndarray = ants_img_corrected.numpy() return ndarray
[docs] @log_command def main(): install() args = parse_args() Configuration.verbose = args.verbose verbose_start_msg() sample_paths = get_samples(args.dirs, args.pattern, args.verbose) progress, task_id = initialize_progress_bar(len(sample_paths), "[red]Processing samples...") with Live(progress): for sample_path in sample_paths: # Directory with outputs (e.g., transforms) from registration reg_outputs_path = resolve_path(sample_path, args.reg_outputs) reg_outputs_path.mkdir(parents=True, exist_ok=True) # Define inputs and outputs for the fixed image fixed_img_nii_path = resolve_path(sample_path, args.fixed_img) if not fixed_img_nii_path.exists(): print(f"\n [red]The fixed image to be padded for registration ({fixed_img_nii_path}) does not exist. Exiting.\n") import sys ; sys.exit() fixed_img_for_reg = str(Path(args.fixed_img).name).replace(".nii.gz", "_fixed_reg_input.nii.gz") fixed_img_for_reg_path = str(Path(reg_outputs_path, fixed_img_for_reg)) # Preprocess the fixed image if not Path(fixed_img_for_reg_path).exists(): fixed_img_nii = nib.load(fixed_img_nii_path) # Optionally perform bias correction on the fixed image (e.g., when it is an autofluorescence image) if args.bias_correct: print(f'\n Bias correcting the registration input\n') if args.mask != "None": mask_path = resolve_path(sample_path, args.mask) fixed_img = bias_correction(str(fixed_img_nii_path), mask_path=str(mask_path), shrink_factor=2, verbose=args.verbose) elif args.mask == "None": fixed_img = bias_correction(str(fixed_img_nii_path), mask_path=None, shrink_factor=2, verbose=args.verbose) else: fixed_img = fixed_img_nii.get_fdata(dtype=np.float32) # Pad the fixed image with 15% of voxels on all sides (keeps moving img in frame during initial alignment and avoids edge effects) print(f'\n Adding padding to the registration input\n') fixed_img = pad(fixed_img, pad_width=0.15) # Optionally smooth the fixed image (e.g., when it is an autofluorescence image) if args.smooth > 0: print(f'\n Smoothing the registration input\n') fixed_img = gaussian_filter(fixed_img, sigma=args.smooth) # Create NIfTI, set header info, and save the registration input (reference image) print(f'\n Setting header info for the registration input\n') fixed_img = fixed_img.astype(np.float32) # Convert the fixed image to FLOAT32 for ANTsPy reg_inputs_fixed_img_nii = nib.Nifti1Image(fixed_img, fixed_img_nii.affine.copy(), fixed_img_nii.header) reg_inputs_fixed_img_nii.set_data_dtype(np.float32) # Set the orientation of the image (use if not already set correctly in the header; check with ``io_nii_info``) if args.ort_code: reg_inputs_fixed_img_nii = reorient_nii(reg_inputs_fixed_img_nii, args.ort_code, zero_origin=True, apply=False, form_code=1) # Save the fixed input for registration nib.save(reg_inputs_fixed_img_nii, fixed_img_for_reg_path) # Generate the initial transform matrix for aligning the moving image to the fixed image if not Path(reg_outputs_path, f"ANTsPy_init_tform.mat").exists(): # Check if required files exist if not Path(fixed_img_for_reg_path).exists(): print(f"\n [red]The fixed image for registration ({fixed_img_for_reg_path})does not exist. Exiting.\n") import sys ; sys.exit() if not Path(args.moving_img).exists(): print(f"\n [red]The moving image for registration ({args.moving_img}) does not exist. Exiting.\n") import sys ; sys.exit() print(f'\n\n Generating the initial transform matrix for aligning the moving image (e.g., template) to the fixed image (e.g., tissue) \n') command = [ 'reg_affine_initializer', '-f', fixed_img_for_reg_path, '-m', args.moving_img, '-o', str(Path(reg_outputs_path, f"ANTsPy_init_tform.mat")), '-t', args.init_time # Time in seconds allowed for this step. Increase time out duration if needed. ] # Redirect stderr to os.devnull with open(os.devnull, 'w') as devnull: subprocess.run(command, stderr=devnull) # Perform initial approximate alignment of the moving image to the fixed image init_align_out = str(Path(reg_outputs_path, str(Path(args.moving_img).name).replace(".nii.gz", "__initial_alignment_to_fixed_img.nii.gz"))) if not Path(init_align_out).exists(): print(f'\n Applying the initial transform matrix to aligning the moving image to the fixed image \n') fixed_image = ants.image_read(fixed_img_for_reg_path) moving_image = ants.image_read(args.moving_img) transformed_image = ants.apply_transforms( fixed=fixed_image, moving=moving_image, transformlist=[str(Path(reg_outputs_path, f"ANTsPy_init_tform.mat"))] ) ants.image_write(transformed_image, str(Path(reg_outputs_path, init_align_out))) # Define final output and skip processing if it exists output = str(Path(reg_outputs_path, str(Path(args.moving_img).name).replace(".nii.gz", "__warped_to_fixed_image.nii.gz"))) if not Path(output).exists(): # Perform registration (reg is a dict with multiple outputs) print(f'\n Running registration \n') output_prefix = str(Path(reg_outputs_path, "ANTsPy_")) reg = ants.registration( fixed=fixed_image, # e.g., fixed autofluo image moving=transformed_image, # e.g., the initially aligned moving image (e.g., template) type_of_transform='SyN', # SyN = symmetric normalization grad_step=0.1, # Gradient step size syn_metric='CC', # Cross-correlation syn_sampling=2, # Corresponds to CC radius reg_iterations=(100, 70, 50, 20), # Convergence criteria outprefix=output_prefix, verbose=args.verbose ) # Save the warped moving image output ants.image_write(reg['warpedmovout'], output) # The interpolation method is not NN or multiLabel print(f"\nTransformed moving image saved to: \n{output}") # Save the warped fixed image output (optional) # warpedfixout = str(Path(reg_outputs_path, str(Path(args.fixed_img).name).replace(".nii.gz", "__warped_to_moving_image.nii.gz"))) # ants.image_write(reg['warpedfixout'], warpedfixout) # print(f"\nTransformed fixed image saved to: \n{warpedfixout}") # Warp the atlas image to the tissue image for checking reg (naming prioritizes the common usage) warped_atlas = str(Path(reg_outputs_path, str(Path(args.moving_img2).name).replace(".nii.gz", "_in_tissue_space.nii.gz"))) if not Path(warped_atlas).exists(): print(f'\n Warping the atlas to padded fixed image space for checking reg: reg_outputs/<atlas>_in_tissue_space.nii.gz\n') warp(reg_outputs_path, args.moving_img2, fixed_img_for_reg_path, warped_atlas, inverse=False, interpol=args.interpol) progress.update(task_id, advance=1) verbose_end_msg()
if __name__ == '__main__': main()