#!/usr/bin/env python3
"""
Fetch ERA5 reanalysis data for Bangalore March & April from Open-Meteo archive
API, compute heat-stress metrics across three time windows, and write both the
raw hourly data and the computed summary to JSON.

Windows (per month):
  - 1985-1989 (pre-urbanization baseline)
  - 2021-2025 (recent)
  - 2026 (March: full month; April: Apr 1 to yesterday)
"""

import json
import time
import urllib.parse
import urllib.request
import urllib.error
from datetime import date, timedelta
from pathlib import Path
from statistics import mean

HERE = Path(__file__).resolve().parent

LAT = 12.9716
LON = 77.5946
TZ = "Asia/Kolkata"
ARCHIVE_URL = "https://archive-api.open-meteo.com/v1/archive"

VARS = ["temperature_2m", "apparent_temperature", "relative_humidity_2m"]

# Today is 2026-04-21 per the task brief; fetch 2026 from Apr 1 up to yesterday.
TODAY = date(2026, 4, 21)
YESTERDAY = TODAY - timedelta(days=1)

MONTHS = [
    {
        "name": "march",
        "windows": [
            {"key": "w1985_89", "label": "1985-1989",
             "years": [1985, 1986, 1987, 1988, 1989],
             "start_md": (3, 1), "end_md": (3, 31)},
            {"key": "w2021_25", "label": "2021-2025",
             "years": [2021, 2022, 2023, 2024, 2025],
             "start_md": (3, 1), "end_md": (3, 31)},
            {"key": "w2026", "label": "2026 (Mar 1 - Mar 31)",
             "years": [2026],
             "start_md": (3, 1), "end_md": (3, 31)},
        ],
    },
    {
        "name": "april",
        "windows": [
            {"key": "w1985_89", "label": "1985-1989",
             "years": [1985, 1986, 1987, 1988, 1989],
             "start_md": (4, 1), "end_md": (4, 30)},
            {"key": "w2021_25", "label": "2021-2025",
             "years": [2021, 2022, 2023, 2024, 2025],
             "start_md": (4, 1), "end_md": (4, 30)},
            {"key": "w2026", "label": f"2026 (Apr 1 - Apr {YESTERDAY.day})",
             "years": [2026],
             "start_md": (4, 1), "end_md": (YESTERDAY.month, YESTERDAY.day)},
        ],
    },
]


def fetch_year(year: int, start_md, end_md):
    start = date(year, start_md[0], start_md[1]).isoformat()
    end = date(year, end_md[0], end_md[1]).isoformat()
    params = {
        "latitude": LAT,
        "longitude": LON,
        "start_date": start,
        "end_date": end,
        "hourly": ",".join(VARS),
        "timezone": TZ,
    }
    url = ARCHIVE_URL + "?" + urllib.parse.urlencode(params)
    req = urllib.request.Request(url, headers={"User-Agent": "bangalore-is-hot/1.0"})
    with urllib.request.urlopen(req, timeout=60) as resp:
        return json.loads(resp.read().decode("utf-8"))


def collect_window(window):
    raw = []
    succeeded = []
    failed = []
    for y in window["years"]:
        try:
            print(f"  fetching {y} [{window['start_md']} -> {window['end_md']}] ...", flush=True)
            data = fetch_year(y, window["start_md"], window["end_md"])
            hours = data.get("hourly", {})
            times = hours.get("time", [])
            if not times:
                raise RuntimeError(f"no hourly data returned for {y}")
            raw.append({"year": y, "hourly": hours})
            succeeded.append(y)
        except Exception as exc:
            print(f"    !! failed for {y}: {exc}", flush=True)
            failed.append({"year": y, "error": str(exc)})
        time.sleep(1.0)
    return raw, succeeded, failed


NIGHT_HOURS = set(range(0, 7))        # 00:00-06:59 local: body's recovery window
DAY_HOURS = set(range(12, 17))        # 12:00-16:59 local: solar peak window


