Source code for unravel.abca.sunburst.sunburst_expression

#!/usr/bin/env python3

"""
Use ``abca_sunburst_expression`` or ``sbe`` from UNRAVEL to calculate mean expression for all cell types in the ABCA and make a sunburst plot.

Prereqs: 
    - merfish_filter.py and merfish_join_expression_data.py
    - Or: RNAseq_expression_in_mice.py and RNAseq_filter.py

Note:
    - LUT location: unravel/core/csvs/ABCA/ABCA_sunburst_colors.csv

Next steps:
    - Use input_sunburst.csv to make a sunburst plot or regional volumes in Flourish Studio (https://app.flourish.studio/)
    - It can be pasted into the Data tab (categories columns = cell type columns, Size by = percent column)
    - Preview tab: Hierarchy -> Depth to 5, Colors -> paste content of ..._colors.csv into Custom overrides

Usage:
------ 
    abca_sunburst_expression -i path/VTA_DA_cells_Th_expression.csv -g gene [-o path/out_dir] [-n] [-v]
"""

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
import shutil
from pathlib import Path
from rich import print
from rich.traceback import install

from unravel.core.help_formatter import RichArgumentParser, SuppressMetavar, SM
from unravel.core.config import Configuration 
from unravel.core.utils import log_command, verbose_start_msg, verbose_end_msg


