Source code for unravel.voxel_stats.apply_mask

#!/usr/bin/env python3

""" 
Use ``vstats_apply_mask`` from UNRAVEL to zeros out voxels in image based on a mask and direction args.

Usage:
------
    vstats_apply_mask -i input_image.nii.gz -mas mask.nii.gz [-dil 0] [--mean] [-tmas brain_mask.nii.gz] [-omas other_mask.nii.gz] [-di less | greater] [-o output_image.nii.gz] [-md parameters/metadata.txt] [--reg_res 50] [-mi] [-d list of paths] [-p sample??] [-v]

Usage to zero out voxels in image where mask > 0 (e.g., to exclude voxels representing artifacts):
--------------------------------------------------------------------------------------------------
    vstats_apply_mask -mas 6e10_seg_ilastik_2/sample??_6e10_seg_ilastik_2.nii.gz -i 6e10_rb20 -o 6e10_rb20_wo_artifacts -di greater 

Usage to zero out voxels in image where mask < 1 (e.g., to preserve signal from segmented microglia clusters):
--------------------------------------------------------------------------------------------------------------
    vstats_apply_mask -mas iba1_seg_ilastik_2/sample??_iba1_seg_ilastik_2.nii.gz -i iba1_rb20 -o iba1_rb20_clusters

Usage to replace voxels in image with the mean intensity in the brain where mask > 0:
-------------------------------------------------------------------------------------
    vstats_apply_mask -mas FOS_seg_ilastik/FOS_seg_ilastik_2.nii.gz -i FOS -o FOS_wo_halo.zarr -di greater -m
"""

import nibabel as nib
import numpy as np
from pathlib import Path
from rich import print
from rich.live import Live
from rich.traceback import install
from scipy.ndimage import binary_dilation, zoom

from unravel.register.reg_prep import reg_prep
from unravel.core.help_formatter import RichArgumentParser, SuppressMetavar, SM

from unravel.core.config import Configuration 
from unravel.core.img_io import load_3D_img, load_image_metadata_from_txt, resolve_path, save_as_tifs, save_as_nii, save_as_zarr
from unravel.core.utils import log_command, verbose_start_msg, verbose_end_msg, print_func_name_args_times, initialize_progress_bar, get_samples


