Source code for workflow.scripts.solve_model

# SPDX-FileCopyrightText: 2025 Koen van Greevenbroek
#
# SPDX-License-Identifier: GPL-3.0-or-later

import logging
import math
from collections import defaultdict
from typing import Dict

import numpy as np
import pandas as pd
import pypsa
import xarray as xr

try:  # Used for type annotations / documentation; fallback when unavailable
    import linopy  # type: ignore
except Exception:  # pragma: no cover - documentation build without linopy

    class linopy:  # type: ignore
        class Model:  # Minimal stub to satisfy type checkers and autodoc
            pass


from pypsa._options import options
from pypsa.optimization.optimize import _optimize_guard


HEALTH_AUX_MAP: dict[int, set[str]] = {}


def _register_health_variable(m: "linopy.Model", name: str) -> None:
    aux = HEALTH_AUX_MAP.setdefault(id(m), set())
    aux.add(name)


_SOS2_COUNTER = [0]  # Use list for mutable counter


def _add_sos2_with_fallback(m, variable, sos_dim: str, solver_name: str) -> list[str]:
    """Add SOS2 or binary fallback depending on solver support."""

    if solver_name.lower() != "highs":
        m.add_sos_constraints(variable, sos_type=2, sos_dim=sos_dim)
        return []

    coords = variable.coords[sos_dim]
    n_points = len(coords)
    if n_points <= 1:
        return []

    other_dims = [dim for dim in variable.dims if dim != sos_dim]

    interval_dim = f"{sos_dim}_interval"
    suffix = 1
    while interval_dim in variable.dims:
        interval_dim = f"{sos_dim}_interval{suffix}"
        suffix += 1

    interval_index = pd.Index(range(n_points - 1), name=interval_dim)
    binary_coords = [variable.coords[d] for d in other_dims] + [interval_index]

    # Use counter instead of checking existing variables
    _SOS2_COUNTER[0] += 1
    base_name = f"{variable.name}_segment" if variable.name else "health_segment"
    binary_name = f"{base_name}_{_SOS2_COUNTER[0]}"

    binaries = m.add_variables(coords=binary_coords, binary=True, name=binary_name)

    m.add_constraints(binaries.sum(interval_dim) == 1)

    # Vectorize SOS2 constraints: variable[i] <= binary[i-1] + binary[i]
    if n_points >= 2:
        # Build adjacency matrix using numpy indexing
        adjacency_data = np.zeros((n_points, n_points - 1))

        # Fill adjacency matrix using vectorized operations
        # binary[i] affects variable[i] and variable[i+1]
        indices = np.arange(n_points - 1)
        adjacency_data[indices, indices] = 1  # binary[i] -> variable[i]
        adjacency_data[indices + 1, indices] = 1  # binary[i] -> variable[i+1]

        # Convert to DataArray with proper coordinates
        adjacency = xr.DataArray(
            adjacency_data,
            coords={sos_dim: coords, interval_dim: range(n_points - 1)},
            dims=[sos_dim, interval_dim],
        )

        # Create constraint: variables <= adjacency @ binaries
        rhs = (adjacency * binaries).sum(interval_dim)
        m.add_constraints(variable <= rhs)

    return [binary_name]


OBJECTIVE_COEFF_TARGET = 1e8

logger = logging.getLogger(__name__)


