Source code for unravel.register.adjust_template

#!/usr/bin/env python3
"""
Use ``reg_adjust_template`` (``rat``) from UNRAVEL to adjust intensities in a template for a single set of atlas label IDs.

Note:
    - The boundary between the adjusted region and the rest of the image can be sharp or blurred by applying a Gaussian blur to a two-sided boundary band around the region.
    - The boundary band is computed as:
        inside band  = region - erosion(region, band_width)
        outside band = dilation(region, band_width) - region
        band = inside OR outside
    - We blur the full (modified) image, then replace only band voxels with blurred values.
    
Example of use (separate steps):
--------------------------------
# 1) Fiber tracts: scale by 0.85
reg_adjust_template -i template.nii.gz -a atlas.nii.gz -ids <FIBER_IDS...> --scale 0.85 -bb -o template_fibers.nii.gz

# 2) HPF: scale by 0.8 (using previous output)
reg_adjust_template -i template_fibers.nii.gz -a atlas.nii.gz -ids <HPF_IDS...> --scale 0.8 -bb -o template_fibers_hpf.nii.gz

# 3) DG/CAsp: scale by 2.0 (using previous output)
reg_adjust_template -i template_fibers_hpf.nii.gz -a atlas.nii.gz -ids <DG_CAsp_IDS...> --scale 2.0 -bb -o template_fibers_hpf_dg.nii.gz

Usage:
------
    reg_adjust_template -i path/image.nii.gz -a path/atlas_labels.nii.gz -ids <ID1> <ID2> ... -o path/output.nii.gz [--scale <factor> | --set <value>] [-bb -s <float> -bw <int>] [-dt <dtype>] [-v]
"""

from __future__ import annotations

import numpy as np
from scipy.ndimage import gaussian_filter, binary_erosion, binary_dilation
from rich.traceback import install
from rich import print

from unravel.core.help_formatter import RichArgumentParser, SuppressMetavar, SM
from unravel.core.config import Configuration
from unravel.core.img_io import load_nii, save_3D_img
from unravel.core.utils import log_command, verbose_start_msg, verbose_end_msg


[docs] def parse_args(): p = RichArgumentParser(formatter_class=SuppressMetavar, add_help=False, docstring=__doc__) req = p.add_argument_group('Required arguments') req.add_argument('-i', '--input', help='Input image (.nii.gz)', required=True, action=SM) req.add_argument('-a', '--atlas', help='Atlas label image (.nii.gz) in same space', required=True, action=SM) req.add_argument('-ids', '--label_IDs', help='Atlas label IDs defining the region to modify.', required=True, nargs='*', type=int, action=SM) req.add_argument('-o', '--output', help='Output image (.nii.gz)', required=True, action=SM) # Exactly one of these must be provided op = p.add_argument_group('Operation (choose one)') op.add_argument('--scale', help='Multiply region voxels by this factor.', type=float, default=None, action=SM) op.add_argument('--set', help='Set region voxels to this absolute value.', type=float, default=None, action=SM) opt = p.add_argument_group('Optional arguments') opt.add_argument('-bb', '--band_blur', help='Blur only a two-sided boundary band (inside+outside) around the region after modifying it.', action='store_true', default=False) opt.add_argument('-s', '--sigma', help='Gaussian sigma (voxels) used when --band_blur is set. Default: 1.0', type=float, default=1.0, action=SM) opt.add_argument('-bw', '--band_width', help='Boundary band thickness in voxels (>=1). Default: 1', type=int, default=1, action=SM) opt.add_argument('-dt', '--dtype', help='Optional dtype for output (e.g., uint16). If omitted, preserves input dtype.', default=None, action=SM) gen = p.add_argument_group('General arguments') gen.add_argument('-v', '--verbose', help='Increase verbosity. Default: False', action='store_true', default=False) return p.parse_args()
[docs] def mask_from_ids(atlas: np.ndarray, ids: list[int]) -> np.ndarray: """Binary mask where atlas label is in ids.""" return np.isin(atlas, np.asarray(ids, dtype=atlas.dtype))
[docs] def boundary_band_mask(region_mask: np.ndarray, band_width: int = 1) -> np.ndarray: """ Two-sided boundary band: inside band = region - erode(region, n) outside band = dilate(region, n) - region band = inside OR outside """ if band_width < 1: raise ValueError("--band_width must be >= 1") er = region_mask.copy() dl = region_mask.copy() for _ in range(band_width): er = binary_erosion(er) dl = binary_dilation(dl) inside = region_mask & (~er) outside = dl & (~region_mask) return inside | outside
[docs] def safe_cast_like(x: np.ndarray, dtype: np.dtype) -> np.ndarray: """Cast with rounding/clipping if integer.""" dt = np.dtype(dtype) if np.issubdtype(dt, np.integer): info = np.iinfo(dt) x = np.rint(x) x = np.clip(x, info.min, info.max) return x.astype(dt) return x.astype(dt)
[docs] @log_command def main(): install() args = parse_args() Configuration.verbose = args.verbose verbose_start_msg() # Enforce exactly one operation if (args.scale is None) == (args.set is None): raise ValueError("Specify exactly one of --scale or --set") img = load_nii(args.input) atlas = load_nii(args.atlas) if img.shape != atlas.shape: raise ValueError(f"Shape mismatch: input {img.shape} vs atlas {atlas.shape}") in_dtype = img.dtype work = img.astype(np.float32, copy=True) reg_mask = mask_from_ids(atlas, args.label_IDs) nvox = int(reg_mask.sum()) if args.verbose: print(f"\nRegion voxels: {nvox}\n") # Modify region if args.scale is not None: if args.verbose: print(f"Scaling region by {args.scale}") work[reg_mask] *= float(args.scale) else: if args.verbose: print(f"Setting region to {args.set}") work[reg_mask] = float(args.set) # Optional two-sided band blur for that region if args.band_blur: band = boundary_band_mask(reg_mask, band_width=args.band_width) if args.verbose: print( f"Band voxels: {int(band.sum())} " f"(sigma={args.sigma}, band_width={args.band_width})" ) blurred = gaussian_filter(work, sigma=args.sigma) work[band] = blurred[band] out_dtype = np.dtype(args.dtype) if args.dtype else in_dtype out = safe_cast_like(work, out_dtype) save_3D_img(out, args.output, reference_img=args.input, verbose=args.verbose) verbose_end_msg()
if __name__ == "__main__": main()