#!/usr/bin/env python3
"""
Read bangalore_raw.json (April hourly data, all three windows) and compute
duration-above-threshold metrics that match how a body actually experiences
heat.  Write bangalore_comfort.json.

Outputs per window:
  - day_exceedance[t]      average hours per day (during 07:00-21:59)
                           above threshold t, for t in 24..36 (0.5C step)
  - night_exceedance[t]    average hours per day (during 22:00-06:59)
                           above threshold t, for t in 20..28 (0.5C step)
  - month_total_*          total hours over the whole month-window that
                           exceed selected human-relevant thresholds
  - comfort_bands          hours per day (on average) spent in each band
"""

import json
from pathlib import Path
from statistics import mean

HERE = Path(__file__).resolve().parent

RAW = json.loads((HERE / "bangalore_raw.json").read_text())

DAY_HOURS = set(range(7, 22))       # 07:00 through 21:59 local
NIGHT_HOURS = set(range(22, 24)) | set(range(0, 7))  # 22:00 through 06:59 local

DAY_THRESHOLDS = [24 + 0.5 * i for i in range(0, 25)]   # 24 ... 36
NIGHT_THRESHOLDS = [20 + 0.5 * i for i in range(0, 17)]  # 20 ... 28

# Comfort bands (upper bound exclusive).  Based on combined day+night use:
# below 22 = deep-sleep-ok; 22-26 = comfort; 26-28 = warm;
# 28-30 = hot; 30-32 = very hot; >=32 = heat-stress.
BANDS = [
    ("sleep_cool", None, 22.0),
    ("comfort",   22.0, 26.0),
    ("warm",      26.0, 28.0),
    ("hot",       28.0, 30.0),
    ("very_hot",  30.0, 32.0),
    ("stress",    32.0, None),
]


def band_for(t):
    for name, lo, hi in BANDS:
        if (lo is None or t >= lo) and (hi is None or t < hi):
            return name
    return None


def compute(window_raw):
    """window_raw: list of {year, hourly:{time, temperature_2m, ...}}"""
    day_points = []      # (hour, temp) for day hours only
    night_points = []    # (hour, temp) for night hours only
    all_points = []      # every hour
    day_count = 0
    night_count = 0
    total_count = 0
    days_seen = set()

    for entry in window_raw:
        hourly = entry["hourly"]
        times = hourly["time"]
        temps = hourly["temperature_2m"]
        for i, t_iso in enumerate(times):
            date_part, time_part = t_iso.split("T")
            hour = int(time_part[:2])
            temp = temps[i]
            if temp is None:
                continue
            days_seen.add(date_part)
            total_count += 1
            all_points.append((hour, temp))
            if hour in DAY_HOURS:
                day_points.append((hour, temp))
                day_count += 1
            elif hour in NIGHT_HOURS:
                night_points.append((hour, temp))
                night_count += 1

    n_days = len(days_seen)

    def exceedance_per_day(points, thresholds, denom_days):
        out = []
        total = len(points)
        if denom_days == 0:
            return [(t, 0.0) for t in thresholds]
        for t in thresholds:
            cnt = sum(1 for (_, v) in points if v >= t)
            # hours per day (on average) above threshold t
            out.append((t, cnt / denom_days))
        return out

    day_exc = exceedance_per_day(day_points, DAY_THRESHOLDS, n_days)
    night_exc = exceedance_per_day(night_points, NIGHT_THRESHOLDS, n_days)

    # Comfort band totals: hours per day in each band (averaged across days).
    band_counts = {name: 0 for name, _, _ in BANDS}
    for (_, t) in all_points:
        b = band_for(t)
        if b:
            band_counts[b] += 1
    bands = {name: (band_counts[name] / n_days if n_days else 0.0) for name, _, _ in BANDS}

    # Month-total key stats (useful for the copy).
    key = {}
    for label, points, thr in [
        ("day_total_hours_ge_28", day_points, 28),
        ("day_total_hours_ge_30", day_points, 30),
        ("day_total_hours_ge_32", day_points, 32),
        ("night_total_hours_ge_22", night_points, 22),
        ("night_total_hours_ge_24", night_points, 24),
        ("night_total_hours_ge_26", night_points, 26),
    ]:
        key[label] = sum(1 for (_, v) in points if v >= thr)

    return {
        "n_days": n_days,
        "total_hours": total_count,
        "day_hours_sampled": day_count,
        "night_hours_sampled": night_count,
        "day_exceedance": day_exc,
        "night_exceedance": night_exc,
        "comfort_bands": bands,
        "month_totals": key,
    }


def linear_cross(vals24, threshold, direction):
    """
    vals24: 24 hourly means (index h = hour 0..23).
    threshold: the temperature we're looking for.
    direction: 'up' for going up (threshold crossed increasing) or 'down' (decreasing).

    Returns a decimal hour (e.g. 11.566 = 11:34) or None if never crosses.
    """
    n = 24
    found = []
    for i in range(n):
        a = vals24[i]
        b = vals24[(i + 1) % n]
        if a is None or b is None:
            continue
        if direction == "up" and a < threshold and b >= threshold:
            if b == a:
                continue
            frac = (threshold - a) / (b - a)
            found.append(i + frac)
        elif direction == "down" and a >= threshold and b < threshold:
            if b == a:
                continue
            frac = (a - threshold) / (a - b)
            found.append(i + frac)
    return found


