#!/usr/bin/env python3
"""
Use ``rstats_mean_IF_summary`` (``rmis``) from UNRAVEL to output plots of mean IF intensities for each region intensity ID.
Prereqs:
- Generate CSV inputs withs ``rstats_IF_mean`` or ``rstats_IF_mean_in_seg``
- After ``rstats_IF_mean_in_seg``, aggregate CSV inputs with ``utils_agg_files``
- If needed, add conditions to input CSV file names: ``utils_prepend`` -sk $SAMPLE_KEY -f
Inputs:
- `*`.csv in the working dir with these columns: 'Region_Intensity', 'Mean_IF_Intensity'
Outputs:
- rstats_mean_IF_summary/region_<region_id>_<region_abbr>.pdf for each region
- If significant differences are found, a prefix '_' is added to the filename to sort the files
Note:
- The first word of the csv inputs is used for the the group names (e.g. Control from Control_sample01_cFos_rb4_atlas_space_z.csv)
- Default csv: UNRAVEL/unravel/core/csvs/CCFv3-2020__regionID_side_IDpath_region_abbr.csv
- Alternatively, use CCFv3-2017__regionID_side_IDpath_region_abbr.csv or provide a custom CSV with the same columns.
- The look up table (LUT) csv has these columns: 'Region_ID', 'Side', 'Name', 'Abbr'
Usage for t-tests:
------------------
rstats_mean_IF_summary --order Control Treatment --labels Control Treatment -t ttest [-alt two-sided] [--lut CCFv3-2020__regionID_side_IDpath_region_abbr.csv] [-v]
Usage for Tukey's tests w/ reordering and renaming of conditions:
-----------------------------------------------------------------
rstats_mean_IF_summary --order group3 group2 group1 --labels Group_3 Group_2 Group_1 [--lut CCFv3-2020__regionID_side_IDpath_region_abbr.csv] [-v]
Usage with a custom atlas:
--------------------------
atlas=path/custom_atlas.nii.gz ; rstats_mean_IF_summary --region_ids $(img_unique -i $atlas) --order group2 group1 --labels Group_2 Group_1 -t ttest [-alt two-sided] [--lut CCFv3-2020__regionID_side_IDpath_region_abbr.csv] [-v]
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
import textwrap
from rich import print
from rich.traceback import install
from pathlib import Path
# from scipy.stats import ttest_ind, dunnett # dunnett is not available in scipy
from scipy.stats import ttest_ind
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from unravel.core.help_formatter import RichArgumentParser, SuppressMetavar, SM
from unravel.core.config import Configuration
from unravel.core.utils import log_command, verbose_start_msg, verbose_end_msg
[docs]
def parse_args():
parser = RichArgumentParser(formatter_class=SuppressMetavar, add_help=False, docstring=__doc__)
reqs = parser.add_argument_group('Required arguments')
reqs.add_argument('--order', nargs='*', help='Group Order for plotting (must match 1st word of CSVs)', action=SM)
reqs.add_argument('--labels', nargs='*', help='Group Labels in same order', action=SM)
opts = parser.add_argument_group('Optional arguments')
opts.add_argument('-t', '--test', help='Choose between "tukey", "dunnett" (ignore for now), and "ttest" post-hoc tests. (Default: tukey)', default='tukey', choices=['tukey', 'dunnett', 'ttest'], action=SM)
opts.add_argument('-alt', "--alternate", help="Number of tails and direction for Dunnett's test {'two-sided', 'less' (means < ctrl), 'greater'}. Default: two-sided", default='two-sided', action=SM)
opts.add_argument('--region_ids', nargs='*', type=int, help='List of region intensity IDs (Default: process all regions from the lut CSV)', action=SM)
opts.add_argument('-l', '--lut', help='LUT csv name (in unravel/core/csvs/). Default: CCFv3-2020__regionID_side_IDpath_region_abbr.csv', default="CCFv3-2020__regionID_side_IDpath_region_abbr.csv", action=SM)
general = parser.add_argument_group('General arguments')
general.add_argument('-v', '--verbose', help='Increase verbosity. Default: False', action='store_true', default=False)
return parser.parse_args()
# TODO: Dunnett's test is not available in scipy.stats. Find an alternative or implement it.
# TODO: Also output csv to summarise t-test/Tukey/Dunnett results like in ``cstats``. Make symbols transparent. Add option to pass in symbol colors for each group. Add ABA coloring to plots.
# TODO: CSVs are loaded for each region. It would be more efficient to load them once for processing all regions.
# TODO: Update coloring of plots to match ABA colors (i.e., use code from rstats_summary.py)
# TODO: Save a CSV with the results of the statistical tests for each region.
# Set Arial as the font
mpl.rcParams['font.family'] = 'Arial'
[docs]
def load_data(region_id):
data = []
# Load all CSVs in the directory
for filename in os.listdir():
if filename.endswith('.csv'):
group_name = filename.split("_")[0]
df = pd.read_csv(filename)
# Filter by the region ID
mean_intensity = df[df["Region_Intensity"] == region_id]["Mean_IF_Intensity"].values
if len(mean_intensity) > 0:
data.append({
'group': group_name,
'mean_intensity': mean_intensity[0]
})
if data:
return pd.DataFrame(data)
else:
raise ValueError(f" [red1]No data found for region ID {region_id}")
[docs]
def get_max_region_id_from_csvs():
"""Retrieve the maximum Region_Intensity from all input CSVs."""
max_region_id = -1
for filename in os.listdir():
if filename.endswith('.csv'):
df = pd.read_csv(filename)
max_id_in_file = df["Region_Intensity"].max()
if max_id_in_file > max_region_id:
max_region_id = max_id_in_file
return max_region_id
[docs]
def get_region_details(region_id, csv_path):
region_df = pd.read_csv(csv_path)
region_row = region_df[region_df["Region_ID"] == region_id].iloc[0]
return region_row["Region"], region_row["Abbr"]
[docs]
def get_all_region_ids(csv_path):
"""Retrieve all region IDs from the provided CSV."""
region_df = pd.read_csv(csv_path)
return region_df["Region_ID"].tolist()
[docs]
def filter_region_ids(region_ids, max_region_id):
"""Filter region IDs to be within the maximum region ID from the CSVs."""
return [region_id for region_id in region_ids if region_id <= max_region_id]
[docs]
def remove_zero_intensity_regions(region_ids):
"""Remove regions with Mean_IF_Intensity of 0 across all input CSVs."""
valid_region_ids = []
for region_id in region_ids:
all_zero = True
for filename in os.listdir():
if filename.endswith('.csv'):
df = pd.read_csv(filename)
mean_intensity = df[df["Region_Intensity"] == region_id]["Mean_IF_Intensity"].values
if len(mean_intensity) > 0 and mean_intensity[0] != 0:
all_zero = False
break
if not all_zero:
valid_region_ids.append(region_id)
return valid_region_ids
[docs]
def plot_data(region_id, order=None, labels=None, csv_path=None, test_type='tukey', alt='two-sided'):
df = load_data(region_id)
if 'group' not in df.columns:
raise KeyError(f" [red1]'group' column not found in the DataFrame for {region_id}. Ensure the CSV files contain the correct data.")
region_name, region_abbr = get_region_details(region_id, csv_path)
# Define a list of potential colors
predefined_colors = [
'#2D67C8', # blue
'#D32525', # red
'#27AF2E', # green
'#FFD700', # gold
'#FF6347', # tomato
'#8A2BE2', # blueviolet
# ... add more colors if needed
]
# Check if order is provided and slice the color list accordingly
if order:
selected_colors = predefined_colors[:len(order)]
group_colors = dict(zip(order, selected_colors))
else:
groups_in_df = df['group'].unique().tolist()
selected_colors = predefined_colors[:len(groups_in_df)]
group_colors = dict(zip(groups_in_df, selected_colors))
# If group order and labels are provided, update the DataFrame
if order and labels:
df['group'] = df['group'].astype(pd.CategoricalDtype(categories=order, ordered=True))
df = df.sort_values('group')
labels_mapping = dict(zip(order, labels))
df['group_label'] = df['group'].map(labels_mapping)
else:
df['group_label'] = df['group']
# Bar plot
plt.figure(figsize=(4, 4))
ax = sns.barplot(x='group_label', y='mean_intensity', data=df, color='white', errorbar=('se'), capsize=0.1, linewidth=2, edgecolor='black')
# Formatting
ax.set_ylabel('Mean IF Intensity', weight='bold')
ax.set_xticks(np.arange(len(df['group_label'].unique())))
ax.set_xticklabels(ax.get_xticklabels(), weight='bold')
ax.tick_params(axis='both', which='major', width=2)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_linewidth(2)
ax.spines['left'].set_linewidth(2)
# Swarm plot
sns.swarmplot(x='group_label', y='mean_intensity', hue='group', data=df, palette=group_colors, size=8, linewidth=1, edgecolor='black')
# Remove the legend created by hue
if ax.legend_:
ax.legend_.remove()
# Perform the chosen post-hoc test
if test_type == 'tukey':
test_results = pairwise_tukeyhsd(df['mean_intensity'], df['group']).summary()
test_df = pd.DataFrame(test_results.data[1:], columns=test_results.data[0])
elif test_type == 'dunnett':
# Assuming control is the first group in the order (change as needed)
control_data = df[df['group'] == order[0]]['mean_intensity'].values
experimental_data = [df[df['group'] == group]['mean_intensity'].values for group in order[1:]]
test_stats = dunnett(*experimental_data, control=control_data, alternative=alt)
# Convert the result to a DataFrame similar to the Tukey output for easier handling
test_df = pd.DataFrame({
'group1': [order[0]] * len(test_stats.pvalue),
'group2': order[1:],
'p-adj': test_stats.pvalue
})
test_df['reject'] = test_df['p-adj'] < 0.05
elif test_type == 'ttest':
test_df = perform_t_tests(df, order)
test_df['reject'] = test_df['p-adj'] < 0.05
significant_comparisons = test_df[test_df['reject'] == True]
# Calculate y-axis limits
y_max = df['mean_intensity'].max()
y_min = df['mean_intensity'].min()
height_diff = (y_max - y_min) * 0.1
y_pos = y_max + 0.5 * height_diff
groups = df['group'].unique()
for _, row in significant_comparisons.iterrows():
group1, group2 = row['group1'], row['group2']
x1 = np.where(groups == group1)[0][0]
x2 = np.where(groups == group2)[0][0]
plt.plot([x1, x1, x2, x2], [y_pos, y_pos + height_diff, y_pos + height_diff, y_pos], lw=1.5, c='black')
if row['p-adj'] < 0.0001:
sig = '****'
elif row['p-adj'] < 0.001:
sig = '***'
elif row['p-adj'] < 0.01:
sig = '**'
else:
sig = '*'
plt.text((x1+x2)*.5, y_pos + 0.8*height_diff, sig, horizontalalignment='center', size='xx-large', color='black', weight='bold')
y_pos += 3 * height_diff
# Ensure the y-axis starts from the minimum value, allowing for negative values
plt.ylim(y_min - 2 * height_diff, y_pos + 2 * height_diff)
ax.set_xlabel(None)
# Save the plot
output_folder = Path('regional_mean_IF_summary')
output_folder.mkdir(parents=True, exist_ok=True)
title = f"{region_name} ({region_abbr})"
wrapped_title = textwrap.fill(title, 42) # wraps at x characters. Adjust as needed.
plt.title(wrapped_title)
plt.tight_layout()
region_abbr = region_abbr.replace("/", "-") # Replace problematic characters for file paths
is_significant = not significant_comparisons.empty
file_prefix = '_' if is_significant else ''
file_name = f"{file_prefix}region_{region_id}_{region_abbr}.pdf"
plt.savefig(output_folder / file_name)
plt.close()
[docs]
@log_command
def main():
install()
args = parse_args()
Configuration.verbose = args.verbose
verbose_start_msg()
if (args.order and not args.labels) or (not args.order and args.labels):
raise ValueError("Both --order and --labels must be provided together.")
if args.order and args.labels and len(args.order) != len(args.labels):
raise ValueError("The number of entries in --order and --labels must match.")
# Print CSVs in the working dir
print(f'\n[bold]CSVs in the working dir to process (the first word defines the groups): \n')
for filename in os.listdir():
if filename.endswith('.csv'):
print(f' {filename}')
print()
# If region IDs are provided using -r, use them; otherwise, get all region IDs from the CSV
lut = Path(__file__).parent.parent / 'core' / 'csvs' / args.lut
region_ids_to_process = args.region_ids if args.region_ids else get_all_region_ids(lut)
# Filter region IDs based on max Region_Intensity in input CSVs
max_region_id = get_max_region_id_from_csvs()
region_ids_to_process = filter_region_ids(region_ids_to_process, max_region_id)
# Remove regions with Mean_IF_Intensity of 0 across all input CSVs
region_ids_to_process = remove_zero_intensity_regions(region_ids_to_process)
# Process each region ID
for region_id in region_ids_to_process:
plot_data(region_id, args.order, args.labels, csv_path=lut, test_type=args.test, alt=args.alternate)
verbose_end_msg()
if __name__ == '__main__':
main()