#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
optimize_solar_wind_mix.py

Find the solar–wind mix that maximizes an IDF-based objective function for
renewable-energy droughts.

For each country and target severity rank, the script scans a user-defined
solar-share grid and computes:

    J(S) = sum_{D=dmin..dmax} I(D, rank=target_severity)

where I(D, rank) is the drought intensity obtained from the normalized daily
mixed capacity-factor series using the same event-selection algorithm as the
IDF-curve calculation:

  - D-day rolling mean
  - event date assigned using shift(-D // 2)
  - iterative selection of the N lowest non-overlapping events

The mixed series is defined as:

    MIX = S * SPV + W * WTO,     W = 1 - S

where SPV is solar photovoltaic capacity factor and WTO is total wind capacity
factor, including onshore and offshore wind. The mixed series is normalized by
its long-term mean before the drought analysis.

The optimal solution is defined as the solar share that maximizes J(S),
thereby maximizing drought intensity values (i.e. reducing the severity
of renewable-energy droughts) over the selected duration range.

Expected input files
--------------------
For each country code, the input directory must contain:

    SPV-NUTS0-<period_tag>/DAILY_SPV_NUTS0_<COUNTRY>.txt
    WTO-NUTS0-<period_tag>/DAILY_WTO_NUTS0_<COUNTRY>.txt

Each file must be a tab-separated daily time series without header:

    YYYY-MM-DD    value

Output files
------------
For each country, severity rank and optimized duration range, the script writes:

    optimal_mix_summary_<COUNTRY>.txt
    scan_objective_<COUNTRY>.txt
    DAILY_MIX_NUTS0_<COUNTRY>.txt
    DAILY_NORMMEAN_MIX_NUTS0_<COUNTRY>.txt
    RANKING_MIX_NUTS0_<COUNTRY>.txt
    
Example
-------
python optimize_solar_wind_mix.py \
    --input-root data \
    --output-root results \
    --countries ES,FR,IT \
    --severities 1,2,3,4,5,6,7,8,9,10 \
    --dmin 1 \
    --dmax 10  
"""

from __future__ import annotations

import argparse
from pathlib import Path
from datetime import timedelta
import numpy as np
import pandas as pd


TXT_SEP = "\t"


def parse_int_list(text: str) -> list[int]:
    return [int(x.strip()) for x in text.split(",") if x.strip()]


def parse_str_list(text: str) -> list[str]:
    return [x.strip() for x in text.split(",") if x.strip()]


def read_daily_series(path: Path) -> pd.Series:
    df = pd.read_csv(
        path,
        sep=TXT_SEP,
        header=None,
        names=["date", "value"],
        dtype={"date": str},
    )
    dates = pd.to_datetime(df["date"], format="%Y-%m-%d", errors="raise")
    series = pd.Series(df["value"].to_numpy(dtype=np.float64), index=dates)
    series = series.sort_index()
    series.index.name = "date"
    return series


def find_input_files(input_root: Path, country: str, period_tag: str) -> tuple[Path, Path]:
    spv_file = (
        input_root
        / f"SPV-NUTS0-{period_tag}"
        / f"DAILY_SPV_NUTS0_{country}.txt"
    )
    wto_file = (
        input_root
        / f"WTO-NUTS0-{period_tag}"
        / f"DAILY_WTO_NUTS0_{country}.txt"
    )

    if not spv_file.exists():
        raise FileNotFoundError(f"Missing input file: {spv_file}")
    if not wto_file.exists():
        raise FileNotFoundError(f"Missing input file: {wto_file}")

    return spv_file, wto_file


def ranking_for_duration(
    series_norm: pd.Series,
    duration: int,
    n_events: int,
) -> pd.DataFrame:
    df = pd.DataFrame({"cf": series_norm.copy()}).sort_index()

    start_date = df.index.min()
    end_date = df.index.max()
    nyears = round((end_date - start_date).days / 365.25)

    rolling = (
        df["cf"]
        .rolling(window=duration, min_periods=duration)
        .mean()
        .shift(-duration // 2)
    )

    available = rolling.copy()
    events: list[dict] = []

    for _ in range(n_events):
        if available.isna().all():
            break

        idx_min = available.idxmin()
        val_min = available.loc[idx_min]

        if pd.isna(val_min):
            break

        rank = len(events) + 1
        return_period = (nyears + 1) / rank

        events.append(
            {
                "duration": duration,
                "value": float(val_min),
                "date": idx_min.date().isoformat(),
                "severity": rank,
                "return_period": return_period,
            }
        )

        start = idx_min - timedelta(days=(duration // 2))
        end = start + timedelta(days=duration - 1)
        available.loc[start:end] = np.nan

    return pd.DataFrame(events)


def full_ranking(
    series_norm: pd.Series,
    dmin: int,
    dmax: int,
    n_events: int,
) -> pd.DataFrame:
    rankings = [
        ranking_for_duration(series_norm, duration, n_events)
        for duration in range(dmin, dmax + 1)
    ]

    output = pd.concat(rankings, axis=0, ignore_index=True)
    output["return_period"] = output["return_period"].round(1)
    return output


def objective_from_ranking(
    ranking: pd.DataFrame,
    target_severity: int,
    dmin: int,
    dmax: int,
) -> float:
    subset = ranking[
        (ranking["severity"] == target_severity)
        & (ranking["duration"] >= dmin)
        & (ranking["duration"] <= dmax)
    ]

    expected = dmax - dmin + 1
    if len(subset) != expected:
        raise RuntimeError(
            f"Missing IDF values for severity={target_severity} "
            f"between D={dmin} and D={dmax}: found {len(subset)}, expected {expected}."
        )

    return float(subset["value"].sum())


def build_mix(
    spv: pd.Series,
    wto: pd.Series,
    solar_share: float,
) -> tuple[pd.Series, pd.Series]:
    wind_share = 1.0 - solar_share

    common_idx = spv.index.intersection(wto.index)
    if len(common_idx) == 0:
        raise RuntimeError("SPV and WTO series have no common dates.")

    spv_common = spv.loc[common_idx]
    wto_common = wto.loc[common_idx]

    mix = solar_share * spv_common + wind_share * wto_common
    mix_mean = float(np.nanmean(mix.to_numpy(dtype=np.float64)))

    if not np.isfinite(mix_mean) or mix_mean == 0.0:
        raise RuntimeError("The mean mixed CF is not finite or is zero.")

    mix_norm = mix / mix_mean
    return mix, mix_norm


def write_series(path: Path, series: pd.Series) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)

    with path.open("w", encoding="utf-8") as file:
        for date, value in series.items():
            if pd.isna(value):
                continue
            file.write(f"{date.strftime('%Y-%m-%d')}{TXT_SEP}{value}\n")


def write_ranking(path: Path, ranking: pd.DataFrame) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)

    with path.open("w", encoding="utf-8") as file:
        file.write("duration\tvalue\tdate\tseverity\treturn_period\n")
        for _, row in ranking.iterrows():
            file.write(
                f"{int(row['duration'])}\t"
                f"{row['value']:.6f}\t"
                f"{row['date']}\t"
                f"{int(row['severity'])}\t"
                f"{row['return_period']:.1f}\n"
            )


def optimize_country_severity(
    input_root: Path,
    output_root: Path,
    country: str,
    target_severity: int,
    period_tag: str,
    solar_min: float,
    solar_max: float,
    solar_step: float,
    dmin: int,
    dmax: int,
    full_dmin: int,
    full_dmax: int,
    n_output_events: int,
) -> None:
    output_dir = output_root / country / f"Severity{target_severity}" / f"D{dmin}-{dmax}"
    output_dir.mkdir(parents=True, exist_ok=True)

    spv_file, wto_file = find_input_files(input_root, country, period_tag)

    print(f"\n=== {country} | Severity {target_severity} | D={dmin}-{dmax} ===")
    spv = read_daily_series(spv_file)
    wto = read_daily_series(wto_file)

    solar_values = np.arange(solar_min, solar_max + 0.5 * solar_step, solar_step)
    scan_results: list[tuple[float, float, float]] = []
    best: dict | None = None

    n_events_for_optimization = max(target_severity, n_output_events)

    for solar_share in solar_values:
        wind_share = 1.0 - solar_share

        mix, mix_norm = build_mix(spv, wto, float(solar_share))

        ranking_opt = full_ranking(
            mix_norm,
            dmin,
            dmax,
            n_events_for_optimization,
        )
        objective = objective_from_ranking(
            ranking_opt,
            target_severity,
            dmin,
            dmax,
        )

        scan_results.append((float(solar_share), float(wind_share), objective))

        if best is None or objective > best["objective"]:
            best = {
                "S": float(solar_share),
                "W": float(wind_share),
                "objective": float(objective),
                "mix": mix.copy(),
                "mix_norm": mix_norm.copy(),
            }

    if best is None:
        raise RuntimeError("No optimal solution was found.")

    ranking_full = full_ranking(
        best["mix_norm"],
        full_dmin,
        full_dmax,
        max(target_severity, n_output_events),
    )
    best["ranking"] = ranking_full[
        ranking_full["severity"] <= n_output_events
    ].copy()

    summary_file = output_dir / f"optimal_mix_summary_{country}.txt"
    with summary_file.open("w", encoding="utf-8") as file:
        file.write(f"country{TXT_SEP}{country}\n")
        file.write(f"target_severity{TXT_SEP}{target_severity}\n")
        file.write(f"duration_min{TXT_SEP}{dmin}\n")
        file.write(f"duration_max{TXT_SEP}{dmax}\n")
        file.write(f"n_output_events{TXT_SEP}{n_output_events}\n")
        file.write(f"S_opt{TXT_SEP}{best['S']:.6f}\n")
        file.write(f"W_opt{TXT_SEP}{best['W']:.6f}\n")
        file.write(f"objective_sum{TXT_SEP}{best['objective']:.6f}\n")
        file.write(f"s_grid_min{TXT_SEP}{solar_min}\n")
        file.write(f"s_grid_max{TXT_SEP}{solar_max}\n")
        file.write(f"s_grid_step{TXT_SEP}{solar_step}\n")

    scan_file = output_dir / f"scan_objective_{country}.txt"
    pd.DataFrame(
        scan_results,
        columns=["S", "W", "objective_sum"],
    ).to_csv(scan_file, sep=TXT_SEP, index=False, float_format="%.6f")

    write_series(output_dir / f"DAILY_MIX_NUTS0_{country}.txt", best["mix"])
    write_series(output_dir / f"DAILY_NORMMEAN_MIX_NUTS0_{country}.txt", best["mix_norm"])
    write_ranking(output_dir / f"RANKING_MIX_NUTS0_{country}.txt", best["ranking"])

    print(
        f"Optimal mix: S={best['S']:.4f}, W={best['W']:.4f}, "
        f"objective={best['objective']:.6f}"
    )
    print(f"Output directory: {output_dir}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Optimize solar–wind shares using an IDF-based drought objective."
    )

    parser.add_argument(
        "--input-root",
        required=True,
        type=Path,
        help="Root directory containing SPV-NUTS0-<period_tag> and WTO-NUTS0-<period_tag> subdirectories.",
    )
    parser.add_argument(
        "--output-root",
        required=True,
        type=Path,
        help="Root directory where optimization outputs will be written.",
    )
    parser.add_argument(
        "--countries",
        required=True,
        help="Comma-separated list of country codes, e.g. ES,FR,IT.",
    )
    parser.add_argument(
        "--severities",
        default="1,2,3,4,5,6,7,8,9,10",
        help="Comma-separated list of target severity ranks.",
    )
    parser.add_argument(
        "--period-tag",
        default="1950-2025",
        help="Period tag used in input subdirectory names.",
    )
    parser.add_argument("--s-min", type=float, default=0.0, help="Minimum solar share.")
    parser.add_argument("--s-max", type=float, default=1.0, help="Maximum solar share.")
    parser.add_argument("--s-step", type=float, default=0.1, help="Solar-share grid step.")

    parser.add_argument(
        "--dmin",
        type=int,
        required=True,
        help="Minimum duration used in the optimization objective.",
    )
    parser.add_argument(
        "--dmax",
        type=int,
        required=True,
        help="Maximum duration used in the optimization objective.",
    )
    parser.add_argument(
        "--full-dmin",
        type=int,
        default=1,
        help="Minimum duration written to the final full IDF ranking.",
    )
    parser.add_argument(
        "--full-dmax",
        type=int,
        default=90,
        help="Maximum duration written to the final full IDF ranking.",
    )
    parser.add_argument(
        "--n-output-events",
        type=int,
        default=10,
        help="Number of ranked drought events retained for each duration in the output.",
    )

    return parser.parse_args()


def main() -> None:
    args = parse_args()

    if args.dmin > args.dmax:
        raise ValueError("dmin must be lower than or equal to dmax.")
    if args.dmin < 1:
        raise ValueError("dmin must be >= 1.")
    if args.full_dmin > args.full_dmax:
        raise ValueError("full-dmin must be lower than or equal to full-dmax.")
    if args.s_step <= 0:
        raise ValueError("s-step must be > 0.")

    countries = parse_str_list(args.countries)
    severities = parse_int_list(args.severities)

    for country in countries:
        for severity in severities:
            optimize_country_severity(
                input_root=args.input_root,
                output_root=args.output_root,
                country=country,
                target_severity=severity,
                period_tag=args.period_tag,
                solar_min=args.s_min,
                solar_max=args.s_max,
                solar_step=args.s_step,
                dmin=args.dmin,
                dmax=args.dmax,
                full_dmin=args.full_dmin,
                full_dmax=args.full_dmax,
                n_output_events=args.n_output_events,
            )


if __name__ == "__main__":
    main()

