Source code for workflow.scripts.solve_model

# SPDX-FileCopyrightText: 2025 Koen van Greevenbroek
#
# SPDX-License-Identifier: GPL-3.0-or-later

from collections import defaultdict
import itertools
import math

from linopy.constraints import print_single_constraint
import numpy as np
import pandas as pd
import pypsa
import xarray as xr

from workflow.scripts import constants
from workflow.scripts.build_model.nutrition import (
    _build_food_group_equals_from_baseline,
)
from workflow.scripts.build_model.utils import _per_capita_mass_to_mt_per_year
from workflow.scripts.logging_config import setup_script_logging
from workflow.scripts.snakemake_utils import apply_scenario_config

try:  # Used for type annotations / documentation; fallback when unavailable
    import linopy  # type: ignore
except Exception:  # pragma: no cover - documentation build without linopy

    class linopy:  # type: ignore
        class Model:  # Minimal stub to satisfy type checkers and autodoc
            pass


# Enable new PyPSA components API
pypsa.options.api.new_components_api = True

# Helpers and state for health objective construction
HEALTH_AUX_MAP: dict[int, set[str]] = {}
_SOS2_COUNTER = [0]  # Use list for mutable counter
_LAMBDA_GROUP_COUNTER = itertools.count()
_TOTAL_GROUP_COUNTER = itertools.count()


def _register_health_variable(m: "linopy.Model", name: str) -> None:
    aux = HEALTH_AUX_MAP.setdefault(id(m), set())
    aux.add(name)