[docs] def parse_args(): parser = RichArgumentParser(formatter_class=SuppressMetavar, add_help=False, docstring=__doc__) reqs = parser.add_argument_group('Required arguments') reqs.add_argument('-i', '--input', help='Image input path relative to ./ or ./sample??/', required=True, action=SM) reqs.add_argument('-mas', '--seg_mask', help='rel_path/mask_to_apply.nii.gz (in full res tissue space)', required=True, action=SM) opts = parser.add_argument_group('Optional arguments') opts.add_argument("-dil", "--dilation", help="Number of dilation iterations to perform on full res seg_mask (slow but precise). Default: 0", default=0, type=int, action=SM) opts.add_argument('-m', '--mean', help='If provided, conditionally replace values w/ the mean intensity in the brain', action='store_true', default=False) opts.add_argument('-tmas', '--tissue_mask', help='For the mean itensity. rel_path/brain_mask.nii.gz. Default: reg_inputs/autofl_50um_brain_mask.nii.gz', default="reg_inputs/autofl_50um_brain_mask.nii.gz", action=SM) opts.add_argument('-omas', '--other_mask', help='For restricting application of -mas. E.g., reg_inputs/autofl_50um_brain_mask_outline.nii.gz (from ./UNRAVEL/_other/uncommon_scripts/brain_mask_outline.py)', default=None, action=SM) opts.add_argument('-di', '--direction', help='"greater" to zero out where mask > 0, "less" (default) to zero out where mask < 1', default='less', choices=['greater', 'less'], action=SM) opts.add_argument('-o', '--output', help='Image output path relative to ./ or ./sample??/', action=SM) opts.add_argument('-md', '--metadata', help='path/metadata.txt. Default: parameters/metadata.txt', default="parameters/metadata.txt", action=SM) opts.add_argument('-r', '--reg_res', help='Resample input to this res in microns for ``reg``. Default: 50', default=50, type=int, action=SM) compatability = parser.add_argument_group('Compatability options') compatability.add_argument('-mi', '--miracl', help="Include reorientation step to mimic MIRACL's tif to .nii.gz conversion", action='store_true', default=False) general = parser.add_argument_group('General arguments') general.add_argument('-d', '--dirs', help='Paths to sample?? dirs and/or dirs containing them (space-separated) for batch processing. Default: current dir', nargs='*', default=None, action=SM) general.add_argument('-p', '--pattern', help='Pattern for directories to process. Default: sample??', default='sample??', action=SM) general.add_argument('-v', '--verbose', help='Increase verbosity. Default: False', action='store_true', default=False) return parser.parse_args()
[docs] @print_func_name_args_times() def load_mask(mask_path): """Load .nii.gz and return to an ndarray with a binary dtype""" mask_nii = nib.load(mask_path) return np.asanyarray(mask_nii.dataobj, dtype=np.bool_).squeeze()
[docs] @print_func_name_args_times() def mean_intensity_in_brain(img, tissue_mask): """Z-score the image using the mask. Args: - img (str): the ndarray to be z-scored. - mask (str): the brain mask ndarray""" # Zero out voxels outside the mask masked_data = img * tissue_mask # Calculate mean for masked data masked_nonzero = masked_data[masked_data != 0] # Exclude zero voxels and flatten the array (1D) mean_intensity = masked_nonzero.mean() return mean_intensity
[docs] @print_func_name_args_times() def dilate_mask(mask, iterations): """Dilate the given mask (ndarray) by a specified number of iterations.""" dilated_mask = binary_dilation(mask, iterations=iterations) return dilated_mask
[docs] @print_func_name_args_times() def scale_bool_to_full_res(ndarray, full_res_dims): """Scale ndarray to match x, y, z dimensions provided. Uses nearest-neighbor interpolation by default to preserve a binary data type.""" zoom_factors = (full_res_dims[0] / ndarray.shape[0], full_res_dims[1] / ndarray.shape[1], full_res_dims[2] / ndarray.shape[2]) return zoom(ndarray, zoom_factors, order=0).astype(np.bool_)
[docs] @print_func_name_args_times() def apply_mask_to_ndarray(ndarray, mask_ndarray, other_mask=None, mask_condition='less', new_value=0): """Replace voxels in the ndarray with a new_value based on mask conditions. Optionally use a second mask to restrict application spatially.""" if mask_ndarray.shape != ndarray.shape: raise ValueError("Primary mask and input image must have the same shape") if other_mask is not None and other_mask.shape != ndarray.shape: raise ValueError("Other mask and input image must have the same shape") # Combine masks if other_mask is provided, using logical AND (both masks need to be True) if other_mask is not None: mask_ndarray = np.logical_and(mask_ndarray, other_mask) # Both masks must be True to remain True # Apply the combined mask to the ndarray if mask_condition == 'greater': ndarray[mask_ndarray] = new_value # mask_ndarray already represents where mask is True elif mask_condition == 'less': ndarray[~mask_ndarray] = new_value # Use logical NOT to flip True/False return ndarray
[docs] @log_command def main(): install() args = parse_args() Configuration.verbose = args.verbose verbose_start_msg() sample_paths = get_samples(args.dirs, args.pattern, args.verbose) progress, task_id = initialize_progress_bar(len(sample_paths), "[red]Processing samples...") with Live(progress): for sample_path in sample_paths: # Define output output = resolve_path(sample_path, args.output, make_parents=True) if output.exists(): print(f"\n\n {output.name} already exists. Skipping.\n") continue # Load image img = load_3D_img(sample_path / args.input, return_res=False) # Load metadata metadata_path = sample_path / args.metadata xy_res, z_res, _, _, _ = load_image_metadata_from_txt(metadata_path) if xy_res is None: print(" [red1]./sample??/parameters/metadata.txt is missing. Generate w/ io_metadata") import sys ; sys.exit() # Resample to registration resolution to get the mean intensity in the brain img_resampled = reg_prep(img, xy_res, z_res, args.reg_res, int(1), args.miracl) # Load 50 um tissue mask tissue_mask_img = load_3D_img(sample_path / args.tissue_mask) # Calculate mean intensity in brain if args.mean: mean_intensity = mean_intensity_in_brain(img_resampled, tissue_mask_img) # Check if "sample??_" is in the mask path and replace it with the actual sample name if f"{args.pattern}_" in args.seg_mask: dynamic_mask_path = args.seg_mask.replace(f"{args.pattern}_", f"{sample_path.name}_") else: dynamic_mask_path = args.seg_mask # Load full res mask with the updated or original path mask = load_mask(sample_path / dynamic_mask_path) # Dilate the primary mask if args.dilation > 0: mask = dilate_mask(mask, args.dilation) # Load the other mask and scale to full resolution if args.other_mask: other_mask_img = load_mask(sample_path / args.other_mask) metadata_path = sample_path / args.metadata xy_res, z_res, x_dim, y_dim, z_dim = load_image_metadata_from_txt(metadata_path) original_dimensions = np.array([x_dim, y_dim, z_dim]) other_mask_img = scale_bool_to_full_res(other_mask_img, original_dimensions).astype(np.bool_) # Apply mask to image if args.mean: masked_img = apply_mask_to_ndarray(img, mask, other_mask=other_mask_img, mask_condition=args.direction, new_value=mean_intensity) else: masked_img = apply_mask_to_ndarray(img, mask, other_mask=other_mask_img, mask_condition=args.direction, new_value=0) # Save masked image output.parent.mkdir(parents=True, exist_ok=True) if str(output).endswith(".zarr"): save_as_zarr(masked_img, output) elif str(output).endswith('.nii.gz'): save_as_nii(masked_img, output, xy_res, z_res, img.dtype) else: output.mkdir(parents=True, exist_ok=True) save_as_tifs(masked_img, output, "xyz") progress.update(task_id, advance=1) verbose_end_msg()
if __name__ == '__main__': main()