Source code for unravel.image_tools.spatial_averaging

#!/usr/bin/env python3

"""
Use ``img_spatial_avg`` (``spatial_avg``) from UNRAVEL to load an image and apply 2D or 3D spatial averaging.

Inputs:
    - 3D image: .czi, .nii.gz, .ome.tif series, .tif series, .h5, .zarr

Outputs: 
    - Determined by the extension for .nii.gz and .zarr. Otherwise, saves as a .tif series.

Note:
    - 3D averaging: Applies a 3x3x3 kernel to each voxel and its 26 neighbors.
    - 2D averaging: Applies a 3x3 kernel to each 2D slice independently.
    - The output array is the same size as the input.
    - Edges are zero-padded to preserve dimensions.
    - For .nii.gz output, you must specify both xy and z resolutions.

Usage:
------
    img_spatial_avg -i <tif_dir> -o spatial_avg.zarr -sa 2 [-k 3] [-c 0] [-x 3.5232] [-z 6] [-dt uint16] [-r metadata_referenece.nii.gz] [-ao xyz] [-v]
"""

from pathlib import Path
import cv2
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from rich import print
from rich.live import Live
from rich.traceback import install
from scipy.ndimage import uniform_filter

from unravel.core.config import Configuration
from unravel.core.help_formatter import RichArgumentParser, SuppressMetavar, SM
from unravel.core.img_io import load_3D_img, save_as_nii, save_as_tifs, save_as_zarr
from unravel.core.utils import log_command, verbose_start_msg, verbose_end_msg, print_func_name_args_times, initialize_progress_bar, get_samples


[docs] def parse_args(): parser = RichArgumentParser(formatter_class=SuppressMetavar, add_help=False, docstring=__doc__) reqs = parser.add_argument_group('Required arguments') reqs.add_argument('-i', '--input', help="Path to full res image (relative to ./sample??/) or glob pattern (e.g., '*.czi'). First match used.", required=True, action=SM) reqs.add_argument('-o', '--output', help='Output path relative to the sample?? directory', required=True, action=SM) reqs.add_argument('-sa', '--spatial_avg', help='2D or 3D spatial averaging. (2 or 3)', required=True, type=int, action=SM) opts = parser.add_argument_group('Optional arguments') opts.add_argument('-k', '--kernel_size', help='Size of the kernel for spatial averaging. Default: 3', default=3, type=int, action=SM) opts.add_argument('-c', '--channel', help='Channel number. Default: 0 for autofluo', default=0, type=int, action=SM) opts.add_argument('-x', '--xy_res', help='xy resolution in um', default=None, type=float, action=SM) opts.add_argument('-z', '--z_res', help='z resolution in um', default=None, type=float, action=SM) opts.add_argument('-dt', '--dtype', help='Output data type. Default: uint16', default='uint16', action=SM) opts.add_argument('-r', '--reference', help='Reference image for .nii.gz metadata. Default: None', default=None, action=SM) opts.add_argument('-ao', '--axis_order', help='Default: xyz. (other option: zyx)', default='xyz', action=SM) general = parser.add_argument_group('General arguments') general.add_argument('-d', '--dirs', help='Paths to sample?? dirs and/or dirs containing them (space-separated) for batch processing. Default: current dir', nargs='*', default=None, action=SM) general.add_argument('-p', '--pattern', help='Pattern for directories to process. Default: sample??', default='sample??', action=SM) general.add_argument('-v', '--verbose', help='Increase verbosity. Default: False', action='store_true', default=False) return parser.parse_args()
[docs] @print_func_name_args_times() def spatial_average_3D(arr, kernel_size=3): """ Apply a 3D spatial averaging filter to a 3D numpy array. Parameters: - arr (np.ndarray): The input 3D array. - kernel_size (int): The size of the cubic kernel. Default is 3, for the current voxel and its 26 neighbors. Returns: - np.ndarray: The array after applying the spatial averaging. """ if arr.ndim != 3: raise ValueError("Input array must be 3D.") return uniform_filter(arr, size=kernel_size, mode='constant', cval=0.0)
[docs] def apply_2D_mean_filter(slice, kernel_size=(3, 3)): """Apply a 2D mean filter to a single slice.""" kernel = np.ones(kernel_size, np.float32) / (kernel_size[0] * kernel_size[1]) return cv2.filter2D(slice, -1, kernel)
[docs] @print_func_name_args_times() def spatial_average_2D(volume, filter_func, kernel_size=(3, 3), threads=8): """ Apply a specified 2D filter function to each slice of a 3D volume in parallel. Parameters: - volume (np.ndarray): The input 3D array. - filter_func (callable): The filter function to apply to each slice. - kernel_size (tuple): The dimensions of the kernel to be used in the filter: (height, width). - threads (int): The number of parallel threads to use. Returns: - np.ndarray: The volume processed with the filter applied to each slice. """ processed_volume = np.empty_like(volume) num_cores = min(len(volume), threads) # Limit the number of cores to the number of slices or specified threads with ThreadPoolExecutor(max_workers=num_cores) as executor: # Each slice is processed independently and the result is stored in the corresponding index results = executor.map(filter_func, volume, [kernel_size] * len(volume)) for i, processed_slice in enumerate(results): processed_volume[i] = processed_slice return processed_volume
[docs] @log_command def main(): install() args = parse_args() Configuration.verbose = args.verbose verbose_start_msg() if args.spatial_avg not in (2, 3): raise ValueError("--spatial_avg must be 2 or 3 for 2D or 3D averaging.") sample_paths = get_samples(args.dirs, args.pattern, args.verbose) progress, task_id = initialize_progress_bar(len(sample_paths), "[red]Processing samples...") with Live(progress): for sample_path in sample_paths: output = Path(sample_path / args.output) if output.exists(): print(f"\n {output} already exists. Skipping.") continue # Load full res image [and xy and z voxel size in microns], to be resampled [and reoriented], padded, and warped img_path = next(sample_path.glob(str(args.input)), None) if img_path is None: print(f"\n [red1]No files match the pattern {args.input} in {sample_path}\n") continue # Load image and metadata if args.xy_res is None or args.z_res is None: img, xy_res, z_res = load_3D_img(img_path, channel=args.channel, return_res=True, verbose=args.verbose) else: img = load_3D_img(img_path, channel=args.channel, verbose=args.verbose) xy_res, z_res = args.xy_res, args.z_res # Apply spatial averaging if args.spatial_avg == 3: img = spatial_average_3D(img, kernel_size=args.kernel_size) elif args.spatial_avg == 2: img = spatial_average_2D(img, apply_2D_mean_filter, kernel_size=(args.kernel_size, args.kernel_size)) else: raise ValueError("Spatial averaging must be 2 or 3 for 2D or 3D averaging.") # Set the data type for the output if args.dtype == 'uint8': img = img.astype(np.uint8) elif args.dtype == 'uint16': img = img.astype(np.uint16) elif args.dtype == 'float32': img = img.astype(np.float32) else: raise ValueError("Data type must be uint8, uint16, or float32.") # Save image if str(output).endswith('.nii.gz'): save_as_nii(img, output, xy_res, z_res, data_type=args.dtype, reference=args.reference) elif str(output).endswith('.zarr'): save_as_zarr(img, output, ndarray_axis_order=args.axis_order) else: save_as_tifs(img, output, ndarray_axis_order=args.axis_order) progress.update(task_id, advance=1) verbose_end_msg()
if __name__ == '__main__': main()