[docs] def rescale_objective( m: "linopy.Model", target_max_coeff: float = OBJECTIVE_COEFF_TARGET ) -> float: """Scale objective to keep coefficients within numerical comfort zone.""" dataset = m.objective.data if "coeffs" not in dataset: return 1.0 coeffs = dataset["coeffs"].values if coeffs.size == 0: return 1.0 max_coeff = float(np.nanmax(np.abs(coeffs))) if not math.isfinite(max_coeff) or max_coeff == 0.0: return 1.0 if max_coeff <= target_max_coeff: return 1.0 exponent = math.ceil(math.log10(max_coeff / target_max_coeff)) if exponent <= 0: return 1.0 scale_factor = 10.0**exponent m.objective = m.objective / scale_factor logger.warning( "Scaled objective by 1/%s to reduce max coefficient from %.3g to %.3g", scale_factor, max_coeff, max_coeff / scale_factor, ) return scale_factor
[docs] def sanitize_identifier(value: str) -> str: return ( value.replace(" ", "_") .replace("(", "") .replace(")", "") .replace("/", "_") .replace("-", "_") )
[docs] def sanitize_food_name(food: str) -> str: return sanitize_identifier(food)
[docs] def add_health_objective( n: pypsa.Network, risk_breakpoints_path: str, cluster_cause_path: str, cause_log_path: str, cluster_summary_path: str, clusters_path: str, population_totals_path: str, risk_factors: list[str], solver_name: str, value_per_yll: float, ) -> None: """Add SOS2-based health costs with log-linear aggregation.""" m = n.model risk_breakpoints = pd.read_csv(risk_breakpoints_path) cluster_cause = pd.read_csv(cluster_cause_path) cause_log_breakpoints = pd.read_csv(cause_log_path) cluster_summary = pd.read_csv(cluster_summary_path) if "health_cluster" in cluster_summary.columns: cluster_summary["health_cluster"] = cluster_summary["health_cluster"].astype( int ) cluster_map = pd.read_csv(clusters_path) population_totals = pd.read_csv(population_totals_path) # Load food→risk factor mapping from food_groups.csv (only GBD risk factors) food_groups_df = pd.read_csv(snakemake.input.food_groups) food_map = food_groups_df[food_groups_df["group"].isin(risk_factors)].copy() food_map = food_map.rename(columns={"group": "risk_factor"}) food_map["share"] = 1.0 food_map["sanitized"] = food_map["food"].apply(sanitize_food_name) food_map = food_map.set_index("sanitized")[["risk_factor", "share"]] cluster_lookup = cluster_map.set_index("country_iso3")["health_cluster"].to_dict() cluster_population_baseline = cluster_summary.set_index("health_cluster")[ "population_persons" ].to_dict() cluster_cause_metadata = cluster_cause.set_index(["health_cluster", "cause"]) population_totals = population_totals.dropna(subset=["iso3", "population"]).copy() population_totals["iso3"] = population_totals["iso3"].astype(str).str.upper() population_map = population_totals.set_index("iso3")["population"].to_dict() cluster_population_planning: dict[int, float] = defaultdict(float) for iso3, cluster in cluster_lookup.items(): value = float(population_map.get(iso3, 0.0)) if value <= 0: continue cluster_population_planning[int(cluster)] += value cluster_population: dict[int, float] = {} for cluster, baseline_pop in cluster_population_baseline.items(): planning_pop = cluster_population_planning.get(int(cluster), 0.0) if planning_pop > 0: cluster_population[int(cluster)] = planning_pop else: cluster_population[int(cluster)] = float(baseline_pop) risk_breakpoints = risk_breakpoints.sort_values( ["risk_factor", "intake_g_per_day", "cause"] ) cause_log_breakpoints = cause_log_breakpoints.sort_values(["cause", "log_rr_total"]) p = m.variables["Link-p"].sel(snapshot="now") terms_by_key: dict[tuple[int, str], list[tuple[float, object]]] = defaultdict(list) for link_name in p.coords["name"].values: if not isinstance(link_name, str) or not link_name.startswith("consume_"): continue base, _, country = link_name.rpartition("_") if not base.startswith("consume_") or len(country) != 3: continue sanitized_food = base[len("consume_") :] if sanitized_food not in food_map.index: continue risk_factor = food_map.at[sanitized_food, "risk_factor"] share = float(food_map.at[sanitized_food, "share"]) if share <= 0: continue cluster = cluster_lookup.get(country) if cluster is None: continue population = float(cluster_population.get(cluster, 0.0)) if population <= 0: continue coeff = share * 1_000_000.0 / (365.0 * population) var = p.sel(name=link_name) terms_by_key[(int(cluster), str(risk_factor))].append((coeff, var)) if not terms_by_key: logger.info("No consumption links map to health risk factors; skipping") return risk_data = {} for risk, grp in risk_breakpoints.groupby("risk_factor"): intakes = pd.Index(sorted(grp["intake_g_per_day"].unique()), name="intake") if intakes.empty: continue pivot = ( grp.pivot_table( index="intake_g_per_day", columns="cause", values="log_rr", aggfunc="first", ) .reindex(intakes, axis=0) .sort_index() ) risk_data[risk] = {"intakes": intakes, "log_rr": pivot} cause_breakpoint_data = { cause: df.sort_values("log_rr_total") for cause, df in cause_log_breakpoints.groupby("cause") } log_rr_totals: Dict[tuple[int, str], object] = {} # Group (cluster, risk) pairs by their intake coordinate patterns intake_groups: dict[tuple[float, ...], list[tuple[int, str]]] = defaultdict(list) for (cluster, risk), terms in terms_by_key.items(): risk_table = risk_data.get(risk) if risk_table is None or not terms: continue coords = risk_table["intakes"] if len(coords) == 0: continue coords_key = tuple(coords.values) intake_groups[coords_key].append((cluster, risk)) # Process each group with vectorized operations for coords_key, cluster_risk_pairs in intake_groups.items(): coords = pd.Index(coords_key, name="intake") # Create flattened index for this group cluster_risk_labels = [ f"c{cluster}_r{sanitize_identifier(risk)}" for cluster, risk in cluster_risk_pairs ] cluster_risk_index = pd.Index(cluster_risk_labels, name="cluster_risk") # Single vectorized variable creation lambdas_group = m.add_variables( lower=0, upper=1, coords=[cluster_risk_index, coords], name=f"health_lambda_group_{hash(coords_key)}", ) # Register all variables _register_health_variable(m, lambdas_group.name) # Single SOS2 constraint call for entire group aux_names = _add_sos2_with_fallback( m, lambdas_group, sos_dim="intake", solver_name=solver_name ) for aux_name in aux_names: _register_health_variable(m, aux_name) # Vectorized convexity constraints m.add_constraints(lambdas_group.sum("intake") == 1) # Process each (cluster, risk) for balance constraints and log_rr contributions coeff_intake = xr.DataArray( coords.values, coords={"intake": coords.values}, dims=["intake"] ) for (cluster, risk), label in zip(cluster_risk_pairs, cluster_risk_labels): terms = terms_by_key[(cluster, risk)] lambdas = lambdas_group.sel(cluster_risk=label) intake_expr = m.linexpr((coeff_intake, lambdas)).sum("intake") flow_expr = m.linexpr(*terms) m.add_constraints( flow_expr == intake_expr, name=f"health_intake_balance_c{cluster}_r{sanitize_identifier(risk)}", ) risk_table = risk_data[risk] log_rr_matrix = risk_table["log_rr"] for cause in log_rr_matrix.columns: values = log_rr_matrix[cause].to_numpy() coeff_log = xr.DataArray( values, coords={"intake": coords.values}, dims=["intake"] ) contrib = m.linexpr((coeff_log, lambdas)).sum("intake") key = (cluster, cause) if key in log_rr_totals: log_rr_totals[key] = log_rr_totals[key] + contrib else: log_rr_totals[key] = contrib constant_objective = 0.0 objective_expr = None # Group (cluster, cause) pairs by their log_total coordinate patterns log_total_groups: dict[tuple[float, ...], list[tuple[int, str]]] = defaultdict(list) cluster_cause_data: dict[tuple[int, str], dict] = {} for (cluster, cause), row in cluster_cause_metadata.iterrows(): cluster = int(cluster) cause = str(cause) yll_base = float(row.get("yll_base", 0.0)) if not math.isfinite(value_per_yll) or value_per_yll <= 0: continue if yll_base == 0 or not math.isfinite(yll_base): continue cause_bp = cause_breakpoint_data[cause] coords_key = tuple(cause_bp["log_rr_total"].values) if len(coords_key) == 1: raise ValueError( "Need at least two breakpoints for piecewise linear approximation" ) log_total_groups[coords_key].append((cluster, cause)) # Store data for later use log_rr_total_ref = float(row.get("log_rr_total_ref", 0.0)) cluster_cause_data[(cluster, cause)] = { "value_per_yll": value_per_yll, "yll_base": yll_base, "log_rr_total_ref": log_rr_total_ref, "rr_ref": math.exp(log_rr_total_ref), "cause_bp": cause_bp, } # Process each group with vectorized operations for coords_key, cluster_cause_pairs in log_total_groups.items(): coords = pd.Index(coords_key, name="log_total") # Create flattened index for this group cluster_cause_labels = [ f"c{cluster}_cause{sanitize_identifier(cause)}" for cluster, cause in cluster_cause_pairs ] cluster_cause_index = pd.Index(cluster_cause_labels, name="cluster_cause") # Single vectorized variable creation lambda_total_group = m.add_variables( lower=0, upper=1, coords=[cluster_cause_index, coords], name=f"health_lambda_total_group_{hash(coords_key)}", ) # Register all variables _register_health_variable(m, lambda_total_group.name) # Single SOS2 constraint call for entire group aux_names = _add_sos2_with_fallback( m, lambda_total_group, sos_dim="log_total", solver_name=solver_name ) for aux_name in aux_names: _register_health_variable(m, aux_name) # Vectorized convexity constraints m.add_constraints(lambda_total_group.sum("log_total") == 1) # Process each (cluster, cause) for balance constraints and objective coeff_log_total = xr.DataArray( coords.values, coords={"log_total": coords.values}, dims=["log_total"], ) for (cluster, cause), label in zip(cluster_cause_pairs, cluster_cause_labels): data = cluster_cause_data[(cluster, cause)] lambda_total = lambda_total_group.sel(cluster_cause=label) total_expr = log_rr_totals.get((cluster, cause), m.linexpr(0.0)) cause_bp = data["cause_bp"] log_interp = m.linexpr((coeff_log_total, lambda_total)).sum("log_total") coeff_rr = xr.DataArray( cause_bp["rr_total"].values, coords={"log_total": coords.values}, dims=["log_total"], ) rr_interp = m.linexpr((coeff_rr, lambda_total)).sum("log_total") sanitized_cause = sanitize_identifier(cause) m.add_constraints( log_interp == total_expr, name=f"health_total_balance_c{cluster}_cause{sanitized_cause}", ) coeff = data["value_per_yll"] * data["yll_base"] scaled_expr = rr_interp * (coeff / data["rr_ref"]) objective_expr = ( scaled_expr if objective_expr is None else objective_expr + scaled_expr ) constant_objective -= coeff if objective_expr is not None: m.objective = m.objective + objective_expr logger.info("Added health cost objective") if constant_objective != 0.0: adjustments = n.meta.setdefault("objective_constant_terms", {}) adjustments["health"] = adjustments.get("health", 0.0) + constant_objective logger.debug( "Recorded health objective constant %.3g in network metadata", constant_objective, ) if objective_expr is None and constant_objective == 0.0: logger.info("No health objective terms added (missing overlaps)")
if __name__ == "__main__": n = pypsa.Network(snakemake.input.network) # Create the linopy model logger.info("Creating linopy model...") n.optimize.create_model() logger.info("Linopy model created.") solver_name = snakemake.params.solver solver_options = snakemake.params.solver_options or {} # Add health impacts to the objective if data is available add_health_objective( n, snakemake.input.health_risk_breakpoints, snakemake.input.health_cluster_cause, snakemake.input.health_cause_log, snakemake.input.health_cluster_summary, snakemake.input.health_clusters, snakemake.input.population, snakemake.params.health_risk_factors, solver_name, float(snakemake.params.health_value_per_yll), ) scaling_factor = rescale_objective(n.model) if scaling_factor != 1.0: previous = n.meta.get("objective_scaling_factor", 1.0) n.meta["objective_scaling_factor"] = previous * scaling_factor n.meta["objective_scaling_target_coeff"] = OBJECTIVE_COEFF_TARGET status, condition = n.model.solve( solver_name=solver_name, **solver_options, ) result = (status, condition) if status == "ok": aux_names = HEALTH_AUX_MAP.pop(id(n.model), set()) variables_container = n.model.variables removed = {} for name in aux_names: if name in variables_container.data: removed[name] = variables_container.data.pop(name) try: n.optimize.assign_solution() n.optimize.assign_duals(False) n.optimize.post_processing() finally: if removed: variables_container.data.update(removed) if options.debug.runtime_verification: _optimize_guard(n) n.export_to_netcdf(snakemake.output.network) elif condition in {"infeasible", "infeasible_or_unbounded"}: logger.error("Model is infeasible or unbounded!") if solver_name == "gurobi": try: logger.error("Infeasible constraints:") n.model.print_infeasibilities() except Exception as exc: logger.error("Could not compute infeasibilities: %s", exc) else: logger.error("Infeasibility diagnosis only available with Gurobi solver") else: logger.error("Optimization unsuccessful: %s", result)