#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
aggregate_CFseries_to_nuts0.py

Aggregate PECD zone-level time series (P2ON, P2OF, SZON, etc.)
to NUTS0 countries using spatial weights derived from PECD v4.2 masks.

- Individual variables:

  - Wind onshore (WON) -> WON-NUTS0-<PERIOD>
  - Wind offshore (WOF) -> WOF-NUTS0-<PERIOD>
  - Wind total (WON+WOF) -> WTO-NUTS0-<PERIOD>
  - Solar phtotoltaics (SPV) -> SPV-NUTS0-<PERIOD>

TXT input format:
  YYYY-MM-DD-HH <TAB> value

"""

from __future__ import annotations

import re
from pathlib import Path
from typing import Dict, Tuple, List, Optional

import numpy as np
import pandas as pd
import xarray as xr


# =============================================================================
# CONFIGURATION
# =============================================================================

# Directories containing zone-level TXT files. Multiple directories are allowed.
INPUT_TXT_DIRS = [
    Path("/path/to/SPV-P2ON"),
    Path("/path/to/WON-P2ON"),
    Path("/path/to/WOF-P2OF"),
]

# Root output directory. Variable-specific output folders are created inside it.
OUTPUT_ROOT = Path("/path/to/output")

# Period suffix used to name output folders.
PERIOD_TAG = "1950-2025"

# Tag used for combined onshore + offshore wind output.
WIND_TOTAL_TAG = "WTO"  # e.g. "WTO" (Wind Total) or "WIND"

# --- PECD v4.2 NetCDF spatial masks/weights ---

NC_NUTS0_MASK  = Path("/path/to/masks/ANCI_NUT0-mask_PECD4.2_fv1.nc")
# Fractional NUTS0 country mask. Each layer represents the fraction of
# each grid cell belonging to a given country.

NC_P2ON_MASK   = Path("/path/to/masks/ANCI_P2ON-mask_PECD4.2_fv1.nc")
# Fractional mask of PECD P2ON (Pan-European Onshore) regions.

NC_P2OF_MASK   = Path("/path/to/masks/ANCI_P2OF-mask_PECD4.2_fv1.nc")
# Fractional mask of PECD P2OF (Pan-European Offshore) regions.

NC_SZON_MASK   = Path("/path/to/masks/ANCI_SZON-mask_PECD4.2_fv1.nc")
# Fractional mask of PECD SZON (bidding-zone) regions.

NC_LAT_WEIGHTS = Path("/path/to/masks/ANCI_LAT-mask_PECD4.2_fv1.nc")
# Latitude-dependent grid-cell area weights used to account for the
# varying area represented by regular latitude–longitude grid cells.

NC_POP_DENSITY = Path("/path/to/masks/ANCI_POP-mask_PECD4.2_fv1.nc")
# Population density field used as weighting factor when aggregating
# demand-related variables.

NC_PV_MASK     = Path("/path/to/masks/ANCI_PVM-mask_PECD4.2_fv1.nc")
# Solar PV suitability mask defining grid cells considered eligible
# for photovoltaic generation.

NC_WIND_MASK   = Path("/path/to/masks/ANCI_WPM-mask_PECD4.2_fv1.nc")
# Wind-power suitability/restriction mask defining grid cells eligible
# for wind generation. Depending on the PECD version, values may need
# to be inverted (e.g. 1 = restricted, 0 = eligible).

# Set True when the wind mask encodes restricted cells as 1 and eligible cells as 0.
WIND_MASK_INVERT = True

# Optional filters. Use None to process all variables/zonesets found in INPUT_TXT_DIRS.
ONLY_VARCODE = None   # e.g. {"SPV", "WON", "WOF"} or None
ONLY_ZONESET = None   # e.g. {"P2ON", "P2OF", "SZON"} or None

# Separator used in TXT files.
TXT_SEP = "\t"

# =============================================================================
# END OF CONFIGURATION
# =============================================================================

# Accepted file names:
#  - VAR_ZONESET_ZONE.txt
#  - VAR_ZONESET_ZONE.txt
#  - VAR_ZONESET_ZONE.txt

F_TXT = re.compile(
    r"^(?P<var>[A-Za-z0-9]+)_(?P<zoneset>[A-Za-z0-9]+)_(?P<zone>[A-Za-z0-9]+)(?:_(?P<tag>OFF|ON))?\.txt$"
)

def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


def outdir_for(var_tag: str) -> Path:
    return OUTPUT_ROOT / f"{var_tag}-NUTS0-{PERIOD_TAG}"


def first_data_var(ds: xr.Dataset) -> str:
    return list(ds.data_vars.keys())[0]


def load_da(path: Path, varname: Optional[str] = None) -> xr.DataArray:
    ds = xr.open_dataset(path)
    v = varname or first_data_var(ds)
    return ds[v]


def get_region_dim(da: xr.DataArray) -> Optional[str]:
    dims = list(da.dims)
    lat_candidates = [d for d in dims if d.lower().startswith("lat")]
    lon_candidates = [d for d in dims if d.lower().startswith("lon")]
    reg_candidates = [d for d in dims if d not in lat_candidates + lon_candidates]
    if len(reg_candidates) != 1:
        return None
    return reg_candidates[0]


def align_grids(*das: xr.DataArray) -> List[xr.DataArray]:
    aligned = xr.align(*das, join="inner")
    return list(aligned)


def build_weights_zone_to_country(
    country_mask: xr.DataArray,
    zone_mask: xr.DataArray,
    latw: xr.DataArray,
    extra_weight: xr.DataArray,
    countries: List[str],
    zones: List[str],
) -> np.ndarray:
    cdim0 = get_region_dim(country_mask)
    zdim0 = get_region_dim(zone_mask)
    if cdim0 is None or zdim0 is None:
        raise RuntimeError("Could not detect the region dimension in the NUTS0 or zone mask.")

    ccoord0 = [str(x) for x in country_mask[cdim0].values]
    zcoord0 = [str(x) for x in zone_mask[zdim0].values]

    missing = [z for z in zones if z not in zcoord0]
    if missing:
        raise RuntimeError(
            f"Zones not found in mask {zdim0}: {missing[:10]} (total {len(missing)})"
        )

    c_idx = [ccoord0.index(c) for c in countries]
    z_idx = [zcoord0.index(z) for z in zones]

    C0 = country_mask.isel({cdim0: c_idx})
    Z0 = zone_mask.isel({zdim0: z_idx})

    C = C0.rename({cdim0: "country"}).astype("float64")
    Z = Z0.rename({zdim0: "zone"}).astype("float64")

    C, Z, latw2, ew = align_grids(C, Z, latw, extra_weight)
    latw2 = latw2.astype("float64")
    ew = ew.astype("float64")

    lat_dim = [d for d in C.dims if d.lower().startswith("lat")][0]
    lon_dim = [d for d in C.dims if d.lower().startswith("lon")][0]

    CW = C * latw2 * ew
    prod = CW.expand_dims(zone=Z["zone"]) * Z.expand_dims(country=C["country"])

    W = prod.sum(dim=(lat_dim, lon_dim)).transpose("country", "zone").values
    return np.asarray(W, dtype=np.float64)


def normalize_rows(W: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    row_sum = W.sum(axis=1)
    valid = row_sum > 0
    Wn = np.zeros_like(W)
    Wn[valid, :] = Wn[valid, :] + (W[valid, :] / row_sum[valid, None])
    return Wn, valid


def list_zone_txt_groups(indirs: List[Path]) -> Dict[Tuple[str, str], List[str]]:
    groups: Dict[Tuple[str, str], List[str]] = {}
    for indir in indirs:
        for p in Path(indir).glob("*.txt"):
            m = F_TXT.match(p.name)
            if not m:
                continue
            var = m.group("var").upper()
            zoneset = m.group("zoneset").upper()
            zone = m.group("zone").upper()
            tag = (m.group("tag") or "").upper()

            if ONLY_VARCODE is not None and var not in ONLY_VARCODE:
                continue
            if ONLY_ZONESET is not None and zoneset not in ONLY_ZONESET:
                continue

            groups.setdefault((var, zoneset), []).append(zone)

    for k in list(groups.keys()):
        groups[k] = sorted(set(groups[k]))
    return dict(sorted(groups.items()))


def find_zone_file(indirs: List[Path], var: str, zoneset: str, zone: str) -> Path:

    candidates = [
        f"{var}_{zoneset}_{zone}.txt",
    ]

    # Remove duplicates while preserving order.
    seen = set()
    cand2 = []
    for c in candidates:
        if c not in seen:
            cand2.append(c)
            seen.add(c)

    for d in indirs:
        for fname in cand2:
            p = Path(d) / fname
            if p.exists():
                return p

    raise FileNotFoundError(
        f"Could not find {var}_{zoneset}_{zone}.txt"
        f"in any INPUT_TXT_DIRS"
    )


def read_zone_series(input_dirs: List[Path], var: str, zoneset: str, zone: str) -> pd.Series:
    f = find_zone_file(input_dirs, var, zoneset, zone)
    df = pd.read_csv(f, sep=TXT_SEP, header=None, names=["date", "value"], dtype={"date": str})
    s = pd.Series(df["value"].to_numpy(dtype=np.float64), index=df["date"].to_numpy())
    s.index.name = "date"
    return s


def build_X_matrix(input_dirs: List[Path], var: str, zoneset: str, zones: List[str]) -> Tuple[np.ndarray, np.ndarray]:
    series_list = []
    for z in zones:
        series_list.append(read_zone_series(input_dirs, var, zoneset, z).rename(z))

    df = pd.concat(series_list, axis=1, join="inner")
    df = df.sort_index()
    dates = df.index.to_numpy(dtype=str)
    X = df.to_numpy(dtype=np.float64)
    return dates, X


def write_country_series(output_dir: Path, fname: str, dates: np.ndarray, values: np.ndarray) -> None:
    ensure_dir(output_dir)
    out = output_dir / fname
    with out.open("w", encoding="utf-8") as f:
        for dt, val in zip(dates, values):
            f.write(f"{dt}{TXT_SEP}{val}\n")


def aggregate_group_to_nuts0(
    input_dirs: List[Path],
    output_dir: Path,
    var: str,
    zoneset: str,
    zones: List[str],
    countries: List[str],
    W: np.ndarray,
    prefix: Optional[str] = None,
) -> None:
    Wn, valid = normalize_rows(W)

    dates, X = build_X_matrix(input_dirs, var, zoneset, zones)
    Y = X @ Wn.T

    for i, c in enumerate(countries):
        if not valid[i]:
            continue
        fname = f"{var}_{zoneset}_NUTS0_{c}.txt" if prefix is None else f"{prefix}_NUTS0_{c}.txt"
        write_country_series(output_dir, fname, dates, Y[:, i])


def aggregate_wind_on_off_total(
    input_dirs: List[Path],
    outdir_won: Path,
    outdir_wof: Path,
    outdir_wto: Path,
    countries: List[str],
    zones_on: List[str],
    zones_off: List[str],
    W_on: np.ndarray,
    W_off: np.ndarray,
) -> None:
    dates_on, X_on = build_X_matrix(input_dirs, "WON", "P2ON", zones_on)
    dates_off, X_off = build_X_matrix(input_dirs, "WOF", "P2OF", zones_off)

    idx_on = pd.Index(dates_on, name="date")
    idx_off = pd.Index(dates_off, name="date")
    master_idx = idx_on.union(idx_off).sort_values()

    df_on = pd.DataFrame(X_on, index=idx_on, columns=zones_on).reindex(master_idx).fillna(0.0)
    df_off = pd.DataFrame(X_off, index=idx_off, columns=zones_off).reindex(master_idx).fillna(0.0)

    dates = master_idx.to_numpy(dtype=str)
    X_on_c = df_on.to_numpy(dtype=np.float64)
    X_off_c = df_off.to_numpy(dtype=np.float64)

    Wn_on, valid_on = normalize_rows(W_on)
    Y_on = X_on_c @ Wn_on.T

    Wn_off, valid_off = normalize_rows(W_off)
    Y_off = X_off_c @ Wn_off.T

    W_tot = np.concatenate([W_on, W_off], axis=1)
    Wn_tot, valid_tot = normalize_rows(W_tot)
    X_tot = np.concatenate([X_on_c, X_off_c], axis=1)
    Y_tot = X_tot @ Wn_tot.T

    for i, c in enumerate(countries):
        if valid_on[i]:
            write_country_series(outdir_won, f"WON_NUTS0_{c}.txt", dates, Y_on[:, i])

        if valid_off[i]:
            write_country_series(outdir_wof, f"WOF_NUTS0_{c}.txt", dates, Y_off[:, i])

        if valid_tot[i]:
            write_country_series(outdir_wto, f"{WIND_TOTAL_TAG}_NUTS0_{c}.txt", dates, Y_tot[:, i])


def main() -> None:
    missing_dirs = [str(d) for d in INPUT_TXT_DIRS if not Path(d).exists()]
    if missing_dirs:
        raise SystemExit("These INPUT_TXT_DIRS do not exist:\n  - " + "\n  - ".join(missing_dirs))

    groups = list_zone_txt_groups(INPUT_TXT_DIRS)
    if not groups:
        raise SystemExit("No *_*_*.txt files found in any INPUT_TXT_DIRS")

    print(f"Detected groups (VAR_ZONESET): {len(groups)}")
    for k, z in groups.items():
        print(f"  - {k[0]}_{k[1]}: {len(z)} zones")

    print("\nLoading masks/weights...")

    nuts0 = load_da(NC_NUTS0_MASK, varname="mask")
    p2on  = load_da(NC_P2ON_MASK,  varname="mask")
    p2of  = load_da(NC_P2OF_MASK,  varname="mask")
    szon  = load_da(NC_SZON_MASK,  varname="mask")

    latw  = load_da(NC_LAT_WEIGHTS)
    pop   = load_da(NC_POP_DENSITY)
    pv_mask = load_da(NC_PV_MASK)
    wind_mask_raw = load_da(NC_WIND_MASK)
    wind_elig = (1.0 - wind_mask_raw) if WIND_MASK_INVERT else wind_mask_raw

    cdim = get_region_dim(nuts0)
    if cdim is None:
        raise RuntimeError("Could not detect the country dimension in the NUTS0 mask.")
    countries = [str(x) for x in nuts0[cdim].values]

    if len(latw.dims) == 1:
        lon_dim = [d for d in pv_mask.dims if d.lower().startswith("lon")][0]
        latw = latw.expand_dims({lon_dim: pv_mask[lon_dim]})

    # --- Aggregate individual variables, each into its own output folder. ---
    for (var, zoneset), zones in groups.items():

        # Avoid duplicates: WON and WOF are handled in the dedicated wind block.
        if (var, zoneset) in {("WON", "P2ON"), ("WOF", "P2OF")}:
            print(f"\n== {var}_{zoneset} -> skipped; it will be handled in the wind block")
            continue

        print(f"\n== Aggregating {var}_{zoneset} -> NUTS0")

        if zoneset == "P2ON":
            zone_mask = p2on
            extra = pv_mask if var == "SPV" else xr.ones_like(pv_mask)

        elif zoneset == "P2OF":
            zone_mask = p2of
            extra = xr.ones_like(pv_mask)

        elif zoneset == "SZON":
            zone_mask = szon
            extra = pop

        else:
            raise RuntimeError(f"Unsupported ZONESET: {zoneset}")

        outdir = outdir_for(var)

        print("   - Building weights...")
        W = build_weights_zone_to_country(
            country_mask=nuts0,
            zone_mask=zone_mask,
            latw=latw,
            extra_weight=extra,
            countries=countries,
            zones=zones,
        )

        print("   - Aggregating time series...")
        aggregate_group_to_nuts0(
            input_dirs=INPUT_TXT_DIRS,
            output_dir=outdir,
            var=var,
            zoneset=zoneset,
            zones=zones,
            countries=countries,
            W=W,
            prefix=var,   # Force names such as SPV_NUTS0_ES.txt.
        )

        print(f"   -> OK ({outdir})")

    have_on = ("WON", "P2ON") in groups
    have_off = ("WOF", "P2OF") in groups

    if have_on and have_off:
        zones_on = groups[("WON", "P2ON")]
        zones_off = groups[("WOF", "P2OF")]  

        outdir_won = outdir_for("WON")
        outdir_wof = outdir_for("WOF")
        outdir_wto = outdir_for(WIND_TOTAL_TAG)

        print(f"\n== Aggregating combined wind -> {outdir_wto.name}")

        print("   - Onshore weights (P2ON) with WPM eligibility...")
        W_on = build_weights_zone_to_country(
            country_mask=nuts0,
            zone_mask=p2on,
            latw=latw,
            extra_weight=wind_elig,
            countries=countries,
            zones=zones_on,
        )

        print("   - Offshore weights (P2OF) without restriction mask...")
        W_off = build_weights_zone_to_country(
            country_mask=nuts0,
            zone_mask=p2of,
            latw=latw,
            extra_weight=xr.ones_like(pv_mask),
            countries=countries,
            zones=zones_off,
        )

        print("   - Aggregating and writing onshore/offshore/total wind...")
        aggregate_wind_on_off_total(
            input_dirs=INPUT_TXT_DIRS,
            outdir_won=outdir_won,
            outdir_wof=outdir_wof,
            outdir_wto=outdir_wto,
            countries=countries,
            zones_on=zones_on,
            zones_off=zones_off,
            W_on=W_on,
            W_off=W_off,
        )

        print("   -> OK (combined wind)")
    else:
        print("\nCombined wind was not generated because WON_P2ON or WOF_P2OF TXT files are missing.")

    print("\nDone.")
    print("Variable-specific outputs written under:", OUTPUT_ROOT)


if __name__ == "__main__":
    main()