[docs] def add_macronutrient_constraints( n: pypsa.Network, macronutrient_cfg: dict | None, population: dict[str, float] ) -> None: """Add per-country macronutrient bounds directly to the linopy model. The bounds are expressed on the storage level of each macronutrient store. RHS values are converted from per-person-per-day units using stored population and nutrient unit metadata. """ if not macronutrient_cfg: return m = n.model store_e = m.variables["Store-e"].sel(snapshot="now") stores_df = n.stores.static for nutrient, bounds in macronutrient_cfg.items(): if not bounds: continue carrier_unit = n.carriers.static.at[nutrient, "unit"] nutrient_stores = stores_df[stores_df["carrier"] == nutrient] countries = nutrient_stores["country"].astype(str) lhs = store_e.sel(name=nutrient_stores.index) def rhs_from( value: float, carrier_unit=carrier_unit, countries=countries, nutrient_stores=nutrient_stores, ) -> xr.DataArray: # Carrier unit encodes the nutrient type: "Mt" for mass, "PJ" for energy (kcal) if carrier_unit == "Mt": rhs_vals = [ _per_capita_mass_to_mt_per_year( float(value), float(population[country]) ) for country in countries ] else: rhs_vals = [ float(value) * float(population[country]) * constants.DAYS_PER_YEAR * constants.KCAL_TO_PJ for country in countries ] return xr.DataArray( rhs_vals, coords={"name": nutrient_stores.index}, dims="name" ) for key, operator, label in ( ("equal", "==", "equal"), ("min", ">=", "min"), ("max", "<=", "max"), ): if bounds.get(key) is None: continue rhs = rhs_from(bounds[key]) constr_name = f"macronutrient_{label}_{nutrient}" if operator == "==": m.add_constraints(lhs == rhs, name=f"GlobalConstraint-{constr_name}") n.global_constraints.add( f"{constr_name}_" + nutrient_stores.index, sense="==", constant=rhs.values, type="nutrition", ) break if operator == ">=": m.add_constraints(lhs >= rhs, name=f"GlobalConstraint-{constr_name}") else: m.add_constraints(lhs <= rhs, name=f"GlobalConstraint-{constr_name}") n.global_constraints.add( f"{constr_name}_" + nutrient_stores.index, sense=operator, constant=rhs.values, type="nutrition", )
[docs] def add_food_group_constraints( n: pypsa.Network, food_group_cfg: dict | None, population: dict[str, float], per_country_equal: dict[str, dict[str, float]] | None = None, ) -> None: """Add per-country food group bounds on store levels.""" if not food_group_cfg and not per_country_equal: return food_group_cfg = food_group_cfg or {} per_country_equal = per_country_equal or {} m = n.model store_e = m.variables["Store-e"].sel(snapshot="now") stores_df = n.stores.static groups = set(food_group_cfg) | set(per_country_equal) for group in groups: bounds = food_group_cfg.get(group, {}) if not bounds and group not in per_country_equal: continue group_stores = stores_df[stores_df["carrier"] == f"group_{group}"] countries = group_stores["country"].astype(str) lhs = store_e.sel(name=group_stores.index) def rhs_from( value: float, countries=countries, group_stores=group_stores ) -> xr.DataArray: rhs_vals = [ _per_capita_mass_to_mt_per_year( float(value), float(population[country]) ) for country in countries ] return xr.DataArray( rhs_vals, coords={"name": group_stores.index}, dims="name" ) def rhs_from_equal( group=group, countries=countries, group_stores=group_stores, bounds=bounds ) -> xr.DataArray | None: overrides = per_country_equal.get(group) if overrides: rhs_vals = [ _per_capita_mass_to_mt_per_year( float(overrides[country]), float(population[country]) ) for country in countries ] return xr.DataArray( rhs_vals, coords={"name": group_stores.index}, dims="name" ) if bounds.get("equal") is None: return None return rhs_from(bounds["equal"]) # Apply at most one equality; otherwise allow independent min/max bounds for key, operator, label in ( ("equal", "==", "equal"), ("min", ">=", "min"), ("max", "<=", "max"), ): if key == "equal": rhs = rhs_from_equal() if rhs is None: continue else: if bounds.get(key) is None: continue rhs = rhs_from(bounds[key]) constr_name = f"food_group_{label}_{group}" if operator == "==": m.add_constraints(lhs == rhs, name=f"GlobalConstraint-{constr_name}") n.global_constraints.add( f"{constr_name}_" + group_stores.index, sense="==", constant=rhs.values, type="nutrition", ) break if operator == ">=": m.add_constraints(lhs >= rhs, name=f"GlobalConstraint-{constr_name}") else: m.add_constraints(lhs <= rhs, name=f"GlobalConstraint-{constr_name}") n.global_constraints.add( f"{constr_name}_" + group_stores.index, sense=operator, constant=rhs.values, type="nutrition", )
def _add_sos2_with_fallback(m, variable, sos_dim: str, solver_name: str) -> list[str]: """Add SOS2 or binary fallback depending on solver support.""" if solver_name.lower() != "highs": m.add_sos_constraints(variable, sos_type=2, sos_dim=sos_dim) return [] coords = variable.coords[sos_dim] n_points = len(coords) if n_points <= 1: return [] other_dims = [dim for dim in variable.dims if dim != sos_dim] interval_dim = f"{sos_dim}_interval" suffix = 1 while interval_dim in variable.dims: interval_dim = f"{sos_dim}_interval{suffix}" suffix += 1 interval_index = pd.Index(range(n_points - 1), name=interval_dim) binary_coords = [variable.coords[d] for d in other_dims] + [interval_index] # Use counter instead of checking existing variables _SOS2_COUNTER[0] += 1 base_name = f"{variable.name}_segment" if variable.name else "health_segment" binary_name = f"{base_name}_{_SOS2_COUNTER[0]}" binaries = m.add_variables(coords=binary_coords, binary=True, name=binary_name) m.add_constraints(binaries.sum(interval_dim) == 1) # Vectorize SOS2 constraints: variable[i] <= binary[i-1] + binary[i] if n_points >= 2: adjacency_data = np.zeros((n_points, n_points - 1)) indices = np.arange(n_points - 1) adjacency_data[indices, indices] = 1 adjacency_data[indices + 1, indices] = 1 adjacency = xr.DataArray( adjacency_data, coords={sos_dim: coords, interval_dim: range(n_points - 1)}, dims=[sos_dim, interval_dim], ) rhs = (adjacency * binaries).sum(interval_dim) m.add_constraints(variable <= rhs) return [binary_name] def _apply_solver_threads_option( solver_options: dict, solver_name: str, threads: int ) -> dict: """Ensure the solver options include a threads override when configured.""" solver_key = solver_name.lower() if solver_key == "gurobi": solver_options["Threads"] = threads elif solver_key == "highs": solver_options["threads"] = threads return solver_options
[docs] def add_ghg_pricing_to_objective(n: pypsa.Network, ghg_price_usd_per_t: float) -> None: """Add GHG emissions pricing to the objective function. Adds the cost of GHG emissions (stored in the 'ghg' store) to the objective function at solve time. Parameters ---------- n : pypsa.Network The network containing the model. ghg_price_usd_per_t : float Price per tonne of CO2-equivalent in USD (config currency_year). """ # Convert USD/tCO2 to bnUSD/MtCO2 (matching model units) ghg_price_bnusd_per_mt = ( ghg_price_usd_per_t / constants.TONNE_TO_MEGATONNE * constants.USD_TO_BNUSD ) # Add marginal storage cost to store n.stores.static.at["ghg", "marginal_cost_storage"] = ghg_price_bnusd_per_mt
[docs] def add_food_group_incentives_to_objective( n: pypsa.Network, incentives_paths: list[str] ) -> None: """Add food-group incentives/penalties to the objective function. Incentives are applied as adjustments to marginal storage costs of food group stores. Positive values penalize consumption; negative values subsidize consumption. Parameters ---------- n : pypsa.Network The network containing the model. incentives_paths : list[str] Paths to CSVs with columns: group, country, adjustment_bnusd_per_mt """ if not incentives_paths: raise ValueError("food_group_incentives enabled but no sources are configured") combined = [] for path in incentives_paths: incentives_df = pd.read_csv(path) required = {"group", "country", "adjustment_bnusd_per_mt"} missing = required - set(incentives_df.columns) if missing: missing_text = ", ".join(sorted(missing)) raise ValueError( f"Missing required columns in incentives file {path}: {missing_text}" ) incentives_df["country"] = incentives_df["country"].astype(str).str.upper() incentives_df["store_name"] = ( "store_" + incentives_df["group"] + "_" + incentives_df["country"] ) combined.append(incentives_df[["store_name", "adjustment_bnusd_per_mt"]].copy()) all_incentives = pd.concat(combined, ignore_index=True) summed = ( all_incentives.groupby("store_name")["adjustment_bnusd_per_mt"] .sum() .reset_index() ) if "marginal_cost_storage" not in n.stores.static.columns: n.stores.static["marginal_cost_storage"] = 0.0 store_index = n.stores.static.index missing_stores = summed[~summed["store_name"].isin(store_index)] if not missing_stores.empty: sample = ", ".join(missing_stores["store_name"].head(5)) logger.warning( "Missing %d food group stores for incentives (examples: %s)", len(missing_stores), sample, ) applicable = summed[summed["store_name"].isin(store_index)] if applicable.empty: logger.info( "No applicable food group incentives found in %d sources", len(incentives_paths), ) return n.stores.static.loc[applicable["store_name"], "marginal_cost_storage"] = ( n.stores.static.loc[applicable["store_name"], "marginal_cost_storage"].astype( float ) + applicable["adjustment_bnusd_per_mt"].astype(float).values ) logger.info( "Applied food-group incentives to %d stores from %d sources", len(applicable), len(incentives_paths), )
[docs] def build_residue_feed_fraction_by_country( config: dict, m49_path: str ) -> dict[str, float]: """Build per-country residue feed fraction overrides from config.""" overrides = config["residues"]["max_feed_fraction_by_region"] if not overrides: return {} countries = [str(country).upper() for country in config["countries"]] m49_df = pd.read_csv(m49_path, sep=";", encoding="utf-8-sig", comment="#") m49_df = m49_df[m49_df["ISO-alpha3 Code"].notna()] m49_df["iso3"] = m49_df["ISO-alpha3 Code"].astype(str).str.upper() m49_df = m49_df[m49_df["iso3"].isin(countries)] region_to_countries = m49_df.groupby("Region Name")["iso3"].apply(list).to_dict() subregion_to_countries = ( m49_df.groupby("Sub-region Name")["iso3"].apply(list).to_dict() ) region_overrides = { key: overrides[key] for key in overrides if key in region_to_countries } subregion_overrides = { key: overrides[key] for key in overrides if key in subregion_to_countries } country_overrides = {key: overrides[key] for key in overrides if key in countries} unknown = ( set(overrides) - set(region_overrides) - set(subregion_overrides) - set(country_overrides) ) if unknown: unknown_text = ", ".join(sorted(unknown)) raise ValueError( f"Unknown residues.max_feed_fraction_by_region keys: {unknown_text}" ) per_country: dict[str, float] = {} for region, value in region_overrides.items(): for country in region_to_countries[region]: per_country[country] = float(value) for subregion, value in subregion_overrides.items(): for country in subregion_to_countries[subregion]: per_country[country] = float(value) for country, value in country_overrides.items(): per_country[country] = float(value) return per_country
def _residue_bus_country(residue_bus: str) -> str: return residue_bus.rsplit("_", 1)[-1].upper()
[docs] def add_residue_feed_constraints( n: pypsa.Network, max_feed_fraction: float, max_feed_fraction_by_country: dict[str, float], ) -> None: """Add constraints limiting residue removal for animal feed. Constrains the fraction of residues that can be removed for feed vs. incorporated into soil. The constraint is formulated as:: feed_use ≤ (max_feed_fraction / (1 - max_feed_fraction)) x incorporation This ensures that if a total amount R of residue is generated:: R = feed_use + incorporation feed_use ≤ max_feed_fraction x R Parameters ---------- n : pypsa.Network The network containing the model. max_feed_fraction : float Maximum fraction of residues that can be used for feed (e.g., 0.30 for 30%). max_feed_fraction_by_country : dict[str, float] Overrides keyed by ISO3 country code. """ m = n.model # Get link flow variables and link data link_p = m.variables["Link-p"].sel(snapshot="now") links_df = n.links.static # Find residue feed links (carrier="convert_to_feed", bus0 starts with "residue_") feed_mask = (links_df["carrier"] == "convert_to_feed") & ( links_df["bus0"].str.startswith("residue_") ) feed_links_df = links_df[feed_mask] # Find incorporation links (carrier="residue_incorporation") incorp_mask = links_df["carrier"] == "residue_incorporation" incorp_links_df = links_df[incorp_mask] if feed_links_df.empty or incorp_links_df.empty: logger.info( "No residue feed limit constraints added (missing feed or incorporation links)" ) return # Identify common residue buses feed_buses = set(feed_links_df["bus0"].unique()) incorp_buses = set(incorp_links_df["bus0"].unique()) common_buses = sorted(feed_buses.intersection(incorp_buses)) if not common_buses: logger.info( "No residue feed limit constraints added (no matching residue flows found)" ) return # Filter DataFrames to common buses feed_links_df = feed_links_df[feed_links_df["bus0"].isin(common_buses)] incorp_links_df = incorp_links_df[incorp_links_df["bus0"].isin(common_buses)] # Prepare mapping DataArrays for groupby # Map feed link names to their residue bus feed_bus_map = xr.DataArray( feed_links_df["bus0"], coords={"name": feed_links_df.index}, dims="name", name="residue_bus", ) # Map incorp link names to their residue bus incorp_bus_map = xr.DataArray( incorp_links_df["bus0"], coords={"name": incorp_links_df.index}, dims="name", name="residue_bus", ) # Get variables feed_vars = link_p.sel(name=feed_links_df.index) incorp_vars = link_p.sel(name=incorp_links_df.index) # Sum/Group # Group feed vars by residue bus and sum feed_sum = feed_vars.groupby(feed_bus_map).sum() # Group incorp vars by residue bus and sum (handles alignment) incorp_flow = incorp_vars.groupby(incorp_bus_map).sum() ratios = [] for bus in common_buses: country = _residue_bus_country(bus) max_fraction = max_feed_fraction_by_country.get(country, max_feed_fraction) ratios.append(max_fraction / (1.0 - max_fraction)) ratio = xr.DataArray( ratios, coords={"residue_bus": common_buses}, dims="residue_bus" ) # Add constraints constr_name = "residue_feed_limit" m.add_constraints( feed_sum <= ratio * incorp_flow, name=f"GlobalConstraint-{constr_name}", ) # Add GlobalConstraints for shadow price tracking gc_names = [f"{constr_name}_{bus}" for bus in common_buses] n.global_constraints.add( gc_names, sense="<=", constant=0.0, # RHS is dynamic (depends on incorp_flow), use 0 as placeholder type="residue_feed", ) if max_feed_fraction_by_country: logger.info( "Applied residue feed fraction overrides for %d countries", len(max_feed_fraction_by_country), ) logger.info( "Added %d residue feed limit constraints (max %.0f%% for feed)", len(common_buses), max_feed_fraction * 100, )
[docs] def add_animal_production_constraints( n: pypsa.Network, fao_production: pd.DataFrame, food_to_group: dict[str, str], loss_waste: pd.DataFrame, ) -> None: """Add constraints to fix animal production at FAO levels per country. For each (country, product) combination in the FAO data, adds a constraint that total production from all feed categories equals the FAO target, adjusted for food loss and waste. Since the model applies FLW to the feed→product efficiency, the constraint target must also be adjusted to net production (gross FAO production * (1-loss) * (1-waste)). Parameters ---------- n : pypsa.Network The network containing the model. fao_production : pd.DataFrame FAO production data with columns: country, product, production_mt. food_to_group : dict[str, str] Mapping from product names to food group names for FLW lookup. loss_waste : pd.DataFrame Food loss and waste fractions with columns: country, food_group, loss_fraction, waste_fraction. """ if fao_production.empty: logger.warning( "No FAO animal production data available; skipping production constraints" ) return # Build FLW lookup: (country, food_group) -> (1-loss)*(1-waste) flw_multipliers: dict[tuple[str, str], float] = {} for _, row in loss_waste.iterrows(): key = (str(row["country"]), str(row["food_group"])) loss_frac = float(row["loss_fraction"]) waste_frac = float(row["waste_fraction"]) flw_multipliers[key] = (1.0 - loss_frac) * (1.0 - waste_frac) m = n.model link_p = m.variables["Link-p"].sel(snapshot="now") links_df = n.links.static # Filter to animal production links using carrier # Animal production links have carriers starting with "produce_" prod_mask = links_df["carrier"].str.startswith("produce_") prod_links = links_df[prod_mask] if prod_links.empty: logger.info("No animal production links found.") return products = prod_links["product"].astype(str) countries = prod_links["country"].astype(str) # Prepare DataArrays aligned with the filtered links link_names = prod_links.index # Efficiencies efficiencies = xr.DataArray( prod_links["efficiency"].values, coords={"name": link_names}, dims="name" ) # Production = p * efficiency # Group by (product, country) and sum production_vars = link_p.sel(name=link_names) grouper = pd.MultiIndex.from_arrays( [products.values, countries.values], names=["product", "country"] ) da_grouper = xr.DataArray(grouper, coords={"name": link_names}, dims="name") total_production = (production_vars * efficiencies).groupby(da_grouper).sum() target_series = fao_production.set_index(["product", "country"])[ "production_mt" ].astype(float) # Adjust targets by FLW: net_target = gross_target * (1-loss) * (1-waste) adjusted_targets = [] for product, country in target_series.index: gross_value = target_series.loc[(product, country)] group = food_to_group[product] multiplier = flw_multipliers[(country, group)] adjusted_targets.append(gross_value * multiplier) target_series = pd.Series(adjusted_targets, index=target_series.index) model_index = pd.Index(total_production.coords["group"].values, name="group") common_index = model_index.intersection(target_series.index) if common_index.empty: logger.warning( "No matching animal production targets found for model structure." ) return lhs = total_production.sel(group=common_index) rhs = xr.DataArray( target_series.loc[common_index].values, coords={"group": common_index}, dims="group", ) constr_name = "animal_production_target" m.add_constraints(lhs == rhs, name=f"GlobalConstraint-{constr_name}") # Add GlobalConstraints for shadow price tracking gc_names = [f"{constr_name}_{prod}_{country}" for prod, country in common_index] n.global_constraints.add( gc_names, sense="==", constant=rhs.values, type="production_target", ) logger.info( "Added %d country-level animal production constraints (FLW-adjusted)", len(common_index), )
[docs] def add_production_stability_constraints( n: pypsa.Network, crop_baseline: pd.DataFrame | None, crop_to_fao_item: dict[str, str], animal_baseline: pd.DataFrame | None, stability_cfg: dict, food_to_group: dict[str, str], loss_waste: pd.DataFrame, ) -> None: """Add constraints limiting production deviation from baseline levels. For crops and animal products, applies per-(product, country) bounds: ``(1 - delta) * baseline <= production <= (1 + delta) * baseline`` Products with zero baseline are constrained to zero production. Note: Multi-cropping is disabled when production stability is enabled. Parameters ---------- n : pypsa.Network The network containing the model. crop_baseline : pd.DataFrame | None FAO crop production with columns: country, crop, production_tonnes. crop_to_fao_item : dict[str, str] Mapping from crop names to FAO item names; used to aggregate crops that share an FAO item (e.g., dryland-rice and wetland-rice both map to "Rice"). animal_baseline : pd.DataFrame | None FAO animal production with columns: country, product, production_mt. stability_cfg : dict Configuration with enabled, crops.max_relative_deviation, etc. food_to_group : dict[str, str] Mapping from product names to food group names for FLW lookup. loss_waste : pd.DataFrame Food loss and waste fractions with columns: country, food_group, loss_fraction, waste_fraction. """ if not stability_cfg["enabled"]: return m = n.model link_p = m.variables["Link-p"].sel(snapshot="now") links_df = n.links.static # --- CROP PRODUCTION BOUNDS --- crops_cfg = stability_cfg["crops"] if crops_cfg["enabled"] and crop_baseline is not None: _add_crop_stability_constraints( n, link_p, links_df, crop_baseline, crop_to_fao_item, crops_cfg ) # --- ANIMAL PRODUCTION BOUNDS --- animals_cfg = stability_cfg["animals"] if animals_cfg["enabled"] and animal_baseline is not None: _add_animal_stability_constraints( n, link_p, links_df, animal_baseline, animals_cfg, food_to_group, loss_waste )
def _add_crop_stability_constraints( n: pypsa.Network, link_p, links_df: pd.DataFrame, crop_baseline: pd.DataFrame, crop_to_fao_item: dict[str, str], crops_cfg: dict, ) -> None: """Add crop production stability bounds. Crops that share a FAO item (e.g., dryland-rice and wetland-rice both map to "Rice") are aggregated together for the constraint. """ m = n.model delta = crops_cfg["max_relative_deviation"] # Filter to crop production links using the crop column # Note: some links have empty string instead of NaN, so check for both crop_mask = links_df["crop"].notna() & (links_df["crop"] != "") crop_links = links_df[crop_mask].copy() if crop_links.empty: logger.info( "No crop production links found; skipping crop stability constraints" ) return crops = crop_links["crop"].astype(str) countries = crop_links["country"].astype(str) link_names = crop_links.index # Map crops to FAO items; use crop name as fallback for unmapped crops fao_items = crops.map(lambda c: crop_to_fao_item.get(c, c)) # Filter out crops with empty/nan FAO item (e.g., alfalfa, biomass-sorghum) valid_mask = ( fao_items.notna() & (fao_items != "") & (fao_items.str.lower() != "nan") ) if not valid_mask.any(): logger.info( "No crops with FAO item mappings; skipping crop stability constraints" ) return fao_items = fao_items[valid_mask] countries_filtered = countries[valid_mask] link_names_filtered = link_names[valid_mask] efficiencies_filtered = crop_links.loc[valid_mask, "efficiency"].values # Efficiencies (yield: Mt/Mha) efficiencies = xr.DataArray( efficiencies_filtered, coords={"name": link_names_filtered}, dims="name" ) # Production = p * efficiency (p is land in Mha) production_vars = link_p.sel(name=link_names_filtered) # Group by (fao_item, country) to aggregate related crops grouper = pd.MultiIndex.from_arrays( [fao_items.values, countries_filtered.values], names=["fao_item", "country"] ) da_grouper = xr.DataArray( grouper, coords={"name": link_names_filtered}, dims="name" ) total_production = (production_vars * efficiencies).groupby(da_grouper).sum() # Convert baseline to Mt and aggregate by FAO item baseline_df = crop_baseline.copy() baseline_df["production_mt"] = baseline_df["production_tonnes"] * 1e-6 # Map baseline crops to FAO items baseline_df["fao_item"] = baseline_df["crop"].map( lambda c: crop_to_fao_item.get(c, c) ) # Aggregate baseline by (fao_item, country) - this sums the split values back baseline_agg = ( baseline_df.groupby(["fao_item", "country"])["production_mt"] .sum() .reset_index() ) target_series = baseline_agg.set_index(["fao_item", "country"])["production_mt"] # Match to model index model_index = pd.Index(total_production.coords["group"].values, name="group") common_index = model_index.intersection(target_series.index) if common_index.empty: logger.warning("No matching crop production targets for stability bounds") return # Build RHS bounds baselines = target_series.loc[common_index].values lower_bounds = np.maximum(0.0, (1.0 - delta) * baselines) upper_bounds = (1.0 + delta) * baselines rhs_lower = xr.DataArray(lower_bounds, coords={"group": common_index}, dims="group") rhs_upper = xr.DataArray(upper_bounds, coords={"group": common_index}, dims="group") # Handle zero baselines: force production to zero zero_mask = baselines == 0 nonzero_mask = ~zero_mask if zero_mask.any(): zero_index = common_index[zero_mask] lhs_zero = total_production.sel(group=zero_index) constr_name = "crop_production_zero" m.add_constraints(lhs_zero == 0, name=f"GlobalConstraint-{constr_name}") gc_names = [ f"{constr_name}_{fao_item}_{country}" for fao_item, country in zero_index ] n.global_constraints.add( gc_names, sense="==", constant=0.0, type="production_stability", ) logger.info( "Added %d crop production constraints for zero-baseline (fao_item, country) pairs", int(zero_mask.sum()), ) if nonzero_mask.any(): nonzero_index = common_index[nonzero_mask] lhs_nonzero = total_production.sel(group=nonzero_index) lower_nonzero = rhs_lower.sel(group=nonzero_index) upper_nonzero = rhs_upper.sel(group=nonzero_index) constr_name_min = "crop_production_min" constr_name_max = "crop_production_max" m.add_constraints( lhs_nonzero >= lower_nonzero, name=f"GlobalConstraint-{constr_name_min}" ) m.add_constraints( lhs_nonzero <= upper_nonzero, name=f"GlobalConstraint-{constr_name_max}" ) gc_names_min = [ f"{constr_name_min}_{fao_item}_{country}" for fao_item, country in nonzero_index ] gc_names_max = [ f"{constr_name_max}_{fao_item}_{country}" for fao_item, country in nonzero_index ] n.global_constraints.add( gc_names_min, sense=">=", constant=lower_nonzero.values, type="production_stability", ) n.global_constraints.add( gc_names_max, sense="<=", constant=upper_nonzero.values, type="production_stability", ) logger.info( "Added %d crop production stability constraints (delta=%.0f%%)", 2 * int(nonzero_mask.sum()), delta * 100, ) # Log missing baselines (at FAO item level) missing = model_index.difference(target_series.index) if len(missing) > 0: examples = [f"{item}/{country}" for item, country in list(missing)[:5]] logger.warning( "Missing crop baseline data for %d (fao_item, country) pairs; examples: %s", len(missing), ", ".join(examples), ) def _add_animal_stability_constraints( n: pypsa.Network, link_p, links_df: pd.DataFrame, animal_baseline: pd.DataFrame, animals_cfg: dict, food_to_group: dict[str, str], loss_waste: pd.DataFrame, ) -> None: """Add animal production stability bounds. Reuses the aggregation logic from add_animal_production_constraints() but applies inequality bounds instead of equality. """ m = n.model delta = animals_cfg["max_relative_deviation"] # Build FLW lookup (same as add_animal_production_constraints) flw_multipliers: dict[tuple[str, str], float] = {} for _, row in loss_waste.iterrows(): key = (str(row["country"]), str(row["food_group"])) loss_frac = float(row["loss_fraction"]) waste_frac = float(row["waste_fraction"]) flw_multipliers[key] = (1.0 - loss_frac) * (1.0 - waste_frac) # Filter to animal production links using product column # Note: some links have empty string instead of NaN, so check for both prod_mask = links_df["product"].notna() & (links_df["product"] != "") prod_links = links_df[prod_mask] if prod_links.empty: logger.info( "No animal production links found; skipping animal stability constraints" ) return products = prod_links["product"].astype(str) countries = prod_links["country"].astype(str) link_names = prod_links.index efficiencies = xr.DataArray( prod_links["efficiency"].values, coords={"name": link_names}, dims="name" ) production_vars = link_p.sel(name=link_names) grouper = pd.MultiIndex.from_arrays( [products.values, countries.values], names=["product", "country"] ) da_grouper = xr.DataArray(grouper, coords={"name": link_names}, dims="name") total_production = (production_vars * efficiencies).groupby(da_grouper).sum() # Build FLW-adjusted targets (same logic as add_animal_production_constraints) target_series = animal_baseline.set_index(["product", "country"])[ "production_mt" ].astype(float) adjusted_targets = [] for product, country in target_series.index: gross_value = target_series.loc[(product, country)] group = food_to_group.get(product, product) multiplier = flw_multipliers.get((country, group), 1.0) adjusted_targets.append(gross_value * multiplier) target_series = pd.Series(adjusted_targets, index=target_series.index) model_index = pd.Index(total_production.coords["group"].values, name="group") common_index = model_index.intersection(target_series.index) if common_index.empty: logger.warning("No matching animal production targets for stability bounds") return # Build bounds baselines = target_series.loc[common_index].values lower_bounds = np.maximum(0.0, (1.0 - delta) * baselines) upper_bounds = (1.0 + delta) * baselines rhs_lower = xr.DataArray(lower_bounds, coords={"group": common_index}, dims="group") rhs_upper = xr.DataArray(upper_bounds, coords={"group": common_index}, dims="group") # Handle zero baselines: force production to zero zero_mask = baselines == 0 nonzero_mask = ~zero_mask if zero_mask.any(): zero_index = common_index[zero_mask] lhs_zero = total_production.sel(group=zero_index) constr_name = "animal_production_zero" m.add_constraints(lhs_zero == 0, name=f"GlobalConstraint-{constr_name}") gc_names = [f"{constr_name}_{prod}_{country}" for prod, country in zero_index] n.global_constraints.add( gc_names, sense="==", constant=0.0, type="production_stability", ) logger.info( "Added %d animal production constraints for zero-baseline (product, country) pairs", int(zero_mask.sum()), ) if nonzero_mask.any(): nonzero_index = common_index[nonzero_mask] lhs_nonzero = total_production.sel(group=nonzero_index) lower_nonzero = rhs_lower.sel(group=nonzero_index) upper_nonzero = rhs_upper.sel(group=nonzero_index) constr_name_min = "animal_production_min" constr_name_max = "animal_production_max" m.add_constraints( lhs_nonzero >= lower_nonzero, name=f"GlobalConstraint-{constr_name_min}" ) m.add_constraints( lhs_nonzero <= upper_nonzero, name=f"GlobalConstraint-{constr_name_max}" ) gc_names_min = [ f"{constr_name_min}_{prod}_{country}" for prod, country in nonzero_index ] gc_names_max = [ f"{constr_name_max}_{prod}_{country}" for prod, country in nonzero_index ] n.global_constraints.add( gc_names_min, sense=">=", constant=lower_nonzero.values, type="production_stability", ) n.global_constraints.add( gc_names_max, sense="<=", constant=upper_nonzero.values, type="production_stability", ) logger.info( "Added %d animal production stability constraints (delta=%.0f%%)", 2 * int(nonzero_mask.sum()), delta * 100, ) # Log missing baselines missing = model_index.difference(target_series.index) if len(missing) > 0: examples = [f"{p}/{c}" for p, c in list(missing)[:5]] logger.warning( "Missing animal baseline data for %d (product, country) pairs; examples: %s", len(missing), ", ".join(examples), ) def _get_consumption_link_map( p_names: pd.Index, links_df: pd.DataFrame, food_map: pd.DataFrame, cluster_lookup: dict[str, int], cluster_population: dict[int, float], ) -> pd.DataFrame: """Map consumption links to health clusters and risk factors.""" # Filter for consumption links consume_mask = links_df.index.isin(p_names) & links_df.index.str.startswith( "consume_" ) consume_links = links_df[consume_mask] if consume_links.empty: return pd.DataFrame() # Extract food and country from link attributes (set during model building) df = pd.DataFrame( { "link_name": consume_links.index, "food": consume_links["food"], "country": consume_links["country"], } ) # Merge with food_map df = df.merge(food_map, left_on="food", right_index=True) # Map to cluster df["cluster"] = df["country"].map(cluster_lookup) df = df.dropna(subset=["cluster"]) df["cluster"] = df["cluster"].astype(int) # Map to population df["population"] = df["cluster"].map(cluster_population) df = df[df["population"] > 0] # Calculate coefficient: share * grams per megatonne / (365 * population) df["coeff"] = ( df["share"] * constants.GRAMS_PER_MEGATONNE / (365.0 * df["population"]) ) return df
[docs] def add_health_objective( n: pypsa.Network, risk_breakpoints_path: str, cluster_cause_path: str, cause_log_path: str, cluster_summary_path: str, clusters_path: str, population_totals_path: str, risk_factors: list[str], risk_cause_map: dict[str, list[str]], solver_name: str, enforce_baseline: bool, ) -> None: """Link SOS2-based health costs to explicit PyPSA stores. The function builds the same intake→relative-risk logic as before, but instead of writing directly to the linopy objective it constrains the level of per-cluster, per-cause YLL stores. The monetary contribution is then handled by the store ``marginal_cost_storage`` configured during network construction. Health costs are formulated relative to TMREL (Theoretical Minimum Risk Exposure Level) intake, such that cost is zero when RR = RR_ref. Store energy levels represent (RR - RR_ref) * (YLL_base / RR_ref) and are non-negative since TMREL is the theoretical minimum risk level. """ m = n.model risk_breakpoints = pd.read_csv(risk_breakpoints_path) cluster_cause = pd.read_csv(cluster_cause_path) cause_log_breakpoints = pd.read_csv(cause_log_path) cluster_summary = pd.read_csv(cluster_summary_path) if "health_cluster" in cluster_summary.columns: cluster_summary["health_cluster"] = cluster_summary["health_cluster"].astype( int ) cluster_map = pd.read_csv(clusters_path) population_totals = pd.read_csv(population_totals_path) # Load food→risk factor mapping from food_groups.csv (only GBD risk factors) food_groups_df = pd.read_csv(snakemake.input.food_groups) food_map = food_groups_df[food_groups_df["group"].isin(risk_factors)].copy() food_map = food_map.rename(columns={"group": "risk_factor"}) food_map["share"] = 1.0 food_map = food_map.set_index("food")[["risk_factor", "share"]] cluster_lookup = cluster_map.set_index("country_iso3")["health_cluster"].to_dict() cluster_population_baseline = cluster_summary.set_index("health_cluster")[ "population_persons" ].to_dict() cluster_cause_metadata = cluster_cause.set_index(["health_cluster", "cause"]) population_totals = population_totals.dropna(subset=["iso3", "population"]).copy() population_totals["iso3"] = population_totals["iso3"].astype(str).str.upper() population_map = population_totals.set_index("iso3")["population"].to_dict() cluster_population_planning: dict[int, float] = defaultdict(float) for iso3, cluster in cluster_lookup.items(): value = float(population_map.get(iso3, 0.0)) if value <= 0: continue cluster_population_planning[int(cluster)] += value cluster_population: dict[int, float] = {} for cluster, baseline_pop in cluster_population_baseline.items(): planning_pop = cluster_population_planning.get(int(cluster), 0.0) if planning_pop > 0: cluster_population[int(cluster)] = planning_pop else: cluster_population[int(cluster)] = float(baseline_pop) risk_breakpoints = risk_breakpoints.sort_values( ["risk_factor", "intake_g_per_day", "cause"] ) cause_log_breakpoints = cause_log_breakpoints.sort_values(["cause", "log_rr_total"]) logger.info( "Health data: %d risk breakpoints across %d risks / %d causes; %d cause breakpoints", len(risk_breakpoints), risk_breakpoints["risk_factor"].nunique(), risk_breakpoints["cause"].nunique(), len(cause_log_breakpoints), ) # Filter risk_cause_map to only include risk factors in the breakpoints available_risks = set(risk_breakpoints["risk_factor"].unique()) risk_cause_map = { r: causes for r, causes in risk_cause_map.items() if r in available_risks } # Restrict to configured risk-cause pairs to avoid silent zeros allowed_pairs = {(r, c) for r, causes in risk_cause_map.items() for c in causes} rb_pairs = set(zip(risk_breakpoints["risk_factor"], risk_breakpoints["cause"])) missing_pairs = sorted(allowed_pairs - rb_pairs) if missing_pairs: text = ", ".join([f"{r}:{c}" for r, c in missing_pairs]) raise ValueError(f"Risk breakpoints missing required pairs: {text}") p = m.variables["Link-p"].sel(snapshot="now") if enforce_baseline: # Fixed dietary intakes: skip interpolation and pin store levels to the # baseline RR reported in the precomputed health tables. store_e = m.variables["Store-e"].sel(snapshot="now") health_stores = ( n.stores.static[ n.stores.static["carrier"].notna() & n.stores.static["carrier"].str.startswith("yll_") ] .reset_index() .set_index(["health_cluster", "cause"]) ) constraints_added = 0 if "log_rr_total_baseline" not in cluster_cause.columns: raise ValueError("cluster_cause must contain log_rr_total_baseline") for (cluster, cause), row in cluster_cause_metadata.iterrows(): rr_ref = math.exp(float(row["log_rr_total_ref"])) rr_base = math.exp(float(row["log_rr_total_baseline"])) yll_base = float(row["yll_base"]) delta = max(0.0, rr_base - rr_ref) store_level = delta * (yll_base / rr_ref) * constants.YLL_TO_MILLION_YLL store_name = health_stores.loc[(int(cluster), str(cause)), "name"] m.add_constraints( store_e.sel(name=store_name) == store_level, name=f"health_store_fixed_c{cluster}_cause{cause}", ) constraints_added += 1 logger.info( "Health stores fixed to baseline intake: added %d equality constraints", constraints_added, ) return # --- Stage 1: Intake to Log-RR --- # Vectorized Link Mapping link_map = _get_consumption_link_map( pd.Index(p.coords["name"].values), n.links.static, food_map, cluster_lookup, cluster_population, ) if link_map.empty: logger.info("No consumption links map to health risk factors; skipping") return logger.info( "Health intake mapping: %d links -> %d cluster-risk pairs across %d clusters", len(link_map), len(link_map[["cluster", "risk_factor"]].drop_duplicates()), link_map["cluster"].nunique(), ) risk_data = {} for risk, grp in risk_breakpoints.groupby("risk_factor"): intakes = pd.Index(sorted(grp["intake_g_per_day"].unique()), name="intake") if intakes.empty: continue pivot = ( grp.pivot_table( index="intake_g_per_day", columns="cause", values="log_rr", aggfunc="first", ) .reindex(intakes, axis=0) .sort_index() ) intake_steps = pd.Index(range(len(intakes)), name="intake_step") pivot.index = intake_steps risk_data[risk] = { "intake_steps": intake_steps, "intake_values": xr.DataArray( intakes.values, coords={"intake_step": intake_steps}, dims="intake_step" ), "log_rr": pivot, } cause_breakpoint_data = { cause: df.sort_values("log_rr_total") for cause, df in cause_log_breakpoints.groupby("cause") } # Group (cluster, risk) pairs by intake coordinate patterns # Identify unique (cluster, risk) pairs from the map unique_pairs = link_map[["cluster", "risk_factor"]].drop_duplicates() intake_groups: dict[tuple[float, ...], list[tuple[int, str]]] = defaultdict(list) for _, row in unique_pairs.iterrows(): cluster = int(row["cluster"]) risk = row["risk_factor"] risk_table = risk_data.get(risk) if risk_table is None: continue coords_key = tuple(risk_table["intake_values"].values) intake_groups[coords_key].append((cluster, risk)) log_rr_totals_dict = {} for coords_key, group_pairs in intake_groups.items(): intake_steps = risk_table["intake_steps"] # Identify group labels cluster_risk_labels = [f"c{cluster}_r{risk}" for cluster, risk in group_pairs] cluster_risk_index = pd.Index(cluster_risk_labels, name="cluster_risk") # Create lambdas (vectorized) risk_label = str(group_pairs[0][1]) lambdas_group = m.add_variables( lower=0, upper=1, coords=[cluster_risk_index, intake_steps], name=f"health_lambda_group_{next(_LAMBDA_GROUP_COUNTER)}_{risk_label}", ) # Register all variables _register_health_variable(m, lambdas_group.name) # Single SOS2 constraint call for entire group aux_names = _add_sos2_with_fallback( m, lambdas_group, sos_dim="intake_step", solver_name=solver_name ) for aux_name in aux_names: _register_health_variable(m, aux_name) # Vectorized convexity constraints m.add_constraints(lambdas_group.sum("intake_step") == 1) # --- Intake Balance --- # LHS: Sum of link flows * coeffs # Filter link_map for this group group_map_df = pd.DataFrame(group_pairs, columns=["cluster", "risk_factor"]) group_map_df["cluster_risk"] = cluster_risk_labels # Join link_map with group_map_df merged_links = link_map.merge( group_map_df, on=["cluster", "risk_factor"], how="inner" ) if not merged_links.empty: # Create sparse aggregation relevant_links = merged_links["link_name"].values # Construct DataArrays with "name" coordinate to align with p coeffs = xr.DataArray( merged_links["coeff"].values, coords={"name": relevant_links}, dims="name", ) grouper = xr.DataArray( merged_links["cluster_risk"].values, coords={"name": relevant_links}, dims="name", ) # Groupby sum on DataArray of LinearExpressions (p) works in linopy flow_expr = (p.sel(name=relevant_links) * coeffs).groupby(grouper).sum() # RHS: Intake interpolation coeff_intake = risk_table["intake_values"] intake_expr = (lambdas_group * coeff_intake).sum("intake_step") # Add constraints vectorized m.add_constraints( flow_expr == intake_expr, name=f"health_intake_balance_group_{hash(coords_key)}", ) # --- Log RR Calculation --- # Collect log_rr matrices log_rr_frames = [] for _cluster, risk in group_pairs: df = risk_data[risk]["log_rr"] # index=intake, cols=causes log_rr_frames.append(df) if not log_rr_frames: continue # Concat along cluster_risk dimension combined_log_rr = pd.concat( log_rr_frames, keys=cluster_risk_index, names=["cluster_risk", "intake_step"], ).fillna(0.0) # Convert to DataArray: (cluster_risk, intake_step, cause) # Use stack() to flatten columns (cause) into index s_log = combined_log_rr.stack() s_log.index.names = ["cluster_risk", "intake_step", "cause"] da_log = xr.DataArray.from_series(s_log).fillna(0.0) # Calculate contribution: sum(lambda * log_rr) over intake # lambdas_group: (cluster_risk, intake_step) # da_log: (cluster_risk, intake_step, cause) # Result: (cause, cluster_risk) of LinearExpressions contrib = (lambdas_group * da_log).sum("intake_step") # Accumulate into totals by grouping cluster_risk -> cluster c_map = group_map_df.set_index("cluster_risk")["cluster"] # Ensure we only map coords present in contrib present_cr = contrib.coords["cluster_risk"].values cluster_grouper = xr.DataArray( c_map.loc[present_cr].values, coords={"cluster_risk": present_cr}, dims="cluster_risk", name="cluster", ) # Group sum over cluster_risk -> yields (cause, cluster) group_total = contrib.groupby(cluster_grouper).sum() # Accumulate into dictionary for Stage 2 # group_total is a LinearExpression with dims (cause, cluster) # We iterate over coordinates to extract scalar expressions causes = group_total.coords["cause"].values clusters = group_total.coords["cluster"].values for c in clusters: for cause in causes: # Extract scalar expression for this (cluster, cause) expr = group_total.sel(cluster=c, cause=cause) key = (c, cause) if key in log_rr_totals_dict: log_rr_totals_dict[key] = log_rr_totals_dict[key] + expr else: log_rr_totals_dict[key] = expr # Group (cluster, cause) pairs by their log_total coordinate patterns # so we can reuse a single SOS2 variable for all pairs that share the same # breakpoint grid (one grid per cause). log_total_groups: dict[tuple[float, ...], list[tuple[int, str]]] = defaultdict(list) cluster_cause_data: dict[tuple[int, str], dict] = {} for (cluster, cause), row in cluster_cause_metadata.iterrows(): cluster = int(cluster) cause = str(cause) yll_base = float(row["yll_base"]) cause_bp = cause_breakpoint_data[cause] coords_key = tuple(cause_bp["log_rr_total"].values) if len(coords_key) == 1: raise ValueError( "Need at least two breakpoints for piecewise linear approximation" ) log_total_groups[coords_key].append((cluster, cause)) # Store data for later use log_rr_total_ref = float(row["log_rr_total_ref"]) cluster_cause_data[(cluster, cause)] = { "yll_base": yll_base, "log_rr_total_ref": log_rr_total_ref, "rr_ref": math.exp(log_rr_total_ref), "cause_bp": cause_bp, } logger.info( "Health risk aggregation: %d (cluster, cause) pairs grouped into %d log-RR grids", len(cluster_cause_data), len(log_total_groups), ) store_e = m.variables["Store-e"].sel(snapshot="now") health_stores = ( n.stores.static[ n.stores.static["carrier"].notna() & n.stores.static["carrier"].str.startswith("yll_") ] .reset_index() .set_index(["health_cluster", "cause"]) ) constraints_added = 0 # Process each group with vectorized operations. The outer loop handles one # shared grid (coords_key) at a time; the inner loop attaches each # (cluster, cause) pair using that grid to its own balance and store. for coords_key, cluster_cause_pairs in log_total_groups.items(): log_total_vals = np.asarray(coords_key, dtype=float) log_total_steps = pd.Index(range(len(log_total_vals)), name="log_total_step") # Create flattened index for this group cluster_cause_labels = [ f"c{cluster}_cause{cause}" for cluster, cause in cluster_cause_pairs ] cluster_cause_index = pd.Index(cluster_cause_labels, name="cluster_cause") # Single vectorized variable creation cause_label = str(cluster_cause_pairs[0][1]) lambda_total_group = m.add_variables( lower=0, upper=1, coords=[cluster_cause_index, log_total_steps], name=f"health_lambda_total_group_{next(_TOTAL_GROUP_COUNTER)}_{cause_label}", ) # Register all variables _register_health_variable(m, lambda_total_group.name) # Single SOS2 constraint call for entire group aux_names = _add_sos2_with_fallback( m, lambda_total_group, sos_dim="log_total_step", solver_name=solver_name ) for aux_name in aux_names: _register_health_variable(m, aux_name) # Vectorized convexity constraints m.add_constraints(lambda_total_group.sum("log_total_step") == 1) # Process each (cluster, cause) for balance constraints and objective coeff_log_total = xr.DataArray( log_total_vals, coords={"log_total_step": log_total_steps}, dims=["log_total_step"], ) for (cluster, cause), label in zip(cluster_cause_pairs, cluster_cause_labels): data = cluster_cause_data[(cluster, cause)] lambda_total = lambda_total_group.sel(cluster_cause=label) total_expr = log_rr_totals_dict[(cluster, cause)] cause_bp = data["cause_bp"] log_interp = m.linexpr((coeff_log_total, lambda_total)).sum( "log_total_step" ) coeff_rr = xr.DataArray( cause_bp["rr_total"].values, coords={"log_total_step": log_total_steps}, dims=["log_total_step"], ) rr_interp = m.linexpr((coeff_rr, lambda_total)).sum("log_total_step") m.add_constraints( log_interp == total_expr, name=f"health_total_balance_c{cluster}_cause{cause}", ) store_name = health_stores.loc[(cluster, cause), "name"] if data["yll_base"] <= 0: logger.warning( "Health store has non-positive yll_base (cluster=%d, cause=%s); constraint will be non-binding", cluster, cause, ) # Health cost is zero at TMREL (where RR = RR_ref) and increases with # deviation from optimal intake. Since TMREL is the theoretical minimum # risk level, RR >= RR_ref always, so store levels are non-negative. yll_expr_myll = ( (rr_interp - data["rr_ref"]) * (data["yll_base"] / data["rr_ref"]) * constants.YLL_TO_MILLION_YLL ) m.add_constraints( store_e.sel(name=store_name) >= yll_expr_myll, name=f"health_store_level_c{cluster}_cause{cause}", ) constraints_added += 1 logger.info("Added %d health store level constraints", constraints_added)
if __name__ == "__main__": # Configure logging to write to Snakemake log file logger = setup_script_logging(log_file=snakemake.log[0] if snakemake.log else None) # Apply scenario config overrides based on wildcard apply_scenario_config(snakemake.config, snakemake.wildcards.scenario) n = pypsa.Network(snakemake.input.network) # Add GHG pricing to the objective if enabled if snakemake.config["emissions"]["ghg_pricing_enabled"]: ghg_price = float(snakemake.params.ghg_price) add_ghg_pricing_to_objective(n, ghg_price) # Add food-group incentives to the objective if enabled if snakemake.config["food_group_incentives"]["enabled"]: incentives_paths = list(snakemake.input.food_group_incentives) add_food_group_incentives_to_objective(n, incentives_paths) # Create the linopy model logger.info("Creating linopy model...") n.optimize.create_model() logger.info("Linopy model created.") solver_name = snakemake.params.solver solver_threads = snakemake.params.solver_threads solver_options = _apply_solver_threads_option( dict(snakemake.params.solver_options), solver_name, solver_threads, ) io_api = snakemake.params.io_api netcdf_compression = snakemake.params.netcdf_compression # Configure Gurobi to write detailed logs to the same file if solver_name.lower() == "gurobi" and snakemake.log: if "LogFile" not in solver_options: solver_options["LogFile"] = snakemake.log[0] if "LogToConsole" not in solver_options: solver_options["LogToConsole"] = 1 # Also print to console # Add macronutrient intake bounds population_df = pd.read_csv(snakemake.input.population) population_df["iso3"] = population_df["iso3"].astype(str).str.upper() population_map = ( population_df.set_index("iso3")["population"].astype(float).to_dict() ) # Food group baseline equals (optional) per_country_equal: dict[str, dict[str, float]] | None = None equal_source = snakemake.config["food_groups"]["equal_by_country_source"] if bool(snakemake.params.enforce_baseline) and equal_source: raise ValueError( "Cannot combine enforce_gdd_baseline with food_groups.equal_by_country_source" ) if bool(snakemake.params.enforce_baseline): baseline_df = pd.read_csv(snakemake.input.baseline_diet) per_country_equal = _build_food_group_equals_from_baseline( baseline_df, list(population_map.keys()), pd.read_csv(snakemake.input.food_groups)["group"].unique(), baseline_age=str(snakemake.params.diet["baseline_age"]), reference_year=int(snakemake.params.diet["baseline_reference_year"]), ) elif equal_source: equal_df = pd.read_csv(snakemake.input.food_group_equal) required = {"group", "country", "consumption_g_per_day"} missing = required - set(equal_df.columns) if missing: missing_text = ", ".join(sorted(missing)) raise ValueError( f"Missing required columns in food group equality file: {missing_text}" ) equal_df["country"] = equal_df["country"].astype(str).str.upper() per_country_equal = {} all_countries = set(population_map.keys()) for group, group_df in equal_df.groupby("group"): values = dict.fromkeys(all_countries, 0.0) for _, row in group_df.iterrows(): country = str(row["country"]).upper() if country not in values: logger.warning( "Unknown country '%s' in food group equality file", country ) continue values[country] = float(row["consumption_g_per_day"]) missing_countries = sorted(all_countries - set(group_df["country"])) if missing_countries: preview = ", ".join(missing_countries[:5]) logger.warning( "Food group '%s' missing %d countries in equality file; " "setting them to 0 (examples: %s)", group, len(missing_countries), preview, ) per_country_equal[str(group)] = values add_macronutrient_constraints(n, snakemake.params.macronutrients, population_map) add_food_group_constraints( n, snakemake.params.food_group_constraints, population_map, per_country_equal, ) # Add residue feed limit constraints max_feed_fraction = float(snakemake.config["residues"]["max_feed_fraction"]) max_feed_fraction_by_country = build_residue_feed_fraction_by_country( snakemake.config, snakemake.input.m49 ) add_residue_feed_constraints(n, max_feed_fraction, max_feed_fraction_by_country) # Add animal production constraints in validation mode use_actual_production = bool( snakemake.config["validation"]["use_actual_production"] ) if use_actual_production: fao_animal_production = pd.read_csv(snakemake.input.animal_production) food_groups_df = pd.read_csv(snakemake.input.food_groups) food_to_group = food_groups_df.set_index("food")["group"].to_dict() food_loss_waste = pd.read_csv(snakemake.input.food_loss_waste) add_animal_production_constraints( n, fao_animal_production, food_to_group, food_loss_waste ) # Add production stability constraints stability_cfg = snakemake.params.production_stability if stability_cfg["enabled"]: # Load food_to_group if not already loaded if "food_to_group" not in dir(): food_groups_df = pd.read_csv(snakemake.input.food_groups) food_to_group = food_groups_df.set_index("food")["group"].to_dict() crop_baseline = None crop_to_fao_item: dict[str, str] = {} animal_baseline = None food_loss_waste_df = pd.DataFrame() if stability_cfg["crops"]["enabled"]: crop_baseline = pd.read_csv(snakemake.input.crop_production_baseline) # Load FAO item mapping to aggregate crops sharing an FAO item fao_map_df = pd.read_csv(snakemake.input.faostat_item_map) crop_to_fao_item = dict( zip( fao_map_df["crop"].astype(str), fao_map_df["faostat_item"].astype(str), ) ) if stability_cfg["animals"]["enabled"]: animal_baseline = pd.read_csv(snakemake.input.animal_production_baseline) food_loss_waste_df = pd.read_csv(snakemake.input.food_loss_waste) add_production_stability_constraints( n, crop_baseline, crop_to_fao_item, animal_baseline, stability_cfg, food_to_group, food_loss_waste_df, ) # Add health impacts / store levels if enabled or baseline intake is enforced health_enabled = bool(snakemake.config["health"]["enabled"]) enforce_baseline = bool(snakemake.params.enforce_baseline) if health_enabled and enforce_baseline: raise ValueError( "health.enabled and validation.enforce_gdd_baseline cannot both be true; " "disable one of them in the config." ) if health_enabled or enforce_baseline: add_health_objective( n, snakemake.input.health_risk_breakpoints, snakemake.input.health_cluster_cause, snakemake.input.health_cause_log, snakemake.input.health_cluster_summary, snakemake.input.health_clusters, snakemake.input.population, snakemake.params.health_risk_factors, snakemake.params.health_risk_cause_map, solver_name, enforce_baseline, ) status, condition = n.model.solve( solver_name=solver_name, io_api=io_api, calculate_fixed_duals=snakemake.params.calculate_fixed_duals, **solver_options, ) result = (status, condition) # Temporary debug export of the raw solved linopy model # linopy_debug_path = Path(snakemake.output.network).with_name("linopy_model.nc") # linopy_debug_path.parent.mkdir(parents=True, exist_ok=True) # n.model.to_netcdf(linopy_debug_path) # logger.info("Wrote linopy model snapshot to %s", linopy_debug_path) if status == "ok": aux_names = HEALTH_AUX_MAP.pop(id(n.model), set()) variables_container = n.model.variables removed = {} for name in aux_names: if name in variables_container.data: removed[name] = variables_container.data.pop(name) try: n.optimize.assign_solution() n.optimize.assign_duals(False) n.optimize.post_processing() finally: if removed: variables_container.data.update(removed) n.export_to_netcdf( snakemake.output.network, compression=netcdf_compression, ) elif condition in {"infeasible", "infeasible_or_unbounded"}: logger.error("Model is infeasible or unbounded!") if solver_name.lower() == "gurobi": try: logger.error("Computing IIS (Irreducible Inconsistent Subsystem)...") # Get infeasible constraint labels infeasible_labels = n.model.compute_infeasibilities() if not infeasible_labels: logger.error("No infeasible constraints found in IIS") else: logger.error( "Found %d infeasible constraints:", len(infeasible_labels) ) constraint_details = [] for label in infeasible_labels: try: detail = print_single_constraint(n.model, label) constraint_details.append(detail) except Exception as e: constraint_details.append( f"Label {label}: <error formatting: {e}>" ) # Log all infeasible constraints iis_output = "\n".join(constraint_details) logger.error("IIS constraints:\n%s", iis_output) except Exception as exc: logger.error("Could not compute infeasibilities: %s", exc) else: logger.error("Infeasibility diagnosis only available with Gurobi solver") else: logger.error("Optimization unsuccessful: %s", result)