#!/usr/bin/env python3
import os
import subprocess
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd
import plotly.express as px
import streamlit as st
import yaml


def load_config(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as handle:
        return yaml.safe_load(handle)


def save_config(path: str, data: Dict[str, Any]) -> None:
    with open(path, "w", encoding="utf-8") as handle:
        yaml.safe_dump(data, handle, sort_keys=False)


def list_clinics(outputs_root: str) -> List[str]:
    if not os.path.isdir(outputs_root):
        return []
    return sorted([name for name in os.listdir(outputs_root) if os.path.isdir(os.path.join(outputs_root, name))])


def read_csv(path: str, date_cols: Optional[List[str]] = None) -> pd.DataFrame:
    if not os.path.exists(path):
        return pd.DataFrame()
    df = pd.read_csv(path)
    if date_cols:
        for col in date_cols:
            if col in df.columns:
                df[col] = pd.to_datetime(df[col], errors="coerce")
    return df


def compute_bp_trend(followups: pd.DataFrame, thresholds: Dict[str, Any]) -> pd.DataFrame:
    if followups.empty:
        return pd.DataFrame()
    followups = followups.copy()
    followups["visit_month"] = followups["visit_date"].dt.to_period("M").dt.to_timestamp()
    per_month = followups.sort_values(["patient_id", "visit_date"]).groupby(["patient_id", "visit_month"], as_index=False).tail(1)
    per_month["bp_with_values"] = per_month["sbp"].notna() & per_month["dbp"].notna()
    per_month["bp_controlled"] = (
        per_month["bp_with_values"]
        & (per_month["sbp"] < thresholds["bp_control_sbp"])
        & (per_month["dbp"] < thresholds["bp_control_dbp"])
    )
    trend = per_month.groupby("visit_month", as_index=False).agg(
        visits=("patient_id", "count"),
        bp_with_values=("bp_with_values", "sum"),
        bp_controlled=("bp_controlled", "sum"),
    )
    trend["bp_control_rate"] = trend["bp_controlled"] / trend["bp_with_values"].replace(0, np.nan)
    return trend


def compute_dm_control(row: pd.Series, thresholds: Dict[str, Any]) -> Optional[bool]:
    if pd.notna(row.get("hba1c")):
        return row["hba1c"] < thresholds["hba1c"]
    if pd.notna(row.get("t2hpp")):
        return row["t2hpp"] < thresholds["twopp"]
    if pd.notna(row.get("fbs")):
        return row["fbs"] < thresholds["fbs"]
    return None


def compute_dm_trend(followups: pd.DataFrame, thresholds: Dict[str, Any]) -> pd.DataFrame:
    if followups.empty:
        return pd.DataFrame()
    followups = followups.copy()
    followups["visit_month"] = followups["visit_date"].dt.to_period("M").dt.to_timestamp()
    per_month = followups.sort_values(["patient_id", "visit_date"]).groupby(["patient_id", "visit_month"], as_index=False).tail(1)
    per_month["dm_controlled"] = per_month.apply(lambda r: compute_dm_control(r, thresholds), axis=1)
    per_month["dm_with_values"] = per_month["dm_controlled"].notna()
    trend = per_month.groupby("visit_month", as_index=False).agg(
        visits=("patient_id", "count"),
        dm_with_values=("dm_with_values", "sum"),
        dm_controlled=("dm_controlled", "sum"),
    )
    trend["dm_control_rate"] = trend["dm_controlled"] / trend["dm_with_values"].replace(0, np.nan)
    return trend


def has_value(series: pd.Series) -> pd.Series:
    values = series.copy()
    as_text = values.astype(str).str.strip()
    return values.notna() & as_text.ne("") & as_text.ne("nan")


def compute_med_patterns(latest: pd.DataFrame) -> pd.DataFrame:
    med_fields = {
        "f_amlodipine_dose": "Amlodipine",
        "f_enalapril_dose": "Enalapril",
        "f_atorvastain_dose": "Atorvastatin",
        "f_hydrochlorothiazide_dose": "Hydrochlorothiazide",
        "f_aspirin_dose": "Aspirin",
        "f_metformin_500_dose": "Metformin 500",
        "f_metformin_1000_dose": "Metformin 1000",
        "f_gliclazide_500_dose": "Gliclazide 500",
        "f_gliclazide_1000_dose": "Gliclazide 1000",
    }

    rows = []
    flags = []
    for field, label in med_fields.items():
        if field not in latest.columns:
            continue
        flag = has_value(latest[field])
        flags.append(flag.astype(int))
        rows.append({"category": label, "count": int(flag.sum())})

    if not rows:
        return pd.DataFrame()

    med_counts = pd.DataFrame(rows)
    med_counts["share"] = med_counts["count"] / med_counts["count"].sum()

    regimen_size = pd.DataFrame({"regimen_size": pd.concat(flags, axis=1).sum(axis=1)})
    regimen_dist = regimen_size["regimen_size"].value_counts().sort_index().reset_index()
    regimen_dist.columns = ["category", "count"]
    regimen_dist["category"] = regimen_dist["category"].apply(lambda v: f"{int(v)} meds")
    regimen_dist["share"] = regimen_dist["count"] / regimen_dist["count"].sum()

    med_counts["metric"] = "medication"
    regimen_dist["metric"] = "regimen_size"
    return pd.concat([med_counts, regimen_dist], ignore_index=True)


def compute_visit_interval_distribution(followups: pd.DataFrame) -> pd.DataFrame:
    if followups.empty or "visit_date" not in followups.columns:
        return pd.DataFrame()

    visits = followups[["patient_id", "visit_date"]].dropna().copy()
    visits = visits.sort_values(["patient_id", "visit_date"])
    visits["prev_visit"] = visits.groupby("patient_id")["visit_date"].shift(1)
    visits["interval_days"] = (visits["visit_date"] - visits["prev_visit"]).dt.days
    intervals = visits["interval_days"].dropna()

    if intervals.empty:
        return pd.DataFrame()

    bins = [0, 30, 60, 90, 180, 365, 730, 10000]
    labels = ["0-30", "31-60", "61-90", "91-180", "181-365", "366-730", "731+"]
    interval_bins = pd.cut(intervals, bins=bins, labels=labels, include_lowest=True)
    dist = interval_bins.value_counts().sort_index().reset_index()
    dist.columns = ["interval_bin", "count"]
    dist["share"] = dist["count"] / dist["count"].sum()
    dist["median_days"] = intervals.median()
    dist["mean_days"] = intervals.mean()
    return dist


def main() -> None:
    st.set_page_config(page_title="NCD Analytics Dashboard", layout="wide")
    st.title("NCD Program Analytics")

    config_path = "config.yaml"
    config = load_config(config_path)
    outputs_root = config.get("outputs", {}).get("root", "outputs")
    clinics = list_clinics(outputs_root)

    st.sidebar.subheader("Analysis Controls")
    with st.sidebar.expander("Configure and Run", expanded=False):
        start_default = pd.to_datetime(config.get("date_range", {}).get("start_date"), errors="coerce")
        end_default = pd.to_datetime(config.get("date_range", {}).get("end_date"), errors="coerce")
        if pd.isna(start_default):
            start_default = pd.Timestamp.today().normalize()
        if pd.isna(end_default):
            end_default = pd.Timestamp.today().normalize()

        with st.form("analysis_config_form"):
            db_options = config.get("databases", [])
            selected_dbs = st.multiselect("Databases", db_options, default=db_options)
            start_date = st.date_input(
                "Start date",
                value=start_default.date(),
            )
            end_date = st.date_input(
                "End date",
                value=end_default.date(),
            )
            active_days_input = st.number_input("Active days", min_value=1, max_value=3650, value=int(config["thresholds"].get("active_days", 180)))
            ltfu_days_input = st.number_input("LTFU days", min_value=1, max_value=3650, value=int(config["thresholds"].get("ltfu_days", 90)))
            hba1c_months_input = st.number_input(
                "HbA1c lookback months",
                min_value=1,
                max_value=60,
                value=int(config["thresholds"].get("hba1c_lookback_months", 6)),
            )
            kidney_months_input = st.number_input(
                "Kidney lookback months",
                min_value=1,
                max_value=60,
                value=int(config["thresholds"].get("kidney_lookback_months", 12)),
            )
            cvd_months_input = st.number_input(
                "CVD risk lookback months",
                min_value=1,
                max_value=60,
                value=int(config["thresholds"].get("cvd_risk_lookback_months", 12)),
            )
            mask_ids_input = st.checkbox("Mask IDs in exports", value=bool(config.get("privacy", {}).get("mask_ids", True)))
            submitted = st.form_submit_button("Save config")

        if submitted:
            config["databases"] = selected_dbs
            config["date_range"]["start_date"] = start_date.strftime("%Y-%m-%d")
            config["date_range"]["end_date"] = end_date.strftime("%Y-%m-%d")
            config["thresholds"]["active_days"] = int(active_days_input)
            config["thresholds"]["ltfu_days"] = int(ltfu_days_input)
            config["thresholds"]["hba1c_lookback_months"] = int(hba1c_months_input)
            config["thresholds"]["kidney_lookback_months"] = int(kidney_months_input)
            config["thresholds"]["cvd_risk_lookback_months"] = int(cvd_months_input)
            config["privacy"]["mask_ids"] = bool(mask_ids_input)
            save_config(config_path, config)
            st.success("Config saved. Run analysis to refresh outputs.")

        if st.button("Run analysis now"):
            with st.spinner("Running analysis..."):
                result = subprocess.run(
                    ["php", "artisan", "ncd:analyze"],
                    cwd=os.path.dirname(os.path.abspath(__file__)),
                    capture_output=True,
                    text=True,
                )
            st.session_state["last_analysis_output"] = result.stdout + result.stderr
            if result.returncode == 0:
                st.success("Analysis completed. Refreshing data...")
                st.rerun()
            else:
                st.error("Analysis failed. See output below.")

        if "last_analysis_output" in st.session_state:
            st.text_area("Last analysis output", st.session_state["last_analysis_output"], height=200)

    if not clinics:
        st.warning("No outputs found. Run ncd_analysis.py first.")
        return

    clinic = st.sidebar.selectbox("Clinic", clinics, index=clinics.index("overall") if "overall" in clinics else 0)

    base_dir = os.path.join(outputs_root, clinic)
    patient_latest = read_csv(
        os.path.join(base_dir, "patient_latest.csv"),
        date_cols=["visit_date", "next_appointment", "dm_test_date"],
    )
    followups = read_csv(
        os.path.join(base_dir, "followups_clean.csv"),
        date_cols=["visit_date", "next_appointment", "lab_res_date", "fbs_test_date", "t2hpp_test_date"],
    )
    registers = read_csv(
        os.path.join(base_dir, "registers_clean.csv"),
        date_cols=["reg_date"],
    )
    other_medications = read_csv(os.path.join(base_dir, "other_medications.csv"))

    if patient_latest.empty:
        st.warning("patient_latest.csv is missing or empty.")
        return

    id_col = "patient_id" if "patient_id" in patient_latest.columns else "pid"

    min_date = followups["visit_date"].min() if not followups.empty else patient_latest["visit_date"].min()
    max_date = followups["visit_date"].max() if not followups.empty else patient_latest["visit_date"].max()

    cfg_start = pd.to_datetime(config.get("date_range", {}).get("start_date"), errors="coerce")
    cfg_end = pd.to_datetime(config.get("date_range", {}).get("end_date"), errors="coerce")
    default_start = min_date if pd.notna(min_date) else None
    default_end = max_date if pd.notna(max_date) else None
    if pd.notna(cfg_start) and default_start is not None:
        default_start = max(default_start, cfg_start)
    if pd.notna(cfg_end) and default_end is not None:
        default_end = min(default_end, cfg_end)

    date_range = st.sidebar.date_input(
        "Visit date range",
        value=(default_start.date() if default_start is not None else None, default_end.date() if default_end is not None else None),
    )

    genders = sorted([g for g in patient_latest.get("gender", pd.Series()).dropna().unique()])
    townships = sorted([t for t in patient_latest.get("township", pd.Series()).dropna().unique()])
    age_bands = sorted([a for a in patient_latest.get("age_band", pd.Series()).dropna().unique()])

    selected_gender = st.sidebar.multiselect("Sex", genders, default=genders)
    selected_age = st.sidebar.multiselect("Age bands", age_bands, default=age_bands)
    selected_township = st.sidebar.multiselect("Township", townships, default=townships)

    filtered_patients = patient_latest.copy()
    if selected_gender:
        filtered_patients = filtered_patients[filtered_patients["gender"].isin(selected_gender)]
    if selected_age:
        filtered_patients = filtered_patients[filtered_patients["age_band"].isin(selected_age)]
    if selected_township:
        filtered_patients = filtered_patients[filtered_patients["township"].isin(selected_township)]

    patient_ids = set(filtered_patients[id_col].dropna())

    filtered_followups = followups.copy()
    if patient_ids:
        filtered_followups = filtered_followups[filtered_followups[id_col].isin(patient_ids)]
    selected_start = None
    selected_end = None
    if isinstance(date_range, tuple) and len(date_range) == 2:
        selected_start, selected_end = date_range
        if selected_start and selected_end:
            filtered_followups = filtered_followups[
                (filtered_followups["visit_date"] >= pd.Timestamp(selected_start))
                & (filtered_followups["visit_date"] <= pd.Timestamp(selected_end))
            ]

    filtered_registers = registers.copy()
    if patient_ids:
        filtered_registers = filtered_registers[filtered_registers[id_col].isin(patient_ids)]

    thresholds = config["thresholds"]
    active_days = int(thresholds.get("active_days", 90))

    bp_rate = None
    dm_rate = None
    if "bp_with_values" in filtered_patients.columns:
        bp_rate = filtered_patients.loc[filtered_patients["bp_with_values"], "bp_controlled"].mean()
    if "dm_with_values" in filtered_patients.columns:
        dm_rate = filtered_patients.loc[filtered_patients["dm_with_values"], "dm_controlled"].mean()

    kpi1, kpi2, kpi3, kpi4, kpi5 = st.columns(5)
    kpi1.metric("Patients", f"{filtered_patients.shape[0]:,}")
    kpi2.metric(f"Active caseload (last {active_days}d)", f"{filtered_patients['active_patient'].sum():,}")
    kpi3.metric("BP control", f"{bp_rate:.1%}" if bp_rate is not None else "-")
    kpi4.metric("DM control", f"{dm_rate:.1%}" if dm_rate is not None else "-")
    kpi5.metric("LTFU", f"{filtered_patients['ltfu'].mean():.1%}" if filtered_patients.shape[0] else "-")

    tabs = st.tabs(
        [
            "Executive Summary",
            "Workload Trends",
            "BP Outcomes",
            "Diabetes Outcomes",
            "Continuity",
            "Quality of Care",
            "Other Medications",
            "Program Insights",
        ]
    )

    with tabs[0]:
        st.subheader("Summary")
        summary_rows = [
            {"metric": "bp_control_rate", "rate": bp_rate},
            {"metric": "dm_control_rate", "rate": dm_rate},
            {"metric": "ltfu_rate", "rate": filtered_patients["ltfu"].mean()},
            {"metric": f"active_caseload_last_{active_days}d", "rate": filtered_patients["active_patient"].mean()},
        ]
        st.dataframe(pd.DataFrame(summary_rows))

        if not filtered_patients.empty:
            fig = px.histogram(filtered_patients, x="age_at_report_end", nbins=20, title="Age distribution")
            st.plotly_chart(fig, use_container_width=True)

    with tabs[1]:
        st.subheader("Workload")
        if not filtered_registers.empty:
            reg_month = filtered_registers.copy()
            reg_month["month"] = reg_month["reg_date"].dt.to_period("M").dt.to_timestamp()
            reg_summary = reg_month.groupby("month", as_index=False).size().rename(columns={"size": "new_regs"})
        else:
            reg_summary = pd.DataFrame(columns=["month", "new_regs"])

        if not filtered_followups.empty:
            visit_month = filtered_followups.copy()
            visit_month["month"] = visit_month["visit_date"].dt.to_period("M").dt.to_timestamp()
            visit_summary = visit_month.groupby("month", as_index=False).size().rename(columns={"size": "followups"})
        else:
            visit_summary = pd.DataFrame(columns=["month", "followups"])

        combined = pd.merge(reg_summary, visit_summary, on="month", how="outer").fillna(0).sort_values("month")
        fig = px.line(combined, x="month", y=["new_regs", "followups"], markers=True, title="Monthly registrations and follow-ups")
        st.plotly_chart(fig, use_container_width=True)

    with tabs[2]:
        st.subheader("BP Outcomes")
        trend = compute_bp_trend(filtered_followups, thresholds)
        if not trend.empty:
            fig = px.line(trend, x="visit_month", y="bp_control_rate", markers=True, title="BP control rate trend")
            st.plotly_chart(fig, use_container_width=True)
        if not filtered_followups.empty:
            fig2 = px.histogram(filtered_followups, x="sbp", nbins=30, title="SBP distribution")
            st.plotly_chart(fig2, use_container_width=True)

    with tabs[3]:
        st.subheader("Diabetes Outcomes")
        dm_trend = compute_dm_trend(filtered_followups, thresholds)
        if not dm_trend.empty:
            fig = px.line(dm_trend, x="visit_month", y="dm_control_rate", markers=True, title="Diabetes control rate trend")
            st.plotly_chart(fig, use_container_width=True)

        if "dm_test_used" in filtered_patients.columns:
            test_counts = filtered_patients["dm_test_used"].value_counts(dropna=False).reset_index()
            test_counts.columns = ["test", "count"]
            fig2 = px.bar(test_counts, x="test", y="count", title="Latest diabetes test used")
            st.plotly_chart(fig2, use_container_width=True)

    with tabs[4]:
        st.subheader("Continuity")
        with st.expander("How continuity metrics are calculated"):
            st.markdown(
                f"""
**Missed appointment**  
Latest follow-up per patient: `Next_Appointment` is present **and** earlier than the report end date.

**Active caseload**  
Latest follow-up per patient: `visit_date` is within the last **{active_days} days** of the report end date.

**LTFU**  
Latest follow-up per patient: no visit recorded **or** `visit_date` is earlier than the report end date minus **{int(thresholds.get('ltfu_days', 90))} days**.

**Late visit**  
Latest follow-up per patient: `Late_visit` **or** `Late_follow` is marked yes.
"""
            )
        continuity = pd.DataFrame(
            {
                "metric": ["Late visits", "Missed appointments", "LTFU"],
                "count": [
                    filtered_patients["late_visit_flag"].sum(),
                    filtered_patients["missed_appointment"].sum(),
                    filtered_patients["ltfu"].sum(),
                ],
            }
        )
        fig = px.bar(continuity, x="metric", y="count", title="Continuity flags")
        st.plotly_chart(fig, use_container_width=True)

    with tabs[5]:
        st.subheader("Quality of Care")
        if not filtered_followups.empty:
            report_end = pd.Timestamp(selected_end) if selected_end else filtered_followups["visit_date"].max()
            hb_cutoff = report_end - pd.DateOffset(months=int(thresholds["hba1c_lookback_months"]))
            kidney_cutoff = report_end - pd.DateOffset(months=int(thresholds["kidney_lookback_months"]))
            cvd_cutoff = report_end - pd.DateOffset(months=int(thresholds["cvd_risk_lookback_months"]))

            q = filtered_followups.groupby("patient_id").apply(
                lambda g: pd.Series(
                    {
                        "hba1c_recent": g.loc[(g["visit_date"] >= hb_cutoff) & g["hba1c"].notna()].shape[0] > 0,
                        "creatinine_recent": g.loc[(g["visit_date"] >= kidney_cutoff) & g["creatinine"].notna()].shape[0] > 0,
                        "crcl_recent": g.loc[(g["visit_date"] >= kidney_cutoff) & g["crcl"].notna()].shape[0] > 0,
                        "uring_ac_ratio_recent": g.loc[(g["visit_date"] >= kidney_cutoff) & g["uring_ac_ratio"].notna()].shape[0] > 0,
                        "cvd_risk_recent": g.loc[(g["visit_date"] >= cvd_cutoff) & g["cvd_risk"].notna()].shape[0] > 0,
                    }
                )
            )
            q_summary = pd.DataFrame(
                {
                    "metric": [
                        "HbA1c last 6m",
                        "Creatinine last 12m",
                        "CRCL last 12m",
                        "Urine A/C ratio last 12m",
                        "CVD risk last 12m",
                    ],
                    "rate": [
                        q["hba1c_recent"].mean(),
                        q["creatinine_recent"].mean(),
                        q["crcl_recent"].mean(),
                        q["uring_ac_ratio_recent"].mean(),
                        q["cvd_risk_recent"].mean(),
                    ],
                }
            )
            fig = px.bar(q_summary, x="metric", y="rate", title="Quality coverage")
            fig.update_yaxes(tickformat=".0%")
            st.plotly_chart(fig, use_container_width=True)

    with tabs[6]:
        st.subheader("Other Medications (Foth_medi_spec)")
        if other_medications.empty:
            st.info("No other medication text found in this timeframe.")
        else:
            top_n = st.slider("Top N medications", min_value=5, max_value=50, value=20, step=5)
            display = other_medications.head(top_n)
            fig = px.bar(display, x="medication", y="count", title="Most common other medications")
            fig.update_layout(xaxis_tickangle=-45)
            st.plotly_chart(fig, use_container_width=True)
            st.dataframe(display)

    with tabs[7]:
        st.subheader("Therapy Patterns")
        med_patterns = compute_med_patterns(filtered_patients)
        if med_patterns.empty:
            st.info("Medication pattern fields are not available for this dataset.")
        else:
            meds = med_patterns[med_patterns["metric"] == "medication"].sort_values("count", ascending=False)
            fig = px.bar(meds, x="category", y="count", title="Most common medications (latest follow-up)")
            fig.update_layout(xaxis_tickangle=-45)
            st.plotly_chart(fig, use_container_width=True)
            regimens = med_patterns[med_patterns["metric"] == "regimen_size"]
            fig2 = px.bar(regimens, x="category", y="count", title="Number of medications per patient (latest follow-up)")
            st.plotly_chart(fig2, use_container_width=True)

        st.subheader("Visit Interval Distribution")
        interval_dist = compute_visit_interval_distribution(filtered_followups)
        if interval_dist.empty:
            st.info("Not enough visit dates to compute intervals.")
        else:
            fig = px.bar(interval_dist, x="interval_bin", y="count", title="Days between visits")
            st.plotly_chart(fig, use_container_width=True)
            st.caption(
                f"Median interval: {interval_dist['median_days'].iloc[0]:.0f} days | "
                f"Mean interval: {interval_dist['mean_days'].iloc[0]:.0f} days"
            )

        st.subheader("Risk-Outcome Linkage")
        if "cvd_risk_high" in filtered_patients.columns:
            risk_summary = filtered_patients.groupby("cvd_risk_high", dropna=False).agg(
                patients=("patient_id", "count"),
                bp_control_rate=("bp_controlled", "mean"),
                dm_control_rate=("dm_controlled", "mean"),
            ).reset_index()
            risk_summary["cvd_risk_high"] = risk_summary["cvd_risk_high"].map({True: "High", False: "Not high"}).fillna("Unknown")
            st.dataframe(risk_summary)
        if "ckd_marker" in filtered_patients.columns:
            ckd_summary = filtered_patients.groupby("ckd_marker", dropna=False).agg(
                patients=("patient_id", "count"),
                bp_control_rate=("bp_controlled", "mean"),
                dm_control_rate=("dm_controlled", "mean"),
            ).reset_index()
            ckd_summary["ckd_marker"] = ckd_summary["ckd_marker"].map({True: "CKD marker", False: "No marker"}).fillna("Unknown")
            st.dataframe(ckd_summary)

        st.subheader("Medication Change vs Control")
        if "med_changed_flag" in filtered_patients.columns:
            med_change = filtered_patients.groupby("med_changed_flag", dropna=False).agg(
                patients=("patient_id", "count"),
                bp_control_rate=("bp_controlled", "mean"),
                dm_control_rate=("dm_controlled", "mean"),
            ).reset_index()
            med_change["med_changed_flag"] = med_change["med_changed_flag"].map({True: "Changed", False: "Not changed"}).fillna("Unknown")
            st.dataframe(med_change)

        st.subheader("Equity Deep Dive")
        equity_sex = filtered_patients.groupby("gender", dropna=False).agg(
            patients=("patient_id", "count"),
            bp_control_rate=("bp_controlled", "mean"),
            dm_control_rate=("dm_controlled", "mean"),
            ltfu_rate=("ltfu", "mean"),
        ).reset_index()
        equity_age = filtered_patients.groupby("age_band", dropna=False).agg(
            patients=("patient_id", "count"),
            bp_control_rate=("bp_controlled", "mean"),
            dm_control_rate=("dm_controlled", "mean"),
            ltfu_rate=("ltfu", "mean"),
        ).reset_index()
        st.write("By sex")
        st.dataframe(equity_sex)
        st.write("By age band")
        st.dataframe(equity_age)

        st.subheader("Referral Outcomes")
        if "outcome" in filtered_patients.columns:
            outcome_counts = filtered_patients["outcome"].value_counts(dropna=False).reset_index()
            outcome_counts.columns = ["outcome", "count"]
            if not outcome_counts.empty:
                fig = px.bar(outcome_counts, x="outcome", y="count", title="Outcomes (latest follow-up)")
                fig.update_layout(xaxis_tickangle=-45)
                st.plotly_chart(fig, use_container_width=True)
        else:
            st.info("Outcome field not available in this dataset.")

        if "ncd_tout_icmv_location" in filtered_patients.columns:
            referral_counts = filtered_patients["ncd_tout_icmv_location"].value_counts(dropna=False).reset_index()
            referral_counts.columns = ["referral_location", "count"]
            if not referral_counts.empty:
                fig2 = px.bar(referral_counts, x="referral_location", y="count", title="Referral locations")
                fig2.update_layout(xaxis_tickangle=-45)
                st.plotly_chart(fig2, use_container_width=True)

        if "tout_mam_clinic" in filtered_patients.columns:
            tout_counts = filtered_patients["tout_mam_clinic"].value_counts(dropna=False).reset_index()
            tout_counts.columns = ["tout_mam_clinic", "count"]
            if not tout_counts.empty:
                fig3 = px.bar(tout_counts, x="tout_mam_clinic", y="count", title="Transfers out (Tout_mam_clinic)")
                fig3.update_layout(xaxis_tickangle=-45)
                st.plotly_chart(fig3, use_container_width=True)



if __name__ == "__main__":
    main()
