#!/usr/bin/env python3
import argparse
import datetime as dt
import hashlib
import json
import logging
import os
import re
import subprocess
import sys
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd
import yaml
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sqlalchemy import create_engine

VIEW_ENCRYPTED_ALIAS = {
    "patients": {
        "Date of Birth": "date_of_birth_raw",
    },
    "ncd_pt_registers": {},
    "ncd_followups": {},
}

REQUIRED_VIEWS = [
    "v_ncd_base_join",
    "v_ncd_followups_clean",
    "v_ncd_latest_followup_per_patient",
    "v_ncd_monthly_summary_seed",
]


def setup_logging(level: str) -> None:
    logging.basicConfig(
        level=getattr(logging, level.upper(), logging.INFO),
        format="%(asctime)s %(levelname)s %(message)s",
    )


def load_config(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as handle:
        return yaml.safe_load(handle)


def apply_env_overrides(config: Dict[str, Any]) -> Dict[str, Any]:
    mysql_cfg = dict(config.get("mysql", {}))
    env_map = {
        "DB_HOST": "host",
        "DB_PORT": "port",
        "DB_USERNAME": "user",
        "DB_PASSWORD": "password",
        "NCD_DB_HOST": "host",
        "NCD_DB_PORT": "port",
        "NCD_DB_USER": "user",
        "NCD_DB_PASSWORD": "password",
    }
    for env_key, cfg_key in env_map.items():
        value = os.getenv(env_key)
        if value:
            mysql_cfg[cfg_key] = value

    config["mysql"] = mysql_cfg

    laravel_path = os.getenv("NCD_LARAVEL_PATH")
    if laravel_path:
        config["laravel_path"] = laravel_path

    return config


def build_engine(mysql_cfg: Dict[str, Any], database: str) -> Engine:
    url = URL.create(
        drivername="mysql+pymysql",
        username=mysql_cfg["user"],
        password=mysql_cfg["password"],
        host=mysql_cfg["host"],
        port=int(mysql_cfg.get("port", 3306)),
        database=database,
    )
    return create_engine(
        url,
        pool_pre_ping=True,
        pool_recycle=3600,
        connect_args={"connect_timeout": int(mysql_cfg.get("connect_timeout", 10))},
    )


def read_sql_file(path: str) -> str:
    with open(path, "r", encoding="utf-8") as handle:
        return handle.read()


def split_sql_statements(sql_text: str) -> List[str]:
    statements = []
    buffer: List[str] = []
    for line in sql_text.splitlines():
        stripped = line.strip()
        if stripped.startswith("--") or stripped == "":
            continue
        buffer.append(line)
        if stripped.endswith(";"):
            statements.append("\n".join(buffer))
            buffer = []
    if buffer:
        statements.append("\n".join(buffer))
    return statements


def fetch_table_columns(conn, table_name: str) -> List[str]:
    rows = conn.execute(text(f"SHOW COLUMNS FROM `{table_name}`")).fetchall()
    return [row[0] for row in rows]


def build_base_join_view_sql(register_cols: List[str], patient_cols: List[str]) -> str:
    def select_col(table_alias: str, column: str, alias: str) -> str:
        if column in register_cols and table_alias == "r":
            return f"{table_alias}.`{column}` AS {alias}"
        if column in patient_cols and table_alias == "p":
            return f"{table_alias}.`{column}` AS {alias}"
        return f"NULL AS {alias}"

    select_parts = [
        select_col("r", "Clinic_code", "clinic_code"),
        select_col("r", "Pid", "pid"),
        select_col("r", "FuchiaID", "fuchia_id"),
        select_col("r", "Gender", "gender"),
        select_col("r", "Reg_Date", "reg_date"),
        select_col("r", "Area_Division", "area_division"),
        select_col("r", "Township", "township"),
        select_col("r", "visit_Age", "visit_age"),
        select_col("r", "Current_Age", "current_age"),
        select_col("r", "1stBP", "first_bp"),
        select_col("r", "1stBP_date", "first_bp_date"),
        select_col("r", "2ndBP", "second_bp"),
        select_col("r", "2ndBP_date", "second_bp_date"),
        select_col("r", "3rdBP", "third_bp"),
        select_col("r", "3rdBP_date", "third_bp_date"),
        select_col("r", "1stHypertension", "first_hypertension"),
        select_col("r", "1st_DiagDate", "first_diag_date"),
        select_col("r", "staging_Hypertension", "staging_hypertension"),
        select_col("r", "1st_tot_Diabetes", "first_dm_test_type"),
        select_col("r", "1st_RBS", "first_dm_value"),
        select_col("r", "1st_RBS_date", "first_dm_date"),
        select_col("r", "2nd_tot_Diabetes", "second_dm_test_type"),
        select_col("r", "2nd_RBS", "second_dm_value"),
        select_col("r", "2nd_RBS_date", "second_dm_date"),
        select_col("r", "2nd_Hypertension", "second_hypertension"),
        select_col("r", "2nd_DiagDate", "second_diag_date"),
        select_col("p", "Date of Birth", "date_of_birth_raw"),
    ]

    return (
        "CREATE OR REPLACE VIEW v_ncd_base_join AS\n"
        "SELECT\n    "
        + ",\n    ".join(select_parts)
        + "\nFROM ncd_pt_registers r\n"
        "LEFT JOIN patients p ON p.`Pid` = r.`Pid`;"
    )


def build_followups_view_sql(followup_cols: List[str]) -> str:
    def select_col(column: str, alias: str) -> str:
        if column in followup_cols:
            return f"f.`{column}` AS {alias}"
        return f"NULL AS {alias}"

    bp_col = "own_clinic_Bp"
    bp_raw_expr = select_col(bp_col, "bp_raw")
    if bp_col in followup_cols:
        sbp_raw_expr = (
            "CASE\n"
            "            WHEN f.`own_clinic_Bp` REGEXP '^[[:space:]]*[0-9]{2,3}[[:space:]]*/[[:space:]]*[0-9]{2,3}[[:space:]]*$'\n"
            "                THEN CAST(SUBSTRING_INDEX(TRIM(f.`own_clinic_Bp`), '/', 1) AS UNSIGNED)\n"
            "            ELSE NULL\n"
            "        END AS sbp_raw"
        )
        dbp_raw_expr = (
            "CASE\n"
            "            WHEN f.`own_clinic_Bp` REGEXP '^[[:space:]]*[0-9]{2,3}[[:space:]]*/[[:space:]]*[0-9]{2,3}[[:space:]]*$'\n"
            "                THEN CAST(SUBSTRING_INDEX(TRIM(f.`own_clinic_Bp`), '/', -1) AS UNSIGNED)\n"
            "            ELSE NULL\n"
            "        END AS dbp_raw"
        )
    else:
        sbp_raw_expr = "NULL AS sbp_raw"
        dbp_raw_expr = "NULL AS dbp_raw"

    rbs_col = "RBS result"
    if rbs_col in followup_cols:
        rbs_expr = select_col(rbs_col, "rbs_result")
    elif "RBS_result" in followup_cols:
        rbs_expr = select_col("RBS_result", "rbs_result")
    else:
        rbs_expr = "NULL AS rbs_result"

    select_parts = [
        select_col("id", "followup_id"),
        select_col("Clinic_code", "clinic_code"),
        select_col("Pid", "pid"),
        select_col("FuchiaID", "fuchia_id"),
        select_col("Visit_date", "visit_date"),
        select_col("Reg_Date", "reg_date"),
        select_col("Agey", "visit_age"),
        select_col("Gender", "gender"),
        select_col("Area_Division", "area_division"),
        select_col("Township", "township"),
        select_col("NCD_Diagnosis", "ncd_diagnosis"),
        select_col("Type_cur_visit", "type_cur_visit"),
        select_col("Late_visit", "late_visit"),
        select_col("Late_duration", "late_duration"),
        select_col("Late_duration_unit", "late_duration_unit"),
        select_col("Late_follow", "late_follow"),
        select_col("Late_fol_duration", "late_fol_duration"),
        select_col("Time", "visit_time"),
        bp_raw_expr,
        select_col("own_Bp_Stage", "bp_stage"),
        select_col("ncdV_1st_tot_Diabetes", "dm_1st_total"),
        select_col("FBS", "fbs"),
        select_col("FBS_test_date", "fbs_test_date"),
        select_col("Loaction_test", "fbs_test_location"),
        select_col("ncdV_2nd_tot_Diabetes", "dm_2nd_total"),
        select_col("2HPP", "t2hpp"),
        select_col("2HPP_test_date", "t2hpp_test_date"),
        select_col("Loaction_Test2", "t2hpp_test_location"),
        select_col("Lab_res_Date", "lab_res_date"),
        select_col("Alt", "alt"),
        select_col("HBA1C", "hba1c"),
        select_col("Uring_AC_ratio", "uring_ac_ratio"),
        select_col("Glucose_res", "glucose_res"),
        select_col("Protein_res", "protein_res"),
        select_col("Creatinine", "creatinine"),
        select_col("Creat_unit", "creat_unit"),
        select_col("CRCL", "crcl"),
        select_col("Total_cholesterol", "total_cholesterol"),
        select_col("Total_cho_Unit", "total_cholesterol_unit"),
        select_col("CVD_Risk", "cvd_risk"),
        select_col("HDL", "hdl"),
        select_col("HDL_unit", "hdl_unit"),
        select_col("LDL", "ldl"),
        select_col("LDL_unit", "ldl_unit"),
        select_col("Triglyceride", "triglyceride"),
        select_col("Triglyceride_unit", "triglyceride_unit"),
        select_col("Pulse", "pulse"),
        select_col("Pulse_rate", "pulse_rate"),
        select_col("Diabetic_foot", "diabetic_foot"),
        select_col("Diabetic_Neuropathy", "diabetic_neuropathy"),
        select_col("Lifestyle advice", "lifestyle_advice"),
        select_col("Medication changed", "medication_changed"),
        select_col("Patient_adhe medic", "patient_adherence"),
        select_col("Drug_Supply", "drug_supply"),
        select_col("F_Amlodipine_dose", "f_amlodipine_dose"),
        select_col("F_Enalapril_dose", "f_enalapril_dose"),
        select_col("F_Atorvastain_dose", "f_atorvastain_dose"),
        select_col("F_Hydrochlorothiazide_dose", "f_hydrochlorothiazide_dose"),
        select_col("F_Aspirin_dose", "f_aspirin_dose"),
        select_col("F_Metformin(500)_dose", "f_metformin_500_dose"),
        select_col("F_Metformin(1000)_dose", "f_metformin_1000_dose"),
        select_col("F_Gliclazide(500)_dose", "f_gliclazide_500_dose"),
        select_col("F_Gliclazide(1000)_dose", "f_gliclazide_1000_dose"),
        select_col("Symptom hypoglycemia", "symptom_hypoglycemia"),
        select_col("Foth_medi", "other_medication"),
        select_col("Foth_medi_spec", "other_medication_spec"),
        select_col("Out_come", "outcome"),
        select_col("Tout_mam_clinic", "tout_mam_clinic"),
        select_col("death_date", "death_date"),
        select_col("Tout_physician_data", "tout_physician_data"),
        select_col("Ncd_Tout_icmv_location", "ncd_tout_icmv_location"),
        select_col("Cause_of_death", "cause_of_death"),
        select_col("Fup_doc_initial", "fup_doc_initial"),
        select_col("Next_Appointment", "next_appointment"),
        select_col("visit_type", "visit_type"),
        rbs_expr,
        sbp_raw_expr,
        dbp_raw_expr,
    ]

    return (
        "CREATE OR REPLACE VIEW v_ncd_followups_clean AS\n"
        "SELECT\n    f_raw.*,\n"
        "    CASE\n"
        "        WHEN f_raw.sbp_raw BETWEEN 50 AND 300 AND f_raw.dbp_raw BETWEEN 30 AND 200 THEN f_raw.sbp_raw\n"
        "        ELSE NULL\n"
        "    END AS sbp,\n"
        "    CASE\n"
        "        WHEN f_raw.sbp_raw BETWEEN 50 AND 300 AND f_raw.dbp_raw BETWEEN 30 AND 200 THEN f_raw.dbp_raw\n"
        "        ELSE NULL\n"
        "    END AS dbp\n"
        "FROM (\n"
        "    SELECT\n        "
        + ",\n        ".join(select_parts)
        + "\n    FROM ncd_followups f\n) AS f_raw;"
    )


def ensure_views(engine: Engine, sql_path: str) -> None:
    sql_text = read_sql_file(sql_path)
    statements = split_sql_statements(sql_text)
    with engine.begin() as conn:
        register_cols = fetch_table_columns(conn, "ncd_pt_registers")
        patient_cols = fetch_table_columns(conn, "patients")
        followup_cols = fetch_table_columns(conn, "ncd_followups")

        conn.execute(text(build_base_join_view_sql(register_cols, patient_cols)))
        conn.execute(text(build_followups_view_sql(followup_cols)))

        for statement in statements:
            if "VIEW v_ncd_base_join" in statement or "VIEW v_ncd_followups_clean" in statement:
                continue
            conn.execute(text(statement))


def read_view(engine: Engine, view_name: str) -> pd.DataFrame:
    return pd.read_sql_query(f"SELECT * FROM {view_name}", engine)


def parse_date_series(series: pd.Series) -> pd.Series:
    return pd.to_datetime(series, errors="coerce")


def to_numeric(series: pd.Series) -> pd.Series:
    return pd.to_numeric(series, errors="coerce")


def is_yes(value: Any) -> bool:
    if value is None:
        return False
    if isinstance(value, (int, float)):
        return value == 1
    text_val = str(value).strip().lower()
    return text_val in {"yes", "y", "true", "1", "late", "on"}


def compute_age_years(dob: Optional[pd.Timestamp], ref_date: Optional[pd.Timestamp], divisor: float) -> Optional[int]:
    if isinstance(dob, str):
        dob = pd.to_datetime(dob, errors="coerce")
    if isinstance(ref_date, str):
        ref_date = pd.to_datetime(ref_date, errors="coerce")
    if isinstance(dob, np.datetime64):
        dob = pd.to_datetime(dob, errors="coerce")
    if isinstance(ref_date, np.datetime64):
        ref_date = pd.to_datetime(ref_date, errors="coerce")

    if dob is None or pd.isna(dob) or ref_date is None or pd.isna(ref_date):
        return None
    delta_days = (ref_date - dob).days
    if delta_days < 0:
        return None
    return int(np.floor(delta_days / divisor))


def assign_age_band(age: Optional[float], bands: List[List[int]]) -> Optional[str]:
    if age is None or pd.isna(age):
        return None
    for low, high in bands:
        if low <= age <= high:
            return f"{low}-{high}"
    return None


def hash_id(value: Any, salt: str) -> Optional[str]:
    if value is None or pd.isna(value):
        return None
    raw = f"{salt}:{value}".encode("utf-8")
    return hashlib.sha256(raw).hexdigest()


def mask_ids(df: pd.DataFrame, fields: List[str], salt: str) -> pd.DataFrame:
    masked = df.copy()
    for field in fields:
        if field in masked.columns:
            masked[field] = masked[field].apply(lambda v: hash_id(v, salt))
    return masked


def make_json_safe(value: Any) -> Any:
    if value is None:
        return None
    if isinstance(value, (pd.Timestamp, dt.datetime, dt.date)):
        if pd.isna(value):
            return None
        return value.isoformat()
    if isinstance(value, (np.integer, np.floating)):
        if pd.isna(value):
            return None
        return value.item()
    if isinstance(value, float) and (pd.isna(value) or np.isinf(value)):
        return None
    if pd.isna(value):
        return None
    return value


def decrypt_rows_via_laravel(
    rows: List[Dict[str, Any]],
    fields_to_decrypt: List[str],
    laravel_path: str,
    batch_size: int = 200,
) -> Tuple[List[Dict[str, Any]], int]:
    if not rows or not fields_to_decrypt:
        return rows, 0

    decryptor_path = os.path.join(laravel_path, "tools", "laravel_decryptor.php")
    if not os.path.exists(decryptor_path):
        logging.warning("Decryptor not found at %s", decryptor_path)
        return rows, 0

    decrypted_rows: List[Dict[str, Any]] = []
    error_count = 0

    for i in range(0, len(rows), batch_size):
        batch = rows[i : i + batch_size]
        safe_batch = [
            {key: make_json_safe(val) for key, val in row.items()} if isinstance(row, dict) else row
            for row in batch
        ]
        payload = json.dumps({"rows": safe_batch, "fields_to_decrypt": fields_to_decrypt})

        try:
            result = subprocess.run(
                ["php", decryptor_path],
                input=payload.encode("utf-8"),
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                check=False,
            )
        except FileNotFoundError:
            logging.error("PHP binary not found; skipping decryption")
            return rows, 0

        if result.returncode != 0:
            logging.error("Decryptor failed: %s", result.stderr.decode("utf-8", errors="ignore"))
            return rows, 0

        output = json.loads(result.stdout.decode("utf-8"))
        batch_rows = output.get("rows", [])
        for row in batch_rows:
            if isinstance(row, dict) and row.get("_decrypt_errors"):
                error_count += 1
        decrypted_rows.extend(batch_rows)

    return decrypted_rows, error_count


def decrypt_dataframe(
    df: pd.DataFrame,
    table_key: str,
    encrypted_columns_cfg: Dict[str, List[str]],
    laravel_path: str,
    column_alias_map: Dict[str, Dict[str, str]],
) -> Tuple[pd.DataFrame, int]:
    encrypted_fields = encrypted_columns_cfg.get(table_key, [])
    if not encrypted_fields or df.empty:
        return df, 0

    alias_map = column_alias_map.get(table_key, {})
    fields_to_decrypt = []
    for field in encrypted_fields:
        if field in df.columns:
            fields_to_decrypt.append(field)
        elif field in alias_map and alias_map[field] in df.columns:
            fields_to_decrypt.append(alias_map[field])

    if not fields_to_decrypt:
        return df, 0

    rows = df.to_dict(orient="records")
    decrypted_rows, error_count = decrypt_rows_via_laravel(rows, fields_to_decrypt, laravel_path)
    decrypted_df = pd.DataFrame(decrypted_rows)
    return decrypted_df, error_count


def compute_dm_control(row: pd.Series, thresholds: Dict[str, Any]) -> Tuple[Optional[bool], Optional[str], Optional[float], Optional[pd.Timestamp], Optional[str]]:
    hba1c = row.get("hba1c")
    t2hpp = row.get("t2hpp")
    fbs = row.get("fbs")
    rbs = row.get("rbs_result")

    if pd.notna(hba1c):
        return (
            hba1c < thresholds["hba1c"],
            "HbA1c",
            float(hba1c),
            row.get("lab_res_date"),
            None,
        )
    if pd.notna(t2hpp):
        return (
            t2hpp < thresholds["twopp"],
            "2HPP",
            float(t2hpp),
            row.get("t2hpp_test_date"),
            row.get("t2hpp_test_location"),
        )
    if pd.notna(fbs):
        return (
            fbs < thresholds["fbs"],
            "FBS",
            float(fbs),
            row.get("fbs_test_date"),
            row.get("fbs_test_location"),
        )
    if pd.notna(rbs):
        rbs_threshold = thresholds.get("rbs", 200)
        return (
            rbs < rbs_threshold,
            "RBS",
            float(rbs),
            row.get("lab_res_date") or row.get("visit_date"),
            None,
        )
    return (None, None, None, None, None)


def prepare_followups(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df
    df = df.copy()
    date_cols = [
        "visit_date",
        "reg_date",
        "fbs_test_date",
        "t2hpp_test_date",
        "lab_res_date",
        "next_appointment",
        "death_date",
    ]
    for col in date_cols:
        if col in df.columns:
            df[col] = parse_date_series(df[col])

    numeric_cols = [
        "sbp",
        "dbp",
        "sbp_raw",
        "dbp_raw",
        "fbs",
        "t2hpp",
        "hba1c",
        "rbs_result",
        "uring_ac_ratio",
        "creatinine",
        "crcl",
        "cvd_risk",
    ]
    for col in numeric_cols:
        if col in df.columns:
            df[col] = to_numeric(df[col])

    df["visit_month"] = df["visit_date"].dt.to_period("M").dt.to_timestamp()
    return df


def prepare_registers(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df
    df = df.copy()
    if "reg_date" in df.columns:
        df["reg_date"] = parse_date_series(df["reg_date"])
    if "first_bp_date" in df.columns:
        df["first_bp_date"] = parse_date_series(df["first_bp_date"])
    return df


def compute_patient_ids(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df["patient_id"] = df["pid"].where(df["pid"].notna(), df.get("fuchia_id"))
    return df


def compute_latest_followups(followups: pd.DataFrame) -> pd.DataFrame:
    if followups.empty:
        return followups
    sorted_df = followups.sort_values(["pid", "visit_date", "followup_id"], ascending=[True, False, False])
    return sorted_df.drop_duplicates(subset=["pid"], keep="first")


def compute_monthly_summary(seed_df: pd.DataFrame, start_date: pd.Timestamp, end_date: pd.Timestamp) -> pd.DataFrame:
    if seed_df.empty:
        return seed_df
    seed_df["month_start"] = parse_date_series(seed_df["month_start"])
    monthly = seed_df.groupby("month_start", as_index=False)[["new_regs", "followups"]].sum()
    mask = (monthly["month_start"] >= start_date) & (monthly["month_start"] <= end_date)
    return monthly.loc[mask].sort_values("month_start")


def compute_bp_control_trend(followups: pd.DataFrame, thresholds: Dict[str, Any]) -> pd.DataFrame:
    if followups.empty:
        return followups
    per_month = followups.sort_values(["pid", "visit_date"]).groupby(["pid", "visit_month"], as_index=False).tail(1)
    per_month["bp_with_values"] = per_month["sbp"].notna() & per_month["dbp"].notna()
    per_month["bp_controlled"] = (
        per_month["bp_with_values"]
        & (per_month["sbp"] < thresholds["bp_control_sbp"])
        & (per_month["dbp"] < thresholds["bp_control_dbp"])
    )
    trend = per_month.groupby("visit_month", as_index=False).agg(
        visits=("pid", "count"),
        bp_with_values=("bp_with_values", "sum"),
        bp_controlled=("bp_controlled", "sum"),
    )
    trend["bp_control_rate"] = trend["bp_controlled"] / trend["bp_with_values"].replace(0, np.nan)
    return trend


def compute_dm_trend(followups: pd.DataFrame, thresholds: Dict[str, Any]) -> pd.DataFrame:
    if followups.empty:
        return followups
    per_month = followups.sort_values(["pid", "visit_date"]).groupby(["pid", "visit_month"], as_index=False).tail(1)
    dm_results = per_month.apply(lambda row: compute_dm_control(row, thresholds), axis=1, result_type="expand")
    per_month[["dm_controlled", "dm_test_used", "dm_test_value", "dm_test_date", "dm_test_location"]] = dm_results
    per_month["dm_with_values"] = per_month["dm_controlled"].notna()
    trend = per_month.groupby("visit_month", as_index=False).agg(
        visits=("pid", "count"),
        dm_with_values=("dm_with_values", "sum"),
        dm_controlled=("dm_controlled", "sum"),
    )
    trend["dm_control_rate"] = trend["dm_controlled"] / trend["dm_with_values"].replace(0, np.nan)
    return trend


def compute_quality_metrics(
    followups: pd.DataFrame,
    report_end: pd.Timestamp,
    thresholds: Dict[str, Any],
) -> pd.DataFrame:
    if followups.empty:
        return pd.DataFrame()

    hb_cutoff = report_end - pd.DateOffset(months=int(thresholds["hba1c_lookback_months"]))
    kidney_cutoff = report_end - pd.DateOffset(months=int(thresholds["kidney_lookback_months"]))
    cvd_cutoff = report_end - pd.DateOffset(months=int(thresholds["cvd_risk_lookback_months"]))

    def has_recent(group: pd.DataFrame, col: str, cutoff: pd.Timestamp) -> bool:
        return group.loc[(group["visit_date"] >= cutoff) & group[col].notna()].shape[0] > 0

    quality = followups.groupby("pid").apply(
        lambda g: pd.Series(
            {
                "hba1c_recent": has_recent(g, "hba1c", hb_cutoff),
                "creatinine_recent": has_recent(g, "creatinine", kidney_cutoff),
                "crcl_recent": has_recent(g, "crcl", kidney_cutoff),
                "uring_ac_ratio_recent": has_recent(g, "uring_ac_ratio", kidney_cutoff),
                "cvd_risk_recent": has_recent(g, "cvd_risk", cvd_cutoff),
            }
        )
    )
    summary = pd.DataFrame(
        {
            "metric": [
                "hba1c_recent",
                "creatinine_recent",
                "crcl_recent",
                "uring_ac_ratio_recent",
                "cvd_risk_recent",
            ],
            "count": [
                quality["hba1c_recent"].sum(),
                quality["creatinine_recent"].sum(),
                quality["crcl_recent"].sum(),
                quality["uring_ac_ratio_recent"].sum(),
                quality["cvd_risk_recent"].sum(),
            ],
            "total_patients": [quality.shape[0]] * 5,
        }
    )
    summary["rate"] = summary["count"] / summary["total_patients"].replace(0, np.nan)
    return summary


def compute_data_quality(
    registers: pd.DataFrame,
    followups: pd.DataFrame,
    decrypt_failures: Dict[str, int],
) -> pd.DataFrame:
    metrics = []

    if not registers.empty:
        metrics.append({
            "metric": "register_rows",
            "value": len(registers),
        })
        metrics.append({
            "metric": "register_missing_dob",
            "value": registers["date_of_birth_raw"].isna().sum() if "date_of_birth_raw" in registers.columns else 0,
        })
        metrics.append({
            "metric": "register_duplicate_pid",
            "value": registers["pid"].duplicated().sum() if "pid" in registers.columns else 0,
        })

    if not followups.empty:
        metrics.append({
            "metric": "followup_rows",
            "value": len(followups),
        })
        metrics.append({
            "metric": "followup_missing_bp",
            "value": followups[["sbp", "dbp"]].isna().any(axis=1).sum(),
        })
        metrics.append({
            "metric": "followup_invalid_bp",
            "value": followups["sbp_raw"].notna().sum() - followups["sbp"].notna().sum(),
        })

    for key, value in decrypt_failures.items():
        metrics.append({
            "metric": f"decrypt_failures_{key}",
            "value": value,
        })

    return pd.DataFrame(metrics)


def normalize_medication_label(
    text: str,
    aliases: Dict[str, str],
    stopwords: Iterable[str],
) -> Optional[str]:
    lowered = text.lower().strip()
    if lowered in {"nan", "none", "nil", "no"}:
        return None

    words = re.findall(r"[a-z]+", lowered)
    stop_set = {w.lower() for w in stopwords}
    candidate = None
    for word in words:
        if word in stop_set:
            continue
        candidate = word
        break
    if not candidate:
        return None
    return aliases.get(candidate, candidate)


def extract_medication_tokens(
    value: Any,
    aliases: Dict[str, str],
    stopwords: Iterable[str],
) -> List[str]:
    if value is None or pd.isna(value):
        return []
    text = str(value)
    text = text.replace("\r", ",").replace("\n", ",")
    text = re.sub(r"[;/|]+", ",", text)
    parts = [re.sub(r"\s+", " ", part).strip() for part in text.split(",")]
    tokens = []
    for part in parts:
        if not part:
            continue
        label = normalize_medication_label(part, aliases, stopwords)
        if label:
            tokens.append(label)
    return tokens


def compute_other_med_distribution(
    followups: pd.DataFrame,
    other_med_cfg: Dict[str, Any],
) -> pd.DataFrame:
    if followups.empty or "other_medication_spec" not in followups.columns:
        return pd.DataFrame(columns=["medication", "count", "share_of_mentions"])

    aliases = {
        str(k).lower(): str(v).lower()
        for k, v in (other_med_cfg.get("aliases") or {}).items()
        if k and v
    }
    stopwords = other_med_cfg.get("stopwords") or []

    tokens: List[str] = []
    for value in followups["other_medication_spec"].dropna():
        tokens.extend(extract_medication_tokens(value, aliases, stopwords))

    if not tokens:
        return pd.DataFrame(columns=["medication", "count", "share_of_mentions"])

    counts = pd.Series(tokens).value_counts().reset_index()
    counts.columns = ["medication", "count"]
    counts["share_of_mentions"] = counts["count"] / counts["count"].sum()
    return counts


def compute_metrics_for_clinic(
    registers: pd.DataFrame,
    followups: pd.DataFrame,
    latest_followups: pd.DataFrame,
    monthly_seed: pd.DataFrame,
    config: Dict[str, Any],
    report_start: pd.Timestamp,
    report_end: pd.Timestamp,
) -> Dict[str, pd.DataFrame]:
    thresholds = config["thresholds"]
    age_cfg = config["age"]

    registers = compute_patient_ids(registers)
    followups = compute_patient_ids(followups)
    latest_followups = compute_patient_ids(latest_followups)

    if "date_of_birth_raw" in registers.columns:
        registers["date_of_birth"] = parse_date_series(registers["date_of_birth_raw"])
    else:
        registers["date_of_birth"] = pd.NaT

    registers["age_at_reg"] = registers.apply(
        lambda r: compute_age_years(r.get("date_of_birth"), r.get("reg_date"), age_cfg["year_divisor"]), axis=1
    )
    registers["age_source"] = np.where(registers["date_of_birth"].notna(), "dob", None)
    registers.loc[registers["age_at_reg"].isna() & registers["current_age"].notna(), "age_at_reg"] = registers["current_age"]
    registers.loc[registers["age_at_reg"].isna() & registers["current_age"].notna(), "age_source"] = "current_age"
    registers.loc[registers["age_at_reg"].isna() & registers["visit_age"].notna(), "age_at_reg"] = registers["visit_age"]
    registers.loc[registers["age_at_reg"].isna() & registers["visit_age"].notna(), "age_source"] = "visit_age"

    latest_followups = latest_followups.merge(
        registers[["pid", "date_of_birth", "gender", "township", "area_division", "clinic_code", "age_at_reg"]],
        on="pid",
        how="left",
        suffixes=("", "_reg"),
    )

    latest_followups["age_at_visit"] = latest_followups.apply(
        lambda r: compute_age_years(r.get("date_of_birth"), r.get("visit_date"), age_cfg["year_divisor"]), axis=1
    )
    latest_followups.loc[latest_followups["age_at_visit"].isna() & latest_followups["visit_age"].notna(), "age_at_visit"] = latest_followups["visit_age"]
    latest_followups["age_at_report_end"] = latest_followups.apply(
        lambda r: compute_age_years(r.get("date_of_birth"), report_end, age_cfg["year_divisor"]), axis=1
    )
    latest_followups.loc[latest_followups["age_at_report_end"].isna() & latest_followups["age_at_reg"].notna(), "age_at_report_end"] = latest_followups["age_at_reg"]

    latest_followups["age_band"] = latest_followups["age_at_report_end"].apply(lambda a: assign_age_band(a, age_cfg["bands"]))

    latest_followups["bp_with_values"] = latest_followups["sbp"].notna() & latest_followups["dbp"].notna()
    latest_followups["bp_controlled"] = (
        latest_followups["bp_with_values"]
        & (latest_followups["sbp"] < thresholds["bp_control_sbp"])
        & (latest_followups["dbp"] < thresholds["bp_control_dbp"])
    )

    dm_results = latest_followups.apply(lambda row: compute_dm_control(row, thresholds), axis=1, result_type="expand")
    latest_followups[["dm_controlled", "dm_test_used", "dm_test_value", "dm_test_date", "dm_test_location"]] = dm_results
    latest_followups["dm_with_values"] = latest_followups["dm_controlled"].notna()

    monthly_summary = compute_monthly_summary(monthly_seed, report_start, report_end)
    bp_trend = compute_bp_control_trend(followups, thresholds)
    dm_trend = compute_dm_trend(followups, thresholds)

    active_cutoff = report_end - pd.Timedelta(days=int(thresholds["active_days"]))
    latest_followups["active_patient"] = latest_followups["visit_date"].notna() & (latest_followups["visit_date"] >= active_cutoff)

    ltfu_cutoff = report_end - pd.Timedelta(days=int(thresholds["ltfu_days"]))
    latest_followups["ltfu"] = latest_followups["visit_date"].isna() | (latest_followups["visit_date"] < ltfu_cutoff)

    latest_followups["missed_appointment"] = (
        latest_followups["next_appointment"].notna()
        & (latest_followups["next_appointment"] < report_end)
    )

    latest_followups["late_visit_flag"] = latest_followups["late_visit"].apply(is_yes) | latest_followups["late_follow"].apply(is_yes)

    continuity = pd.DataFrame(
        {
            "metric": ["active_caseload", "ltfu", "missed_appointments", "late_visit"],
            "count": [
                latest_followups["active_patient"].sum(),
                latest_followups["ltfu"].sum(),
                latest_followups["missed_appointment"].sum(),
                latest_followups["late_visit_flag"].sum(),
            ],
            "total_patients": [latest_followups.shape[0]] * 4,
        }
    )
    continuity["rate"] = continuity["count"] / continuity["total_patients"].replace(0, np.nan)

    quality = compute_quality_metrics(followups, report_end, thresholds)

    def cvd_high(val: Any) -> bool:
        if pd.isna(val):
            return False
        try:
            return float(val) >= float(thresholds["cvd_risk_high"])
        except (TypeError, ValueError):
            return "high" in str(val).lower()

    latest_followups["cvd_risk_high"] = latest_followups["cvd_risk"].apply(cvd_high)
    latest_followups["ckd_marker"] = (
        (latest_followups["creatinine"].notna() & (latest_followups["creatinine"] > thresholds["creatinine_high"]))
        | (latest_followups["crcl"].notna() & (latest_followups["crcl"] < thresholds["crcl_low"]))
        | (latest_followups["uring_ac_ratio"].notna() & (latest_followups["uring_ac_ratio"] >= thresholds["uring_ac_ratio_high"]))
    )

    latest_followups["diabetic_foot_flag"] = latest_followups["diabetic_foot"].apply(is_yes)
    latest_followups["neuropathy_flag"] = latest_followups["diabetic_neuropathy"].apply(is_yes)
    latest_followups["hypoglycemia_flag"] = latest_followups["symptom_hypoglycemia"].apply(is_yes)

    risk = pd.DataFrame(
        {
            "metric": ["cvd_risk_high", "ckd_marker", "diabetic_foot", "neuropathy", "hypoglycemia"],
            "count": [
                latest_followups["cvd_risk_high"].sum(),
                latest_followups["ckd_marker"].sum(),
                latest_followups["diabetic_foot_flag"].sum(),
                latest_followups["neuropathy_flag"].sum(),
                latest_followups["hypoglycemia_flag"].sum(),
            ],
            "total_patients": [latest_followups.shape[0]] * 5,
        }
    )
    risk["rate"] = risk["count"] / risk["total_patients"].replace(0, np.nan)

    latest_followups["med_changed_flag"] = latest_followups["medication_changed"].apply(is_yes)

    operations = pd.DataFrame(
        {
            "metric": ["medication_changed"],
            "count": [latest_followups["med_changed_flag"].sum()],
            "total_patients": [latest_followups.shape[0]],
        }
    )
    operations["rate"] = operations["count"] / operations["total_patients"].replace(0, np.nan)

    adherence_dist = latest_followups["patient_adherence"].value_counts(dropna=False).reset_index()
    adherence_dist.columns = ["category", "count"]
    adherence_dist["metric"] = "patient_adherence"

    drug_supply_dist = latest_followups["drug_supply"].value_counts(dropna=False).reset_index()
    drug_supply_dist.columns = ["category", "count"]
    drug_supply_dist["metric"] = "drug_supply"

    operations_dist = pd.concat([adherence_dist, drug_supply_dist], ignore_index=True)
    other_med_cfg = config.get("other_medication", {})
    other_med_dist = compute_other_med_distribution(followups, other_med_cfg)

    equity_bp = latest_followups.groupby(["gender", "age_band", "township"], dropna=False).agg(
        patients=("pid", "count"),
        bp_controlled=("bp_controlled", "sum"),
        bp_with_values=("bp_with_values", "sum"),
    ).reset_index()
    equity_bp["bp_control_rate"] = equity_bp["bp_controlled"] / equity_bp["bp_with_values"].replace(0, np.nan)

    equity_dm = latest_followups.groupby(["gender", "age_band", "township"], dropna=False).agg(
        patients=("pid", "count"),
        dm_controlled=("dm_controlled", "sum"),
        dm_with_values=("dm_with_values", "sum"),
    ).reset_index()
    equity_dm["dm_control_rate"] = equity_dm["dm_controlled"] / equity_dm["dm_with_values"].replace(0, np.nan)

    bp_rate = latest_followups.loc[latest_followups["bp_with_values"], "bp_controlled"].mean()
    dm_rate = latest_followups.loc[latest_followups["dm_with_values"], "dm_controlled"].mean()

    kpi_summary = pd.DataFrame(
        {
            "metric": [
                "patients",
                "active_caseload",
                "bp_control_rate",
                "dm_control_rate",
                "ltfu_rate",
            ],
            "value": [
                latest_followups.shape[0],
                latest_followups["active_patient"].sum(),
                bp_rate,
                dm_rate,
                latest_followups["ltfu"].mean(),
            ],
        }
    )

    return {
        "registers": registers,
        "followups": followups,
        "latest_followups": latest_followups,
        "monthly_summary": monthly_summary,
        "bp_trend": bp_trend,
        "dm_trend": dm_trend,
        "continuity": continuity,
        "quality": quality,
        "risk": risk,
        "operations": operations,
        "operations_dist": operations_dist,
        "other_med_dist": other_med_dist,
        "equity_bp": equity_bp,
        "equity_dm": equity_dm,
        "kpi_summary": kpi_summary,
    }


def export_csv(df: pd.DataFrame, path: str) -> None:
    df.to_csv(path, index=False)


def main() -> int:
    parser = argparse.ArgumentParser(description="NCD analytics pipeline")
    parser.add_argument("--config", required=True, help="Path to config.yaml")
    parser.add_argument("--views", default="metrics_views.sql", help="Path to metrics_views.sql")
    args = parser.parse_args()

    config = apply_env_overrides(load_config(args.config))
    setup_logging(config.get("logging", {}).get("level", "INFO"))

    report_start = pd.to_datetime(config["date_range"]["start_date"])
    report_end = pd.to_datetime(config["date_range"]["end_date"])

    outputs_root = config.get("outputs", {}).get("root", "outputs")
    os.makedirs(outputs_root, exist_ok=True)

    encrypted_cfg = config.get("encrypted_columns", {})
    laravel_path = config["laravel_path"]

    all_registers = []
    all_followups = []
    all_latest = []
    all_monthly = []
    all_quality = []
    all_equity_bp = []
    all_equity_dm = []
    all_kpis = []

    for db_name in config["databases"]:
        logging.info("Processing %s", db_name)
        engine = build_engine(config["mysql"], db_name)
        ensure_views(engine, args.views)

        registers = prepare_registers(read_view(engine, "v_ncd_base_join"))
        followups = prepare_followups(read_view(engine, "v_ncd_followups_clean"))
        latest_followups = prepare_followups(read_view(engine, "v_ncd_latest_followup_per_patient"))
        monthly_seed = read_view(engine, "v_ncd_monthly_summary_seed")

        decrypt_failures = {"patients": 0, "ncd_pt_registers": 0, "ncd_followups": 0}

        registers, decrypt_failures["patients"] = decrypt_dataframe(
            registers,
            "patients",
            encrypted_cfg,
            laravel_path,
            VIEW_ENCRYPTED_ALIAS,
        )

        followups, decrypt_failures["ncd_followups"] = decrypt_dataframe(
            followups,
            "ncd_followups",
            encrypted_cfg,
            laravel_path,
            VIEW_ENCRYPTED_ALIAS,
        )

        latest_followups, _ = decrypt_dataframe(
            latest_followups,
            "ncd_followups",
            encrypted_cfg,
            laravel_path,
            VIEW_ENCRYPTED_ALIAS,
        )

        registers = compute_patient_ids(registers)
        followups = compute_patient_ids(followups)
        latest_followups = compute_patient_ids(latest_followups)

        registers["source_db"] = db_name
        followups["source_db"] = db_name
        latest_followups["source_db"] = db_name

        metrics = compute_metrics_for_clinic(
            registers,
            followups,
            latest_followups,
            monthly_seed,
            config,
            report_start,
            report_end,
        )

        quality_report = compute_data_quality(registers, followups, decrypt_failures)

        clinic_dir = os.path.join(outputs_root, db_name)
        os.makedirs(clinic_dir, exist_ok=True)

        mask_ids_enabled = config.get("privacy", {}).get("mask_ids", True)
        id_salt = config.get("privacy", {}).get("id_salt", "")

        stake_registers = mask_ids(metrics["registers"], ["pid", "patient_id", "fuchia_id"], id_salt) if mask_ids_enabled else metrics["registers"]
        stake_followups = mask_ids(metrics["followups"], ["pid", "patient_id", "fuchia_id"], id_salt) if mask_ids_enabled else metrics["followups"]
        stake_latest = mask_ids(metrics["latest_followups"], ["pid", "patient_id", "fuchia_id"], id_salt) if mask_ids_enabled else metrics["latest_followups"]

        export_csv(stake_registers, os.path.join(clinic_dir, "registers_clean.csv"))
        export_csv(stake_followups, os.path.join(clinic_dir, "followups_clean.csv"))
        export_csv(stake_latest, os.path.join(clinic_dir, "patient_latest.csv"))
        export_csv(metrics["monthly_summary"], os.path.join(clinic_dir, "monthly_summary.csv"))
        export_csv(metrics["bp_trend"], os.path.join(clinic_dir, "bp_control_trend.csv"))
        export_csv(metrics["dm_trend"], os.path.join(clinic_dir, "dm_control_trend.csv"))
        export_csv(metrics["continuity"], os.path.join(clinic_dir, "continuity_metrics.csv"))
        export_csv(metrics["quality"], os.path.join(clinic_dir, "quality_metrics.csv"))
        export_csv(metrics["risk"], os.path.join(clinic_dir, "risk_metrics.csv"))
        export_csv(metrics["operations"], os.path.join(clinic_dir, "operations_metrics.csv"))
        export_csv(metrics["operations_dist"], os.path.join(clinic_dir, "operations_distributions.csv"))
        export_csv(metrics["other_med_dist"], os.path.join(clinic_dir, "other_medications.csv"))
        export_csv(metrics["equity_bp"], os.path.join(clinic_dir, "equity_bp_control.csv"))
        export_csv(metrics["equity_dm"], os.path.join(clinic_dir, "equity_dm_control.csv"))
        export_csv(metrics["kpi_summary"], os.path.join(clinic_dir, "kpi_summary.csv"))
        export_csv(quality_report, os.path.join(clinic_dir, "data_quality_report.csv"))

        doctor_dir = os.path.join(clinic_dir, "doctor")
        os.makedirs(doctor_dir, exist_ok=True)
        keep_full_ids = config.get("privacy", {}).get("keep_full_ids_in_doctor_lists", True)
        doctor_latest = metrics["latest_followups"] if keep_full_ids else stake_latest

        doctor_lists = doctor_latest.loc[
            doctor_latest["ltfu"] | doctor_latest["missed_appointment"],
            [
                "pid",
                "patient_id",
                "fuchia_id",
                "clinic_code",
                "gender",
                "township",
                "visit_date",
                "next_appointment",
                "ltfu",
                "missed_appointment",
                "fup_doc_initial",
            ],
        ]
        export_csv(doctor_lists, os.path.join(doctor_dir, "doctor_action_lists.csv"))

        all_registers.append(metrics["registers"])
        all_followups.append(metrics["followups"])
        all_latest.append(metrics["latest_followups"])
        all_monthly.append(metrics["monthly_summary"].assign(source_db=db_name))
        all_quality.append(metrics["quality"].assign(source_db=db_name))
        all_equity_bp.append(metrics["equity_bp"].assign(source_db=db_name))
        all_equity_dm.append(metrics["equity_dm"].assign(source_db=db_name))
        all_kpis.append(metrics["kpi_summary"].assign(source_db=db_name))

    if not all_registers:
        logging.error("No data processed. Check configuration and database connectivity.")
        return 1

    combined_registers = pd.concat(all_registers, ignore_index=True)
    combined_followups = pd.concat(all_followups, ignore_index=True)
    combined_latest = pd.concat(all_latest, ignore_index=True)

    combined_metrics = compute_metrics_for_clinic(
        combined_registers,
        combined_followups,
        combined_latest,
        pd.concat(all_monthly, ignore_index=True),
        config,
        report_start,
        report_end,
    )

    overall_dir = os.path.join(outputs_root, "overall")
    os.makedirs(overall_dir, exist_ok=True)

    mask_ids_enabled = config.get("privacy", {}).get("mask_ids", True)
    id_salt = config.get("privacy", {}).get("id_salt", "")

    stake_registers = mask_ids(combined_metrics["registers"], ["pid", "patient_id", "fuchia_id"], id_salt) if mask_ids_enabled else combined_metrics["registers"]
    stake_followups = mask_ids(combined_metrics["followups"], ["pid", "patient_id", "fuchia_id"], id_salt) if mask_ids_enabled else combined_metrics["followups"]
    stake_latest = mask_ids(combined_metrics["latest_followups"], ["pid", "patient_id", "fuchia_id"], id_salt) if mask_ids_enabled else combined_metrics["latest_followups"]

    export_csv(stake_registers, os.path.join(overall_dir, "registers_clean.csv"))
    export_csv(stake_followups, os.path.join(overall_dir, "followups_clean.csv"))
    export_csv(stake_latest, os.path.join(overall_dir, "patient_latest.csv"))
    export_csv(combined_metrics["monthly_summary"], os.path.join(overall_dir, "monthly_summary.csv"))
    export_csv(combined_metrics["bp_trend"], os.path.join(overall_dir, "bp_control_trend.csv"))
    export_csv(combined_metrics["dm_trend"], os.path.join(overall_dir, "dm_control_trend.csv"))
    export_csv(combined_metrics["continuity"], os.path.join(overall_dir, "continuity_metrics.csv"))
    export_csv(combined_metrics["quality"], os.path.join(overall_dir, "quality_metrics.csv"))
    export_csv(combined_metrics["risk"], os.path.join(overall_dir, "risk_metrics.csv"))
    export_csv(combined_metrics["operations"], os.path.join(overall_dir, "operations_metrics.csv"))
    export_csv(combined_metrics["operations_dist"], os.path.join(overall_dir, "operations_distributions.csv"))
    export_csv(combined_metrics["other_med_dist"], os.path.join(overall_dir, "other_medications.csv"))
    export_csv(combined_metrics["equity_bp"], os.path.join(overall_dir, "equity_bp_control.csv"))
    export_csv(combined_metrics["equity_dm"], os.path.join(overall_dir, "equity_dm_control.csv"))
    export_csv(combined_metrics["kpi_summary"], os.path.join(overall_dir, "kpi_summary.csv"))
    export_csv(pd.concat(all_kpis, ignore_index=True), os.path.join(overall_dir, "kpi_summary_by_clinic.csv"))

    logging.info("Done. Outputs in %s", outputs_root)
    return 0


if __name__ == "__main__":
    sys.exit(main())
