#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""

Compute seasonal mean sea-level pressure (MSLP) weather regimes using k-means
clustering of daily MSLP anomalies.

For each season, the script:
  1. Reads a daily MSLP field from a NetCDF file.
  2. Optionally converts pressure from Pa to hPa.
  3. Computes daily anomalies relative to the day-of-year climatology.
  4. Optionally removes leap days before computing anomalies.
  5. Optionally applies sqrt(cos(latitude)) area weighting before clustering.
  6. Applies k-means independently to each season.
  7. Reorders clusters by decreasing frequency of occurrence.
  8. Writes one NetCDF file per season and one combined NetCDF file.

The output NetCDF files contain:
  - cluster_<SEASON>: daily cluster assignment, with values from 1 to K
  - centroid_<SEASON>: cluster centroids as MSLP anomaly maps
  - frequency_<SEASON>: occurrence frequency of each cluster

Example
-------
python compute_seasonal_mslp_weather_regimes.py \
    --input-file era5_mslp_daily.nc \
    --variable msl \
    --output-dir outputs \
    --start-date 1950-01-01 \
    --end-date 2025-12-31 \
    --n-clusters 4
"""

from __future__ import annotations

import argparse
from pathlib import Path

import numpy as np
import xarray as xr
from sklearn.cluster import KMeans


SEASONS = {
    "DJF": [12, 1, 2],
    "MAM": [3, 4, 5],
    "JJA": [6, 7, 8],
    "SON": [9, 10, 11],
}


def get_latlon_names(da: xr.DataArray) -> tuple[str, str]:
    lat_name = "latitude" if "latitude" in da.dims else ("lat" if "lat" in da.dims else None)
    lon_name = "longitude" if "longitude" in da.dims else ("lon" if "lon" in da.dims else None)

    if lat_name is None or lon_name is None:
        raise ValueError(
            "Latitude/longitude dimensions were not found. Expected either "
            "'lat'/'lon' or 'latitude'/'longitude'. "
            f"Available dimensions: {da.dims}"
        )

    return lat_name, lon_name


def rename_time_if_needed(ds: xr.Dataset) -> xr.Dataset:
    if "time" not in ds.dims and "valid_time" in ds.dims:
        ds = ds.rename({"valid_time": "time"})
    return ds


def drop_leap_day(da: xr.DataArray) -> xr.DataArray:
    leap_day = (da["time"].dt.month == 2) & (da["time"].dt.day == 29)
    return da.sel(time=~leap_day)


def daily_dayofyear_anomalies(
    da: xr.DataArray,
    climatology_start: str,
    climatology_end: str,
    remove_leapday: bool,
) -> xr.DataArray:
    reference = da.sel(time=slice(climatology_start, climatology_end))

    if remove_leapday:
        reference = drop_leap_day(reference)
        da = drop_leap_day(da)

    climatology = reference.groupby("time.dayofyear").mean("time")
    anomalies = da.groupby("time.dayofyear") - climatology
    return anomalies


def sqrt_coslat_weights(lat: xr.DataArray) -> xr.DataArray:
    weights = np.sqrt(np.cos(np.deg2rad(lat)))
    return xr.DataArray(
        weights.values,
        dims=[lat.dims[0]],
        coords={lat.dims[0]: lat},
    )


def stack_to_feature_matrix(
    da: xr.DataArray,
    lat_name: str,
    lon_name: str,
) -> tuple[np.ndarray, np.ndarray, xr.DataArray]:
    stacked = da.stack(feature=(lat_name, lon_name))
    X = stacked.transpose("time", "feature").values

    valid_mask = np.isfinite(X).all(axis=0)
    X = X[:, valid_mask]

    return X, valid_mask, stacked


def unstack_centroids(
    centroids: np.ndarray,
    stacked: xr.DataArray,
    valid_mask: np.ndarray,
) -> xr.DataArray:
    full = np.full(
        (centroids.shape[0], stacked.sizes["feature"]),
        np.nan,
        dtype=np.float32,
    )
    full[:, valid_mask] = centroids.astype(np.float32)

    da_centroids = xr.DataArray(
        full,
        dims=("cluster", "feature"),
        coords={
            "cluster": np.arange(1, centroids.shape[0] + 1),
            "feature": stacked["feature"],
        },
    )

    return da_centroids.unstack("feature")


def standardize_features(X: np.ndarray) -> np.ndarray:
    mean = X.mean(axis=0)
    std = X.std(axis=0)
    std[std == 0] = 1.0
    return (X - mean) / std


def reorder_by_frequency(
    labels_1based: np.ndarray,
    centroids: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    n_clusters = centroids.shape[0]

    counts = np.array(
        [(labels_1based == (i + 1)).sum() for i in range(n_clusters)],
        dtype=int,
    )
    frequencies = counts / counts.sum()

    order = np.lexsort((np.arange(1, n_clusters + 1), -frequencies))
    centroids_reordered = centroids[order, :]

    mapping = np.empty(n_clusters, dtype=int)
    for new_index, old_index in enumerate(order):
        mapping[old_index] = new_index

    labels_0based = labels_1based - 1
    labels_reordered = mapping[labels_0based] + 1

    frequencies_reordered = frequencies[order]
    original_cluster_order_1based = order + 1

    return (
        labels_reordered,
        centroids_reordered,
        frequencies_reordered,
        original_cluster_order_1based,
    )


def open_mslp_field(
    input_file: Path,
    variable: str,
    start_date: str,
    end_date: str,
    convert_to_hpa: bool,
) -> xr.DataArray:
    if not input_file.exists():
        raise FileNotFoundError(f"Input file not found: {input_file}")

    ds = rename_time_if_needed(xr.open_dataset(input_file))

    if variable not in ds:
        raise KeyError(
            f"Variable '{variable}' not found in {input_file}. "
            f"Available variables: {list(ds.data_vars)}"
        )

    da = ds[variable].sel(time=slice(start_date, end_date))

    if da.sizes.get("time", 0) == 0:
        raise RuntimeError(
            f"No time steps found between {start_date} and {end_date}."
        )

    if convert_to_hpa:
        sample = float(da.isel(time=0).mean().values)
        if sample > 2000:
            da = da / 100.0
            da.attrs["units"] = "hPa"

    return da


def compute_one_season(
    anomalies_weighted: xr.DataArray,
    anomalies_unweighted: xr.DataArray,
    season: str,
    months: list[int],
    lat_name: str,
    lon_name: str,
    n_clusters: int,
    max_iter: int,
    n_init: int,
    random_state: int,
    standardize: bool,
    area_weighting: bool,
    lat_weights: xr.DataArray | None,
) -> xr.Dataset:
    selected = anomalies_weighted.sel(
        time=anomalies_weighted["time"].dt.month.isin(months)
    )

    X, valid_mask, stacked = stack_to_feature_matrix(selected, lat_name, lon_name)

    if standardize:
        X = standardize_features(X)

    kmeans = KMeans(
        n_clusters=n_clusters,
        init="k-means++",
        n_init=n_init,
        max_iter=max_iter,
        random_state=random_state,
        algorithm="lloyd",
    )

    labels = kmeans.fit_predict(X) + 1

    labels, centroids, frequencies, original_order = reorder_by_frequency(
        labels_1based=labels,
        centroids=kmeans.cluster_centers_,
    )

    labels_da = xr.DataArray(
        labels.astype(np.int16),
        dims=("time",),
        coords={"time": selected["time"]},
        name=f"cluster_{season}",
    )
    labels_da.attrs["description"] = (
        f"K-means cluster assignment for {season}. Clusters are reordered by "
        "decreasing frequency of occurrence."
    )
    labels_da.attrs["values"] = f"1..{n_clusters}"
    labels_da.attrs["reordering"] = (
        "New clusters 1..K correspond to the following original k-means "
        f"clusters: {list(original_order)}"
    )

    centroids_da = unstack_centroids(centroids, stacked, valid_mask)
    centroids_da = centroids_da.rename(f"centroid_{season}")

    if area_weighting and lat_weights is not None:
        centroids_da = (centroids_da / lat_weights).rename(f"centroid_{season}")

    centroids_da.attrs["description"] = (
        f"K-means centroid anomalies for {season}, reordered by decreasing "
        "frequency of occurrence."
    )
    centroids_da.attrs["units"] = anomalies_unweighted.attrs.get("units", "")

    frequency_da = xr.DataArray(
        frequencies.astype(np.float32),
        dims=("cluster",),
        coords={"cluster": np.arange(1, n_clusters + 1)},
        name=f"frequency_{season}",
    )
    frequency_da.attrs["description"] = (
        f"Cluster occurrence frequency for {season}, expressed as a fraction."
    )

    return xr.Dataset(
        {
            labels_da.name: labels_da,
            centroids_da.name: centroids_da,
            frequency_da.name: frequency_da,
        }
    )


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Compute seasonal MSLP weather regimes using k-means clustering."
    )

    parser.add_argument(
        "--input-file",
        required=True,
        type=Path,
        help="Input NetCDF file containing daily MSLP fields.",
    )
    parser.add_argument(
        "--variable",
        default="msl",
        help="Name of the MSLP variable in the input NetCDF file.",
    )
    parser.add_argument(
        "--output-dir",
        required=True,
        type=Path,
        help="Directory where seasonal and combined NetCDF outputs will be written.",
    )
    parser.add_argument(
        "--start-date",
        default="1950-01-01",
        help="Start date used to subset the input data and compute climatology.",
    )
    parser.add_argument(
        "--end-date",
        default="2025-12-31",
        help="End date used to subset the input data and compute climatology.",
    )
    parser.add_argument(
        "--n-clusters",
        type=int,
        default=4,
        help="Number of k-means clusters per season.",
    )
    parser.add_argument(
        "--max-iter",
        type=int,
        default=1000,
        help="Maximum number of k-means iterations.",
    )
    parser.add_argument(
        "--n-init",
        type=int,
        default=50,
        help="Number of k-means initializations.",
    )
    parser.add_argument(
        "--random-state",
        type=int,
        default=42,
        help="Random seed used by k-means.",
    )
    parser.add_argument(
        "--no-hpa-conversion",
        action="store_true",
        help="Do not convert pressure values from Pa to hPa.",
    )
    parser.add_argument(
        "--no-area-weighting",
        action="store_true",
        help="Do not apply sqrt(cos(latitude)) area weighting before clustering.",
    )
    parser.add_argument(
        "--keep-leapday",
        action="store_true",
        help="Keep February 29 in the anomaly calculation.",
    )
    parser.add_argument(
        "--standardize-features",
        action="store_true",
        help="Standardize each spatial feature before applying k-means.",
    )

    return parser.parse_args()


def main() -> None:
    args = parse_args()
    args.output_dir.mkdir(parents=True, exist_ok=True)

    da = open_mslp_field(
        input_file=args.input_file,
        variable=args.variable,
        start_date=args.start_date,
        end_date=args.end_date,
        convert_to_hpa=not args.no_hpa_conversion,
    )

    lat_name, lon_name = get_latlon_names(da)

    anomalies = daily_dayofyear_anomalies(
        da,
        climatology_start=args.start_date,
        climatology_end=args.end_date,
        remove_leapday=not args.keep_leapday,
    )
    anomalies = anomalies.rename(f"{args.variable}_anom")
    anomalies.attrs["description"] = (
        "Daily anomalies relative to the day-of-year climatology."
    )
    anomalies.attrs["units"] = da.attrs.get("units", "")

    if args.no_area_weighting:
        lat_weights = None
        anomalies_weighted = anomalies
    else:
        lat_weights = sqrt_coslat_weights(anomalies[lat_name])
        anomalies_weighted = anomalies * lat_weights

    all_outputs: dict[str, xr.DataArray] = {}

    for season, months in SEASONS.items():
        ds_season = compute_one_season(
            anomalies_weighted=anomalies_weighted,
            anomalies_unweighted=anomalies,
            season=season,
            months=months,
            lat_name=lat_name,
            lon_name=lon_name,
            n_clusters=args.n_clusters,
            max_iter=args.max_iter,
            n_init=args.n_init,
            random_state=args.random_state,
            standardize=args.standardize_features,
            area_weighting=not args.no_area_weighting,
            lat_weights=lat_weights,
        )

        out_file = args.output_dir / f"kmeans_mslp_daily_{season}_K{args.n_clusters}.nc"
        ds_season.to_netcdf(out_file)
        print(f"[OK] {season}: {out_file}")

        for name, variable in ds_season.data_vars.items():
            all_outputs[name] = variable

    combined = xr.Dataset(all_outputs)
    combined_file = (
        args.output_dir / f"kmeans_mslp_daily_ALLSEASONS_K{args.n_clusters}.nc"
    )
    combined.to_netcdf(combined_file)
    print(f"[OK] Combined output: {combined_file}")


if __name__ == "__main__":
    main()