def fmt_time(h_dec):
    if h_dec is None:
        return None
    h = int(h_dec)
    m = int(round((h_dec - h) * 60))
    if m == 60:
        h += 1; m = 0
    return f"{h:02d}:{m:02d}"


def crossings_for(vals24):
    """Compute key threshold crossing times."""
    out = {}

    # Heat stress: 32 degC going up & down
    ups = linear_cross(vals24, 32.0, "up")
    downs = linear_cross(vals24, 32.0, "down")
    out["heatstress_32_starts"] = fmt_time(min(ups)) if ups else None
    out["heatstress_32_ends"]   = fmt_time(max(downs)) if downs else None

    # Night cool: 22 degC going down (sets cool window start) & up (cool window end)
    ups22 = linear_cross(vals24, 22.0, "up")
    downs22 = linear_cross(vals24, 22.0, "down")
    out["night_cool_22_starts"] = fmt_time(min(downs22)) if downs22 else None  # first downward cross
    out["night_cool_22_ends"]   = fmt_time(min(ups22)) if ups22 else None      # first upward cross
    out["night_min_c"] = min(vals24)
    out["night_ever_below_22"] = bool(downs22)

    return out


def main():
    out = {"bands_def": BANDS, "day_window": "07:00-21:59", "night_window": "22:00-06:59",
           "months": {}}
    for month_name, month_data in RAW["months"].items():
        out["months"][month_name] = {"windows": {}}
        for wkey, wdata in month_data["windows"].items():
            raw = wdata.get("raw", [])
            if not raw:
                out["months"][month_name]["windows"][wkey] = None
                continue
            stats = compute(raw)
            # Also compute hourly mean curve for threshold crossing analysis
            by_hour = {h: [] for h in range(24)}
            for entry in raw:
                hourly = entry["hourly"]
                for i, t_iso in enumerate(hourly["time"]):
                    hour = int(t_iso.split("T")[1][:2])
                    temp = hourly["temperature_2m"][i]
                    if temp is not None:
                        by_hour[hour].append(temp)
            hourly_means = [mean(by_hour[h]) if by_hour[h] else None for h in range(24)]
            stats["crossings"] = crossings_for(hourly_means)

            out["months"][month_name]["windows"][wkey] = {
                "label": wdata["label"],
                "years": wdata.get("years_succeeded", []),
                **stats,
            }

    (HERE / "bangalore_comfort.json").write_text(json.dumps(out, indent=2))
    print("wrote bangalore_comfort.json")

    # Print a human-readable summary
    for month in ("march", "april"):
        print(f"\n======= {month.upper()} =======")
        header = f"{'metric':35s} {'1985-89':>10s} {'2021-25':>10s} {'2026':>10s}   delta 85->26"
        print(header)
        print("-" * len(header))
        w = out["months"][month]["windows"]
        w85 = w["w1985_89"]; w21 = w["w2021_25"]; w26 = w["w2026"]
        for k in ("day_total_hours_ge_28", "day_total_hours_ge_30", "day_total_hours_ge_32",
                 "night_total_hours_ge_22", "night_total_hours_ge_24", "night_total_hours_ge_26"):
            a = w85["month_totals"][k]
            b = w21["month_totals"][k]
            c = w26["month_totals"][k]
            # Per-day rate for fair comparison (2026 April is partial)
            ad = a / w85["n_days"]; bd = b / w21["n_days"]; cd = c / w26["n_days"]
            print(f"{k:35s} {a:>4d} ({ad:4.1f}/d) {b:>4d} ({bd:4.1f}/d) {c:>4d} ({cd:4.1f}/d)  "
                  f"+{(cd-ad):.1f} hrs/day x{cd/ad if ad>0 else float('inf'):.1f}")
        print("\nComfort-band hours per average day:")
        for name, _, _ in BANDS:
            print(f"  {name:12s}  {w85['comfort_bands'][name]:6.2f}  {w21['comfort_bands'][name]:6.2f}  {w26['comfort_bands'][name]:6.2f}")

        print("\nThreshold crossings (when does the day cross 32C / night cross 22C?):")
        for key, label in [
            ("heatstress_32_starts",   "Heat stress starts (32C up)"),
            ("heatstress_32_ends",     "Heat stress ends   (32C dn)"),
            ("night_cool_22_starts",   "Night cools to 22C (dn)    "),
            ("night_cool_22_ends",     "Night returns to 22C (up)  "),
            ("night_min_c",            "Night minimum temp         "),
            ("night_ever_below_22",    "Ever below 22C?            "),
        ]:
            vals = [w85['crossings'].get(key), w21['crossings'].get(key), w26['crossings'].get(key)]
            def pretty(v):
                if v is None: return "   —   "
                if isinstance(v, bool): return " yes " if v else " no  "
                if isinstance(v, float): return f"{v:6.2f}"
                return f"{v:>6}"
            print(f"  {label}  {pretty(vals[0])}  {pretty(vals[1])}  {pretty(vals[2])}")


if __name__ == "__main__":
    main()