def compute_metrics(window_raw):
    """
    window_raw: list of {year, hourly:{time, temperature_2m, apparent_temperature, relative_humidity_2m}}

    Returns dict of computed stats for this window.
    """
    # Per hour-of-day accumulators.
    by_hour_t = {h: [] for h in range(24)}
    by_hour_at = {h: [] for h in range(24)}
    by_hour_rh = {h: [] for h in range(24)}

    # Per-day accumulators keyed by (year, month, day).
    per_day = {}

    for entry in window_raw:
        hourly = entry["hourly"]
        times = hourly["time"]
        temps = hourly["temperature_2m"]
        app = hourly["apparent_temperature"]
        rh = hourly["relative_humidity_2m"]
        for i, t_iso in enumerate(times):
            date_part, time_part = t_iso.split("T")
            hour = int(time_part[:2])
            temp = temps[i]
            at = app[i]
            rh_v = rh[i]
            if temp is None:
                continue
            by_hour_t[hour].append(temp)
            if at is not None:
                by_hour_at[hour].append(at)
            if rh_v is not None:
                by_hour_rh[hour].append(rh_v)

            day_key = date_part
            d = per_day.setdefault(day_key, {"temps": [], "ats": [], "night": [], "day": []})
            d["temps"].append(temp)
            if at is not None:
                d["ats"].append(at)
            if hour in NIGHT_HOURS:
                d["night"].append(temp)
            if hour in DAY_HOURS:
                d["day"].append(temp)

    hourly_mean_t = [mean(by_hour_t[h]) if by_hour_t[h] else None for h in range(24)]
    hourly_mean_at = [mean(by_hour_at[h]) if by_hour_at[h] else None for h in range(24)]
    hourly_mean_rh = [mean(by_hour_rh[h]) if by_hour_rh[h] else None for h in range(24)]

    daily_mins = []
    daily_maxs = []
    daily_means = []
    daily_night_means = []
    daily_day_means = []
    hours_ge_26_per_day = []
    hours_ge_28_per_day = []
    deg_hours_above_26_per_day = []
    nights_min_ge_22_days = 0  # "tropical night" analogue - body can't recover well
    nights_min_ge_24_days = 0

    for day_key, d in per_day.items():
        temps = d["temps"]
        if not temps:
            continue
        dmin = min(temps)
        dmax = max(temps)
        daily_mins.append(dmin)
        daily_maxs.append(dmax)
        daily_means.append(mean(temps))
        if d["night"]:
            daily_night_means.append(mean(d["night"]))
        if d["day"]:
            daily_day_means.append(mean(d["day"]))
        hours_ge_26_per_day.append(sum(1 for t in temps if t >= 26))
        hours_ge_28_per_day.append(sum(1 for t in temps if t >= 28))
        deg_hours_above_26_per_day.append(sum(max(0.0, t - 26.0) for t in temps))
        if dmin >= 22:
            nights_min_ge_22_days += 1
        if dmin >= 24:
            nights_min_ge_24_days += 1

    n_days = len(per_day)
    return {
        "n_days": n_days,
        "hourly_mean_temperature_2m": hourly_mean_t,
        "hourly_mean_apparent_temperature": hourly_mean_at,
        "hourly_mean_relative_humidity_2m": hourly_mean_rh,
        "avg_daily_min": mean(daily_mins) if daily_mins else None,
        "avg_daily_max": mean(daily_maxs) if daily_maxs else None,
        "avg_daily_mean": mean(daily_means) if daily_means else None,
        "avg_night_mean_00_07": mean(daily_night_means) if daily_night_means else None,
        "avg_day_mean_12_17": mean(daily_day_means) if daily_day_means else None,
        "avg_diurnal_range": (mean(daily_maxs) - mean(daily_mins)) if daily_mins and daily_maxs else None,
        "avg_hours_per_day_ge_26": mean(hours_ge_26_per_day) if hours_ge_26_per_day else None,
        "avg_hours_per_day_ge_28": mean(hours_ge_28_per_day) if hours_ge_28_per_day else None,
        "avg_degree_hours_above_26_per_day": mean(deg_hours_above_26_per_day) if deg_hours_above_26_per_day else None,
        "pct_nights_min_ge_22": (nights_min_ge_22_days / n_days * 100) if n_days else None,
        "pct_nights_min_ge_24": (nights_min_ge_24_days / n_days * 100) if n_days else None,
    }


def main():
    raw_out = {"location": {"lat": LAT, "lon": LON, "timezone": TZ}, "months": {}}
    summary = {
        "location": {"lat": LAT, "lon": LON, "timezone": TZ},
        "generated_at": TODAY.isoformat(),
        "months": {},
    }

    for m in MONTHS:
        print(f"\n======= MONTH: {m['name'].upper()} =======")
        raw_out["months"][m["name"]] = {"windows": {}}
        summary["months"][m["name"]] = {"windows": {}}
        for w in m["windows"]:
            print(f"\n[{m['name']} / {w['label']}] window key={w['key']}")
            raw, ok, failed = collect_window(w)
            raw_out["months"][m["name"]]["windows"][w["key"]] = {
                "label": w["label"],
                "years_succeeded": ok,
                "years_failed": failed,
                "raw": raw,
            }
            if not raw:
                print(f"  !! no years succeeded for {w['label']}, skipping stats")
                summary["months"][m["name"]]["windows"][w["key"]] = {
                    "label": w["label"],
                    "years_succeeded": ok,
                    "years_failed": failed,
                    "stats": None,
                }
                continue
            stats = compute_metrics(raw)
            summary["months"][m["name"]]["windows"][w["key"]] = {
                "label": w["label"],
                "years_succeeded": ok,
                "years_failed": failed,
                "stats": stats,
            }
            print(f"  ok years: {ok}; failed: {[f['year'] for f in failed]}")
            print(f"  n_days={stats['n_days']}, min={stats['avg_daily_min']:.2f}, "
                  f"max={stats['avg_daily_max']:.2f}, mean={stats['avg_daily_mean']:.2f}, "
                  f"range={stats['avg_diurnal_range']:.2f}")
            print(f"  night(00-07)={stats['avg_night_mean_00_07']:.2f}, "
                  f"day(12-17)={stats['avg_day_mean_12_17']:.2f}")
            print(f"  hrs>=26={stats['avg_hours_per_day_ge_26']:.2f}, "
                  f"hrs>=28={stats['avg_hours_per_day_ge_28']:.2f}, "
                  f"dh>26={stats['avg_degree_hours_above_26_per_day']:.2f}")
            print(f"  pct_nights min>=22: {stats['pct_nights_min_ge_22']:.1f}%, "
                  f"min>=24: {stats['pct_nights_min_ge_24']:.1f}%")

    raw_path = HERE / "bangalore_raw.json"
    summary_path = HERE / "bangalore_data.json"
    raw_path.write_text(json.dumps(raw_out, indent=2))
    summary_path.write_text(json.dumps(summary, indent=2))
    print(f"\nwrote {raw_path}")
    print(f"wrote {summary_path}")


if __name__ == "__main__":
    main()
