#!/usr/bin/env python3
"""
Use ``cstats_reshape`` (``reshape``) from UNRAVEL to export raw cluster validation data as long, wide and per-cluster CSVs.
Prereqs:
- ``cstats_validation``, ``cstats_org_data``, ``cstats_group_data``, ``utils_prepend``
- or ``cstats_mean_IF`` for reshaping mean IF intensity data (metric column autodetected if cols: sample, cluster_ID, <metric_col> are present)
Input files:
- `*_data.csv` from ``cstats_validation`` after condition prefixes were prepended
e.g., saline_cell_density_data.csv, drug_cell_density_data.csv
Outputs:
- _reshaped/raw_data_long.csv
- _reshaped/raw_data_wide.csv
- _reshaped/by_cluster/cluster_<cluster_ID>__<value_name>.csv
Note:
- If hemisphere-specific CSVs are present (e.g. cluster_1_LH.csv and cluster_1_RH.csv), the script will attempt to pool the data.
Usage:
------
cstats_reshape -g saline drug1 drug2 -i '`*cell_density_data.csv`'
Usage with combined groups:
---------------------------
cstats_reshape -g saline MBDB MDAI RMDMA SMDMA --combine `entactogens=MBDB+MDAI+RMDMA+SMDMA`
Usage for mean_IF data:
-----------------------
cstats_reshape -g AwS AwP -o _reshaped_mean_IF
"""
import pandas as pd
from pathlib import Path
from rich import print
from rich.traceback import install
from unravel.cluster_stats.cstats import detect_metric_schema, get_matching_input_csvs, cluster_validation_data_df
from unravel.core.help_formatter import RichArgumentParser, SuppressMetavar, SM
from unravel.core.config import Configuration
from unravel.core.utils import log_command, verbose_start_msg, verbose_end_msg
[docs]
def parse_args():
parser = RichArgumentParser(formatter_class=SuppressMetavar, add_help=False, docstring=__doc__)
reqs = parser.add_argument_group('Required arguments')
reqs.add_argument('-g', '--groups', help='Group/condition prefixes from CSV filenames, in desired output order.', nargs='*', required=True, action=SM)
opts = parser.add_argument_group('Optional args')
opts.add_argument('-i', '--input', help="CSV paths or glob patterns. Default: '*_data.csv'", nargs='*', default=['*.csv'], action=SM)
opts.add_argument('-o', '--outdir', help="Output directory. Default: _reshaped", default="_reshaped", action=SM)
opts.add_argument("-vn", "--value_name", help="Name to use for the metric value column. Default: cell_density", default="cell_density", action=SM)
opts.add_argument("-sn", "--support_name", help="Name to use for the support/count column. Default: cell_count", default="cell_count", action=SM)
opts.add_argument("-c", "--combine", help="Optional combined per-cluster columns, e.g. drug1+drug2 or ent=drug1+drug2", nargs="*", default=[], action=SM)
general = parser.add_argument_group('General arguments')
general.add_argument('-v', '--verbose', help='Increase verbosity. Default: False', action='store_true', default=False)
return parser.parse_args()
[docs]
def detect_simple_metric_schema(first_df):
required_cols = {"sample", "cluster_ID"}
if not required_cols.issubset(first_df.columns):
raise ValueError("Not a simple metric CSV.")
metric_cols = [
c for c in first_df.columns
if c not in required_cols
]
if len(metric_cols) != 1:
raise ValueError(
f"Could not infer metric column. Expected exactly one column besides "
f"{sorted(required_cols)}, found: {metric_cols}"
)
return metric_cols[0]
[docs]
def simple_metric_data_df(csv_files, groups, metric_col):
rows = []
for file in csv_files:
file = Path(file)
condition = file.name.split("_sample")[0]
if condition not in groups:
continue
df = pd.read_csv(file)
if metric_col not in df.columns:
continue
df["condition"] = condition
rows.append(
df[["condition", "sample", "cluster_ID", metric_col]]
)
if not rows:
return pd.DataFrame()
return pd.concat(rows, ignore_index=True)
[docs]
@log_command
def main():
install()
args = parse_args()
Configuration.verbose = args.verbose
verbose_start_msg()
current_dir = Path.cwd()
csv_files = get_matching_input_csvs(current_dir, args.groups)
if not csv_files:
print(f" [red1]No matching input CSVs found for groups: {' '.join(args.groups)}")
return
first_df = pd.read_csv(csv_files[0])
try:
schema = detect_metric_schema(first_df)
is_simple_metric = False
metric_col = None
except ValueError:
try:
metric_col = detect_simple_metric_schema(first_df)
is_simple_metric = True
except ValueError as e:
print(f"[red1]Error: {e}")
return
if is_simple_metric:
data_df = simple_metric_data_df(
csv_files=csv_files,
groups=args.groups,
metric_col=metric_col,
)
else:
# Check if any files contain hemisphere indicators
has_hemisphere = any('_LH.csv' in str(file.name) or '_RH.csv' in str(file.name) for file in csv_files)
# Aggregate the data from all .csv files and pool the data if hemispheres are present
data_df = cluster_validation_data_df(
metric_name=schema['metric_name'],
value_col=schema['value_col'],
support_col=schema['support_col'],
support_type=schema['support_type'],
aggregation_method=schema['aggregation_method'],
has_hemisphere=has_hemisphere,
csv_files=csv_files,
groups=args.groups,
)
if data_df.empty:
print(" [red1]No data rows found after aggregation.")
return
metric_col = args.value_name
rename_map = {}
if "value" in data_df.columns:
rename_map["value"] = metric_col
if "support" in data_df.columns:
rename_map["support"] = args.support_name
data_df = data_df.rename(columns=rename_map)
data_df = data_df.dropna(axis=1, how="all")
if metric_col not in data_df.columns:
print(f"[red1]Metric column not found after aggregation: {metric_col}")
print(f"Columns: {list(data_df.columns)}")
return
outdir = Path(args.outdir)
cluster_outdir = outdir / "by_cluster"
cluster_outdir.mkdir(parents=True, exist_ok=True)
long_path = outdir / f"{metric_col}_long.csv"
data_df.to_csv(long_path, index=False)
# All-clusters wide format:
# cluster_ID, saline_sample01, saline_sample02, drug_sample01, ...
wide_df = data_df.copy()
wide_df["_wide_col"] = wide_df["condition"].astype(str) + "_" + wide_df["sample"].astype(str)
wide_df = (
wide_df.pivot_table(
index="cluster_ID",
columns="_wide_col",
values=metric_col,
aggfunc="first",
)
.reset_index()
.sort_values("cluster_ID")
)
wide_df.columns.name = None
ordered_cols = ["cluster_ID"]
for group in args.groups:
samples = sorted(data_df.loc[data_df["condition"] == group, "sample"].astype(str).unique())
ordered_cols.extend([f"{group}_{sample}" for sample in samples])
ordered_cols = [c for c in ordered_cols if c in wide_df.columns]
remaining_cols = [c for c in wide_df.columns if c not in ordered_cols]
wide_df = wide_df[ordered_cols + remaining_cols]
wide_path = outdir / f"{metric_col}_wide.csv"
wide_df.to_csv(wide_path, index=False)
# Optional combined columns for per-cluster CSVs.
combined_groups = []
for spec in args.combine:
if "=" in spec:
col_name, group_str = spec.split("=", 1)
else:
col_name, group_str = spec, spec
combined_groups.append((col_name, group_str.split("+")))
# One CSV per cluster with stacked replicates for each group and combined groups if specified, e.g.:
# saline, drug1, drug2, drug1+drug2
for cluster_id, cluster_df in data_df.groupby("cluster_ID", sort=True):
max_len = max(
[len(cluster_df.loc[cluster_df["condition"] == group]) for group in args.groups] +
[len(cluster_df.loc[cluster_df["condition"].isin(combo_groups)]) for _, combo_groups in combined_groups] +
[1]
)
cluster_out_df = pd.DataFrame(index=range(max_len))
for group in args.groups:
values = (
cluster_df.loc[cluster_df["condition"] == group]
.sort_values("sample")[metric_col]
.reset_index(drop=True)
)
cluster_out_df[group] = values
for col_name, combo_groups in combined_groups:
values = (
cluster_df.loc[cluster_df["condition"].isin(combo_groups)]
.sort_values(["condition", "sample"])[metric_col]
.reset_index(drop=True)
)
cluster_out_df[col_name] = values
safe_cluster_id = "".join(
ch if ch.isalnum() or ch in "._-" else "_"
for ch in str(cluster_id)
)
cluster_path = cluster_outdir / f"cluster_{safe_cluster_id}__{metric_col}.csv"
cluster_out_df.to_csv(cluster_path, index=False)
if args.verbose:
print("\n[blue]Aggregated long data:[/]")
print(data_df)
print("\n[blue]Wide data:[/]")
print(wide_df)
print(f"\n Long CSV: {long_path}")
print(f" Wide CSV: {wide_path}")
print(f" Per-cluster CSVs: {cluster_outdir}")
verbose_end_msg()
if __name__ == '__main__':
main()