[docs] def parse_args(): parser = RichArgumentParser(formatter_class=SuppressMetavar, add_help=False, docstring=__doc__) reqs = parser.add_argument_group('Required arguments') reqs.add_argument('-i', '--input', help='path/cells_filtered_exp.csv', required=True, action=SM) reqs.add_argument('-g', '--gene', help='Gene to analyze', required=True, action=SM) opts = parser.add_argument_group('Optional args') opts.add_argument('-n', '--neurons', help='Filter out non-neuronal cells. Default: False', action='store_true', default=False) opts.add_argument('-c', '--color_max', help='Maximum value for the color scale. Default: 10', default=10, type=float, action=SM) opts.add_argument('-t', '--threshold', help='Log2(CPM+1) threshold for percent gene expression. Default: 6', default=6, type=float, action=SM) opts.add_argument('-o', '--output', help='Output dir path. Default: ABCA_sunburst_cmax10_thr6/', default=None, action=SM) opts.add_argument('-a', '--all', help='Save mean expression and percent expressing for all cells. Default: False', action='store_true', default=False) general = parser.add_argument_group('General arguments') general.add_argument('-v', '--verbose', help='Increase verbosity. Default: False', action='store_true', default=False) return parser.parse_args()
[docs] @log_command def main(): install() args = parse_args() Configuration.verbose = args.verbose verbose_start_msg() # Load the CSV file cols = ['neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', args.gene] cells_df = pd.read_csv(args.input, usecols=cols) # Replace blank values in 'neurotransmitter' column with 'NA' cells_df['neurotransmitter'] = cells_df['neurotransmitter'].fillna('NA') if args.neurons: cells_df = cells_df[cells_df['class'].str.split().str[0].astype(int) <= 29] # Groupby cluster to calculate the percentage of cells in each cluster cluster_df = cells_df.groupby('cluster').size().reset_index(name='counts') # Count the number of cells in each cluster cluster_df = cluster_df.sort_values('counts', ascending=False) # Sort the clusters by the number of cells # Add a column for the percentage of cells in each cluster cluster_df['percent'] = cluster_df['counts'] / cluster_df['counts'].sum() * 100 # Drop the 'counts' column cluster_df = cluster_df.drop(columns='counts') # Join the cells_df with the cluster_df cells_df = cells_df.merge(cluster_df, on='cluster') # Drop duplicate rows cells_df = cells_df.drop_duplicates() # Sort by percentage cells_df = cells_df.sort_values('percent', ascending=False).reset_index(drop=True) # Calculate the mean expression and percent expressing for all cells in cells_df all_mean = cells_df[args.gene].mean() all_percent = (cells_df[args.gene] > args.threshold).mean() * 100 # Create the output directory if args.output is None: output_dir = Path(args.input).parent / f'ABCA_sunburst_cmax{args.color_max}_thr{args.threshold}' else: output_dir = Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) # Save the mean expression and percent expressing for all cells (.txt) if args.all: output_path = output_dir / str(Path(args.input).name).replace('.csv', f'_sunburst_expression_thr{args.threshold}_all.txt') with open(output_path, 'w') as f: f.write(f"all_mean: {all_mean}\nall_percent (threshold: {args.threshold}): {all_percent}") print(f"\nSaved mean expression and percent expressing for all cells to {output_path}") # Calculate mean expression and percent expressing at each hierarchy level summary_df = cells_df.copy() hierarchy_levels = ['neurotransmitter', 'class', 'subclass', 'supertype', 'cluster'] for level in hierarchy_levels: summary_df[f'{level}_mean'] = summary_df[level].map(cells_df.groupby(level)[args.gene].mean()) summary_df[f'{level}_percent'] = summary_df[level].map(cells_df.groupby(level)[args.gene].apply(lambda x: (x > args.threshold).mean() * 100)) summary_df = summary_df.drop(columns=[args.gene]).drop_duplicates() # Save the results output_path = output_dir / str(Path(args.input).name).replace('.csv', f'_sunburst_expression_thr{args.threshold}.csv') summary_df.to_csv(output_path, index=False) print(f"\nSaved sunburst expression summary to {output_path}") # Stack labels and values for LUT files label_stack = pd.DataFrame() for level in hierarchy_levels: label_stack = pd.concat([label_stack, summary_df[level].rename('label')], axis=0) # Stack mean expression values and construct the mean expression LUT mean_stack = pd.DataFrame() for level in hierarchy_levels: mean_stack = pd.concat([mean_stack, summary_df[f'{level}_mean'].rename('value')], axis=0) mean_df = pd.concat([label_stack, mean_stack], axis=1) # Combine the label stack and the mean stack mean_df = mean_df.drop_duplicates() mean_df.columns = ['label', 'value'] # Replace the mean value with the hex color (magma_r) mean_df['color'] = mean_df['value'].apply(lambda x: mcolors.rgb2hex(plt.cm.magma_r((x - 0) / (args.color_max - 0)))) mean_df = mean_df.drop(columns=['value']) # Save the mean expression LUT mean_path = str(output_path).replace(f'_expression_thr{args.threshold}.csv', '_mean_expression_lut.txt') for row in mean_df.itertuples(index=False): with open(mean_path, 'a') as f: f.write(f"{row.label}: {row.color}\n") # Stack percent expression values percent_stack = pd.DataFrame() for level in hierarchy_levels: percent_stack = pd.concat([percent_stack, summary_df[f'{level}_percent'].rename('value')], axis=0) percent_df = pd.concat([label_stack, percent_stack], axis=1) percent_df = percent_df.drop_duplicates() percent_df.columns = ['label', 'value'] # Replace the percent value with the hex color (viridis_r) percent_df['color'] = percent_df['value'].apply(lambda x: mcolors.rgb2hex(plt.cm.viridis_r((x - 0) / (100 - 0)))) percent_df = percent_df.drop(columns=['value']) # Save the percent expression LUT percent_path = str(output_path).replace(f'_expression_thr{args.threshold}.csv', f'_percent_expression_thr{args.threshold}_lut.txt') for row in percent_df.itertuples(index=False): with open(percent_path, 'a') as f: f.write(f"{row.label}: {row.color}\n") lut_path = Path(__file__).parent.parent.parent.parent / 'unravel' / 'core' / 'csvs' / 'ABCA' / 'ABCA_sunburst_colors.csv' shutil.copy(lut_path, output_path.parent / lut_path.name) verbose_end_msg()
if __name__ == '__main__': main()