Source code for unravel.cluster_stats.index

#!/usr/bin/env python3

"""
Use ``cstats_index`` from UNRAVEL to create a cluster index with valid clusters from a given NIfTI image.

Outputs:
    - path/valid_clusters/rev_cluster_index_valid_clusters.nii.gz
    - path/valid_clusters/cluster_``*``_sunburst.csv

Note:
    - Default csv: UNRAVEL/unravel/core/csvs/sunburst_IDPath_Abbrv.csv
    - CCFv3-2020_info.csv or CCFv3-2017_info.csv

Usage
-----
    cstats_index -ci path/rev_cluster_index.nii.gz -ids 1 2 3 [-a atlas/atlas_CCFv3_2020_30um.nii.gz] [-vcd valid_clusters_dir] [-rgb] [-scsv sunburst_IDPath_Abbrv.csv] [-in CCFv3-2020_info.csv] [-v]
"""

from pathlib import Path
import nibabel as nib
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from rich import print
from rich.traceback import install

from unravel.cluster_stats.sunburst import sunburst
from unravel.core.help_formatter import RichArgumentParser, SuppressMetavar, SM
from unravel.core.config import Configuration
from unravel.core.utils import log_command, verbose_start_msg, verbose_end_msg


[docs] def parse_args(): parser = RichArgumentParser(formatter_class=SuppressMetavar, add_help=False, docstring=__doc__) reqs = parser.add_argument_group('Required arguments') reqs.add_argument('-ci', '--cluster_idx', help='Path to the reverse cluster index NIfTI file.', required=True, action=SM) reqs.add_argument('-ids', '--valid_cluster_ids', help='Space-separated list of valid cluster IDs.', nargs='*', type=int, required=True, action=SM) opts = parser.add_argument_group('Optional args') opts.add_argument('-vcd', '--valid_clusters_dir', help='path/name_of_the_output_directory. Default: valid_clusters', default='_valid_clusters', action=SM) opts.add_argument('-a', '--atlas', help='path/atlas.nii.gz. Default: atlas/atlas_CCFv3_2020_30um.nii.gz', default='atlas/atlas_CCFv3_2020_30um.nii.gz', action=SM) opts.add_argument('-rgb', '--output_rgb_lut', help='Output sunburst_RGBs.csv if flag provided (for Allen brain atlas coloring)', action='store_true') opts.add_argument('-scsv', '--sunburst_csv', help='CSV name or path/name.csv. Default: sunburst_IDPath_Abbrv.csv', default='sunburst_IDPath_Abbrv.csv', action=SM) opts.add_argument('-in', '--info', help='CSV name or path/name.csv. Default: CCFv3-2020_info.csv', default='CCFv3-2020_info.csv', action=SM) general = parser.add_argument_group('General arguments') general.add_argument('-v', '--verbose', help='Increase verbosity. Default: False', action='store_true', default=False) return parser.parse_args()
# TODO: Look into consolidating csvs
[docs] def generate_sunburst(cluster, img, atlas, xyz_res_in_um, data_type, output_dir, sunburst_csv_path, info_csv_path, output_rgb_lut): """Generate a sunburst plot for a given cluster. Args: - cluster (int): the cluster ID. - img (ndarray): the input image ndarray. - atlas (ndarray): the atlas ndarray. - atlas_res_in_um (tuple): the atlas resolution in microns. For example, (25, 25, 25) - data_type (type): the data type of the image. - output_dir (Path): the output directory. """ mask = (img == cluster) if np.any(mask): cluster_image = np.where(mask, cluster, 0).astype(data_type) cluster_sunburst_path = output_dir / f'cluster_{cluster}_sunburst.csv' sunburst_df = sunburst(cluster_image, atlas, xyz_res_in_um, cluster_sunburst_path, sunburst_csv_path, info_csv_path, output_rgb_lut)
[docs] @log_command def main(): install() args = parse_args() Configuration.verbose = args.verbose verbose_start_msg() output_dir = Path(args.valid_clusters_dir) output_dir.mkdir(exist_ok=True, parents=True) output_image_path = output_dir / str(Path(args.cluster_idx).name).replace('.nii.gz', f'_{output_dir.name}.nii.gz') if output_image_path.exists(): print(f"\n {output_image_path.name} already exists. Skipping.") return # Load the cluster index and set the data type nii = nib.load(args.cluster_idx) img = np.asanyarray(nii.dataobj, dtype=nii.header.get_data_dtype()).squeeze() max_cluster_id = int(img.max()) data_type = np.uint16 if max_cluster_id >= 256 else np.uint8 img = img.astype(data_type) # Load the atlas and get the resolution in microns atlas_nii = nib.load(args.atlas) atlas = np.asanyarray(atlas_nii.dataobj, dtype=atlas_nii.header.get_data_dtype()).squeeze() atlas_res = atlas_nii.header.get_zooms() # (x, y, z) in mm xyz_res_in_um = atlas_res[0] * 1000 # Write valid cluster indices to a file with open(output_dir / 'valid_clusters.txt', 'w') as file: file.write(' '.join(map(str, args.valid_cluster_ids))) # Generate the valid cluster index valid_cluster_index = np.zeros_like(img, dtype=data_type) for cluster in args.valid_cluster_ids: valid_cluster_index = np.where(img == cluster, cluster, valid_cluster_index) # Parallel processing of sunburst plots with ThreadPoolExecutor() as executor: futures = [executor.submit(generate_sunburst, cluster, img, atlas, xyz_res_in_um, data_type, output_dir, args.sunburst_csv, args.info, args.output_rgb_lut) for cluster in args.valid_cluster_ids] for future in futures: future.result() # Wait for all threads to complete print(f' Saved valid cluster index: {output_image_path}') nib.save(nib.Nifti1Image(valid_cluster_index, nii.affine, nii.header), output_image_path) # Generate the sunburst plot for the valid cluster index sunburst_df = sunburst(valid_cluster_index, atlas, xyz_res_in_um, output_dir / 'valid_clusters_sunburst.csv', args.sunburst_csv, args.info, args.output_rgb_lut) print(sunburst_df) verbose_end_msg()
if __name__ == '__main__': main()