#!/usr/bin/env python3
"""
Use ``cstats_validation`` from UNRAVEL to warp a cluster index from atlas space to tissue space, crop clusters, apply a segmentation mask, and quantify cell/label densities.
Prereqs:
- ``cstats_fdr`` to generate a cluster index in atlas space (a map of clusters of significant voxels)
- ``seg_ilastik`` to generate a segmentation mask in tissue space (e.g., to label c-Fos+ cells)
Inputs:
- path/rev_cluster_index.nii.gz to warp from atlas space (rev = reverse, i.e., cluster IDs are from large to small)
- rel_path/seg_img.nii.gz. 1st glob match processed
Outputs:
- ./sample??/clusters/<cluster_index_dir>/outer_bounds.txt
- ./sample??/clusters/<cluster_index_dir>/<args.density>_data.csv
- cluster_index_dir = Path(args.moving_img).name w/o "_rev_cluster_index" and ".nii.gz"
Note:
- For -s, if a dir name is provided, the command will load ./sample??/seg_dir/sample??_seg_dir.nii.gz.
- If a relative path is provided, the command will load the image at the specified path.
Next command:
``cstats_summary``
Usage:
------
cstats_validation -m <path/rev_cluster_index_to_warp_from_atlas_space.nii.gz> -s <rel_path/seg_img.nii.gz> [-de cell_density | label_density] [-o rel_path/cluster_data.csv] [-c 1 3 4] [optional output: -n rel_path/native_cluster_index.zarr] [-fri autofl_50um_masked_fixed_reg_input.nii.gz] [-inp nearestNeighbor] [-ro reg_outputs] [-r 50] [-md parameters/metadata.txt] [-zo 0] [-mi] [-cc 6] [-d list of paths] [-p sample??] [-v]
"""
import concurrent.futures
import cc3d
import numpy as np
import os
import pandas as pd
from pathlib import Path
from rich import print
from rich.live import Live
from rich.traceback import install
from unravel.core.help_formatter import RichArgumentParser, SuppressMetavar, SM
from unravel.core.config import Configuration
from unravel.core.img_io import load_3D_img, load_image_metadata_from_txt, load_nii_subset, resolve_path
from unravel.core.img_tools import cluster_IDs
from unravel.core.utils import log_command, verbose_start_msg, verbose_end_msg, initialize_progress_bar, get_samples, print_func_name_args_times
from unravel.warp.to_native import to_native
[docs]
def parse_args():
parser = RichArgumentParser(formatter_class=SuppressMetavar, add_help=False, docstring=__doc__)
reqs = parser.add_argument_group('Required arguments')
reqs.add_argument('-m', '--moving_img', help='path/*_rev_cluster_index.nii.gz to warp from atlas space', required=True, action=SM)
reqs.add_argument('-s', '--seg', help='rel_path/seg_img.nii.gz. 1st glob match processed', required=True, action=SM)
opts = parser.add_argument_group('Optional args')
opts.add_argument('-de', '--density', help='Density to measure: cell_density (default) or label_density', default='cell_density', choices=['cell_density', 'label_density'], action=SM)
opts.add_argument('-o', '--output', help='rel_path/clusters_info.csv. Default: clusters/<cluster_index_dir>/cluster_data.csv', default=None, action=SM)
opts.add_argument('-c', '--clusters', help='Clusters to process: all or list of clusters (e.g., 1 3 4). Processes all clusters by default', nargs='*', default='all', action=SM)
# Optional to_native() args
opts_to_native = parser.add_argument_group('Optional args for to_native()')
opts_to_native.add_argument('-n', '--native_idx', help='Load/save native cluster index from/to rel_path/native_image.zarr (fast) or rel_path/native_image.nii.gz if provided', default=None, action=SM)
opts_to_native.add_argument('-fri', '--fixed_reg_in', help='Fixed input for registration (unravel.register.reg). Default: autofl_50um_masked_fixed_reg_input.nii.gz', default="autofl_50um_masked_fixed_reg_input.nii.gz", action=SM)
opts_to_native.add_argument('-inp', '--interpol', help='Interpolator for ants.apply_transforms (nearestNeighbor [default], multiLabel [slow])', default="nearestNeighbor", action=SM)
opts_to_native.add_argument('-ro', '--reg_outputs', help="Name of folder w/ outputs from unravel.register.reg (e.g., transforms). Default: reg_outputs", default="reg_outputs", action=SM)
opts_to_native.add_argument('-r', '--reg_res', help='Resolution of registration inputs in microns. Default: 50', default='50',type=int, action=SM)
opts_to_native.add_argument('-md', '--metadata', help='path/metadata.txt. Default: parameters/metadata.txt', default="parameters/metadata.txt", action=SM)
opts_to_native.add_argument('-zo', '--zoom_order', help='SciPy zoom order for scaling to full res. Default: 0 (nearest-neighbor)', default='0',type=int, action=SM)
# Compatability args
compatability = parser.add_argument_group('Compatability options for to_native()')
compatability.add_argument('-mi', '--miracl', help='Mode for compatibility (accounts for tif to nii reorienting)', action='store_true', default=False)
# Optional arg for count_cells()
opts_cell_counts = parser.add_argument_group('Optional args for count_cells()')
opts_cell_counts.add_argument('-cc', '--connect', help='Connected component connectivity (6, 18, or 26). Default: 6', type=int, default=6, action=SM)
general = parser.add_argument_group('General arguments')
general.add_argument('-d', '--dirs', help='Paths to sample?? dirs and/or dirs containing them (space-separated) for batch processing. Default: current dir', nargs='*', default=None, action=SM)
general.add_argument('-p', '--pattern', help='Pattern for directories to process. Default: sample??', default='sample??', action=SM)
general.add_argument('-v', '--verbose', help='Increase verbosity. Default: False', action='store_true', default=False)
return parser.parse_args()
# TODO: QC. Aggregate .csv results for all samples if args.dirs, script to load image subset.
# TODO: Make config file for defaults or a command_generator.py script
# TODO: Consider adding an option to quantify mean IF intensity in each cluster in segmented voxels. Also make a script for mean IF intensity in clusters in atlas space.
# TODO: Use glob for -s to load the first match. If no match, print a message and continue to the next sample. Afterwards, update in help: "For -s, if a dir name is provided, the command will load ./sample??/seg_dir/sample??_seg_dir.nii.gz."
# TODO: Consider removing the -o option. Have I used this so far? If not, remove it.
[docs]
@print_func_name_args_times()
def crop_outer_space(native_cluster_index, output_path):
"""Crop outer space around all clusters and save bounding box to .txt file (outer_bounds.txt)
Return cropped native_cluster_index, outer_xmin, outer_xmax, outer_ymin, outer_ymax, outer_zmin, outer_zmax"""
# Create boolean arrays indicating presence of clusters along each axis
presence_x = np.any(native_cluster_index, axis=(1, 2))
presence_y = np.any(native_cluster_index, axis=(0, 2))
presence_z = np.any(native_cluster_index, axis=(0, 1))
# Use np.argmax on presence arrays to find first occurrence of clusters
# For max, reverse the array, use np.argmax, and subtract from the length
outer_xmin, outer_xmax = np.argmax(presence_x), len(presence_x) - np.argmax(presence_x[::-1])
outer_ymin, outer_ymax = np.argmax(presence_y), len(presence_y) - np.argmax(presence_y[::-1])
outer_zmin, outer_zmax = np.argmax(presence_z), len(presence_z) - np.argmax(presence_z[::-1])
# Adjust the max bounds to include the last slice where the cluster is present
outer_xmax += 1
outer_ymax += 1
outer_zmax += 1
# Crop the native_cluster_index to the bounding box
native_cluster_index_cropped = native_cluster_index[outer_xmin:outer_xmax, outer_ymin:outer_ymax, outer_zmin:outer_zmax]
# Save the bounding box to a file
with open(f"{output_path.parent}/outer_bounds.txt", "w") as file:
file.write(f"{outer_xmin}:{outer_xmax}, {outer_ymin}:{outer_ymax}, {outer_zmin}:{outer_zmax}")
return native_cluster_index_cropped, outer_xmin, outer_xmax, outer_ymin, outer_ymax, outer_zmin, outer_zmax
[docs]
def cluster_bbox(cluster_ID, native_cluster_index_cropped):
"""Get bounding box for the current cluster. Return cluster_ID, xmin, xmax, ymin, ymax, zmin, zmax."""
cluster_mask = native_cluster_index_cropped == cluster_ID
presence_x = np.any(cluster_mask, axis=(1, 2))
presence_y = np.any(cluster_mask, axis=(0, 2))
presence_z = np.any(cluster_mask, axis=(0, 1))
xmin, xmax = np.argmax(presence_x), len(presence_x) - np.argmax(presence_x[::-1])
ymin, ymax = np.argmax(presence_y), len(presence_y) - np.argmax(presence_y[::-1])
zmin, zmax = np.argmax(presence_z), len(presence_z) - np.argmax(presence_z[::-1])
return cluster_ID, xmin, xmax, ymin, ymax, zmin, zmax
[docs]
@print_func_name_args_times()
def cluster_bbox_parallel(native_cluster_index_cropped, clusters):
"""Get bounding boxes for each cluster in parallel. Return list of results."""
results = []
num_cores = os.cpu_count() # This is good for CPU-bound tasks. Could try 2 * num_cores + 1 for io-bound tasks
workers = min(num_cores, len(clusters))
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
future_to_cluster = {executor.submit(cluster_bbox, cluster_ID, native_cluster_index_cropped): cluster_ID for cluster_ID in clusters}
for future in concurrent.futures.as_completed(future_to_cluster):
cluster_ID = future_to_cluster[future]
try:
result = future.result()
results.append(result)
except Exception as exc:
print(f'Cluster {cluster_ID} generated an exception: {exc}')
return results
[docs]
def count_cells(seg_in_cluster, connectivity=6):
"""Count cells (objects) in each cluster using connected-components-3d
Return the number of cells in the cluster."""
# If the data is big-endian, convert it to little-endian
if seg_in_cluster.dtype.byteorder == '>':
seg_in_cluster = seg_in_cluster.byteswap().newbyteorder()
seg_in_cluster = seg_in_cluster.astype(np.uint8)
# Count the number of cells in the cluster
labels_out, n = cc3d.connected_components(seg_in_cluster, connectivity=connectivity, out_dtype=np.uint32, return_N=True)
return n
[docs]
def density_in_cluster(cluster_data, native_cluster_index_cropped, seg_cropped, xy_res, z_res, connectivity=6, density='cell_count'):
"""Measure cell count or volume of segmented voxels in the current cluster.
For cell densities, return: cluster_ID, cell_count, cluster_volume_in_cubic_mm, cell_density, xmin, xmax, ymin, ymax, zmin, zmax
For label densities, return: cluster_ID, seg_volume_in_cubic_mm, cluster_volume_in_cubic_mm, label_density, xmin, xmax, ymin, ymax, zmin, zmax.
"""
cluster_ID, xmin, xmax, ymin, ymax, zmin, zmax = cluster_data
# Crop the cluster from the native cluster index
cropped_cluster = native_cluster_index_cropped[xmin:xmax, ymin:ymax, zmin:zmax]
# Crop the segmentation image for the current cluster
seg_in_cluster = seg_cropped[xmin:xmax, ymin:ymax, zmin:zmax]
# Zero out segmented voxels outside of the current cluster
seg_in_cluster[cropped_cluster == 0] = 0
# Measure cluster volume
cluster_volume_in_cubic_mm = ((xy_res**2) * z_res) * np.count_nonzero(cropped_cluster) / 1e9
# Count cells or measure the volume of segmented voxels
if density == "cell_density":
cell_count = count_cells(seg_in_cluster, connectivity=connectivity)
cell_density = cell_count / cluster_volume_in_cubic_mm
return cluster_ID, cell_count, cluster_volume_in_cubic_mm, cell_density, xmin, xmax, ymin, ymax, zmin, zmax
else:
seg_volume_in_cubic_mm = ((xy_res**2) * z_res) * np.count_nonzero(seg_in_cluster) / 1e9
label_density = seg_volume_in_cubic_mm / cluster_volume_in_cubic_mm * 100
return cluster_ID, seg_volume_in_cubic_mm, cluster_volume_in_cubic_mm, label_density, xmin, xmax, ymin, ymax, zmin, zmax
[docs]
@print_func_name_args_times()
def density_in_cluster_parallel(cluster_bbox_results, native_cluster_index_cropped, seg_cropped, xy_res, z_res, connectivity=6, density='cell_count'):
"""Measure cell count or volume of segmented voxels in each cluster in parallel. Return list of results."""
results = []
num_cores = os.cpu_count()
workers = min(num_cores, len(cluster_bbox_results))
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
future_to_cluster = {executor.submit(density_in_cluster, cluster_data, native_cluster_index_cropped, seg_cropped, xy_res, z_res, connectivity, density): cluster_data[0] for cluster_data in cluster_bbox_results} # cluster_data[0] is the cluster_ID
for future in concurrent.futures.as_completed(future_to_cluster):
cluster_ID = future_to_cluster[future]
try:
result = future.result()
results.append(result)
except Exception as exc:
print(f'Cluster {cluster_ID} generated an exception: {exc}')
return results
[docs]
@log_command
def main():
install()
args = parse_args()
Configuration.verbose = args.verbose
verbose_start_msg()
sample_paths = get_samples(args.dirs, args.pattern, args.verbose)
progress, task_id = initialize_progress_bar(len(sample_paths), "[red]Processing samples...")
with Live(progress):
for sample_path in sample_paths:
# Define final output and check if it exists
cluster_index_dir = str(Path(args.moving_img).name).replace(".nii.gz", "").replace("_rev_cluster_index_", "_")
if args.output:
output_path = resolve_path(sample_path, args.output)
else:
output_path = resolve_path(sample_path, Path("clusters", cluster_index_dir, f"{args.density}_data.csv"), make_parents=True)
if output_path and output_path.exists():
print(f"\n\n {output_path} already exists. Skipping.\n")
continue
# Use lower bit-depth possible for cluster index
rev_cluster_index = load_3D_img(args.moving_img)
# Define paths relative to sample?? folder
native_idx_path = resolve_path(sample_path, args.native_idx) if args.native_idx else None
# Load cluster index and convert to ndarray
if args.native_idx and Path(args.native_idx).exists():
native_cluster_index = load_3D_img(Path(args.native_idx).exists())
else:
fixed_reg_input = Path(sample_path, args.reg_outputs, args.fixed_reg_in)
if not fixed_reg_input.exists():
fixed_reg_input = sample_path / args.reg_outputs / "autofl_50um_fixed_reg_input.nii.gz"
native_cluster_index = to_native(sample_path, args.reg_outputs, fixed_reg_input, args.moving_img, args.metadata, args.reg_res, args.miracl, args.zoom_order, args.interpol, output=native_idx_path)
# Get clusters to process
if args.clusters == "all":
clusters = cluster_IDs(rev_cluster_index)
else:
clusters = args.clusters
clusters = [int(cluster) for cluster in clusters]
# Crop outer space around all clusters
native_cluster_index_cropped, outer_xmin, outer_xmax, outer_ymin, outer_ymax, outer_zmin, outer_zmax = crop_outer_space(native_cluster_index, output_path)
# Load image metadata from .txt
metadata_path = resolve_path(sample_path, args.metadata)
xy_res, z_res, _, _, _ = load_image_metadata_from_txt(metadata_path)
if xy_res is None or z_res is None:
print(" [red bold]./sample??/parameters/metadata.txt missing. cd to sample?? dir and run: io_metadata")
# Get bounding boxes for each cluster in parallel
cluster_bbox_data = cluster_bbox_parallel(native_cluster_index_cropped, clusters)
# Load the segmentation image and crop it to the outer bounds of all clusters
seg_path = next(sample_path.glob(str(args.seg)), None)
if seg_path is None:
print(f"\n [red bold]No files match the pattern {args.seg} in {sample_path}\n")
continue
seg_cropped = load_nii_subset(seg_path, outer_xmin, outer_xmax, outer_ymin, outer_ymax, outer_zmin, outer_zmax)
# Process each cluster to count cells or measure volume, in parallel
cluster_data_results = density_in_cluster_parallel(cluster_bbox_data, native_cluster_index_cropped, seg_cropped, xy_res, z_res, args.connect, args.density)
# Process cluster_data_results to save to CSV or perform further analysis
data_list = []
for result in cluster_data_results:
cluster_ID, cell_count_or_seg_vol, cluster_volume_in_cubic_mm, density_measure, xmin, xmax, ymin, ymax, zmin, zmax = result
# Determine the appropriate headers based on the density measure type
if args.density == "cell_density":
count_or_vol_header, density_header = "cell_count", "cell_density"
else:
count_or_vol_header, density_header = "label_volume", "label_density"
# Prepare the data dictionary
data = {
"sample": sample_path.name,
"cluster_ID": cluster_ID,
count_or_vol_header: cell_count_or_seg_vol,
"cluster_volume": cluster_volume_in_cubic_mm,
density_header: density_measure,
"xmin": xmin, "xmax": xmax, "ymin": ymin, "ymax": ymax, "zmin": zmin, "zmax": zmax
}
data_list.append(data)
# Create a DataFrame from the list of data dictionaries
df = pd.DataFrame(data_list)
# Sort the DataFrame by 'cluster_ID' in ascending order
df_sorted = df.sort_values(by='cluster_ID', ascending=True)
# Save the sorted DataFrame to the CSV file
df_sorted.to_csv(output_path, index=False)
print(f"\n Output: [default bold]{output_path}")
progress.update(task_id, advance=1)
verbose_end_msg()
if __name__ == '__main__':
main()