Source code for workflow.scripts.plotting.plot_food_consumption

# SPDX-FileCopyrightText: 2025 Koen van Greevenbroek
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Plot global food consumption by food group per person per day."""

import logging
from pathlib import Path

import matplotlib
import numpy as np
import pandas as pd
import pypsa

matplotlib.use("pdf")
import matplotlib.pyplot as plt

from workflow.scripts.constants import DAYS_PER_YEAR, GRAMS_PER_MEGATONNE, PJ_TO_KCAL
from workflow.scripts.plotting.color_utils import categorical_colors

# Alias for backwards compatibility with modules that import from here
KCAL_PER_PJ = PJ_TO_KCAL

logger = logging.getLogger(__name__)


def _select_snapshot(network: pypsa.Network) -> pd.Index | str:
    if "now" in network.snapshots:
        return "now"
    if len(network.snapshots) == 1:
        return network.snapshots[0]
    raise ValueError("Expected snapshot 'now' or single snapshot in solved network")


def _bus_column_to_leg(column: str) -> int | None:
    if not column.startswith("bus"):
        return None
    suffix = column[len("bus") :]
    if not suffix:
        return 0
    if suffix.isdigit():
        return int(suffix)
    return None


def _link_dispatch_at_snapshot(
    network: pypsa.Network, snapshot
) -> dict[int, pd.Series]:
    dispatch: dict[int, pd.Series] = {}
    for attr in dir(network.links_t):
        if not attr.startswith("p"):
            continue
        suffix = attr[1:]
        if not suffix.isdigit():
            continue
        series = getattr(network.links_t, attr)
        if snapshot not in series.index:
            continue
        dispatch[int(suffix)] = series.loc[snapshot]
    return dispatch


def _aggregate_group_mass(
    network: pypsa.Network, snapshot, food_to_group: dict[str, str]
) -> pd.Series:
    """Aggregate consumption by food group using link attributes."""
    links = network.links
    consume_links = links[links.index.str.startswith("consume_")]
    if consume_links.empty:
        return pd.Series(dtype=float)

    dispatch_lookup = _link_dispatch_at_snapshot(network, snapshot)
    if not dispatch_lookup:
        return pd.Series(dtype=float)

    totals: dict[str, float] = {}
    for link_name in consume_links.index:
        food = str(consume_links.at[link_name, "food"])
        group = food_to_group.get(food)
        if group is None:
            continue

        # Find which leg outputs to the group bus and get its flow
        for leg, dispatch in dispatch_lookup.items():
            value = float(dispatch.get(link_name, 0.0))
            if value == 0.0 or not np.isfinite(value):
                continue
            # Check if this leg goes to a group bus
            bus_col = f"bus{leg}" if leg > 0 else "bus0"
            bus_value = consume_links.at[link_name, bus_col]
            if isinstance(bus_value, str) and bus_value.startswith("group_"):
                totals[group] = totals.get(group, 0.0) + abs(value)
                break

    return pd.Series(totals, dtype=float)


def _available_legs(links: pd.DataFrame) -> list[int]:
    legs: set[int] = set()
    for column in links.columns:
        if not column.startswith("bus"):
            continue
        if column == "bus0":
            continue
        suffix = column[3:]
        if not suffix:
            continue
        try:
            legs.add(int(suffix))
        except ValueError:
            continue
    return sorted(legs)


def _aggregate_group_calories(
    network: pypsa.Network, snapshot, food_to_group: dict[str, str]
) -> pd.Series:
    """Aggregate calorie consumption by food group using link attributes."""
    links = network.links
    legs = _available_legs(links)
    if not legs:
        return pd.Series(dtype=float)

    time_series_lookup: dict[int, pd.Series] = {}
    for leg in legs:
        attr = f"p{leg}"
        series = getattr(network.links_t, attr, None)
        if series is None or snapshot not in series.index:
            continue
        time_series_lookup[leg] = series.loc[snapshot]

    if not time_series_lookup:
        return pd.Series(dtype=float)

    totals: dict[str, float] = {}
    for link_name in links.index:
        if not link_name.startswith("consume_"):
            continue

        food = str(links.at[link_name, "food"])
        group_name = food_to_group.get(food)
        if group_name is None:
            continue

        # Find kcal leg
        kcal_leg: int | None = None
        for leg in legs:
            column = f"bus{leg}"
            bus_value = links.at[link_name, column]
            if pd.notna(bus_value) and str(bus_value).startswith("cal_"):
                kcal_leg = leg
                break

        if kcal_leg is None:
            continue

        series = time_series_lookup.get(kcal_leg)
        if series is None:
            continue

        value = abs(float(series.get(link_name, 0.0)))
        if value > 0.0:
            totals[group_name] = totals.get(group_name, 0.0) + value

    return pd.Series(totals, dtype=float)


def _assign_colors(
    groups: list[str], overrides: dict[str, str] | None = None
) -> dict[str, str]:
    return categorical_colors(groups, overrides)


def _plot(
    mass_g_per_person_day: pd.Series,
    calories_kcal_per_person_day: pd.Series,
    output_pdf: Path,
) -> None:
    mass_g_per_person_day = mass_g_per_person_day[mass_g_per_person_day > 0]
    calories_kcal_per_person_day = calories_kcal_per_person_day[
        calories_kcal_per_person_day > 0
    ]

    ordered_groups: list[str] = []
    ordered_groups.extend(
        mass_g_per_person_day.sort_values(ascending=False).index.tolist()
    )
    for group in calories_kcal_per_person_day.sort_values(ascending=False).index:
        if group not in ordered_groups:
            ordered_groups.append(group)

    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    if not ordered_groups:
        for ax in axes:
            ax.text(
                0.5, 0.5, "No food group consumption data", ha="center", va="center"
            )
            ax.axis("off")
        fig.tight_layout()
        fig.savefig(output_pdf, bbox_inches="tight", dpi=300)
        plt.close(fig)
        return

    group_colors = getattr(snakemake.params, "group_colors", {}) or {}
    colors = _assign_colors(ordered_groups, group_colors)

    # Mass subplot
    ax_mass = axes[0]
    bottom = 0.0
    for group in ordered_groups:
        value = float(mass_g_per_person_day.get(group, 0.0))
        if value <= 0.0:
            continue
        ax_mass.bar(0, value, bottom=bottom, color=colors[group], label=group)
        bottom += value

    ax_mass.set_xticks([0])
    ax_mass.set_xticklabels(["Mass"])
    ax_mass.set_ylabel("g/person/day")
    ax_mass.set_title("Global Food Consumption (Mass)")
    ax_mass.grid(axis="y", alpha=0.3)

    # Calories subplot
    ax_cal = axes[1]
    bottom = 0.0
    for group in ordered_groups:
        value = float(calories_kcal_per_person_day.get(group, 0.0))
        if value <= 0.0:
            continue
        ax_cal.bar(0, value, bottom=bottom, color=colors[group])
        bottom += value

    ax_cal.set_xticks([0])
    ax_cal.set_xticklabels(["Calories"])
    ax_cal.set_ylabel("kcal/person/day")
    ax_cal.set_title("Global Food Consumption (Calories)")
    ax_cal.grid(axis="y", alpha=0.3)

    handles, labels = ax_mass.get_legend_handles_labels()
    if handles:
        fig.legend(
            handles[::-1],
            labels[::-1],
            loc="center left",
            bbox_to_anchor=(1.0, 0.5),
        )
        fig.tight_layout(rect=(0, 0, 0.85, 1))
    else:
        fig.tight_layout()

    fig.savefig(output_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)


[docs] def main() -> None: try: snakemake # type: ignore[name-defined] except NameError as exc: # pragma: no cover - Snakemake injects the variable raise RuntimeError("This script must be run from Snakemake") from exc network_path = snakemake.input.network # type: ignore[attr-defined] population_path = snakemake.input.population # type: ignore[attr-defined] food_groups_path = snakemake.input.food_groups # type: ignore[attr-defined] output_pdf = Path(snakemake.output.pdf) # type: ignore[attr-defined] output_pdf.parent.mkdir(parents=True, exist_ok=True) logger.info("Loading solved network from %s", network_path) network = pypsa.Network(network_path) # Load food->group mapping food_groups_df = pd.read_csv(food_groups_path) food_to_group = food_groups_df.set_index("food")["group"].to_dict() snapshot = _select_snapshot(network) logger.info("Using snapshot '%s' for aggregation", snapshot) mass = _aggregate_group_mass(network, snapshot, food_to_group) calories_pj = _aggregate_group_calories(network, snapshot, food_to_group) population_df = pd.read_csv(population_path) if "population" not in population_df.columns: raise ValueError("Population file must contain a 'population' column") total_population = float(population_df["population"].sum()) if total_population <= 0.0: raise ValueError("Total population must be positive for per-capita conversion") mass_per_capita = mass * GRAMS_PER_MEGATONNE / (total_population * DAYS_PER_YEAR) calories_per_capita = calories_pj * KCAL_PER_PJ / (total_population * DAYS_PER_YEAR) logger.info( "Found %d food groups with mass data and %d with calorie data", mass_per_capita[mass_per_capita > 0].shape[0], calories_per_capita[calories_per_capita > 0].shape[0], ) _plot(mass_per_capita, calories_per_capita, output_pdf) logger.info("Food consumption plot saved to %s", output_pdf)
if __name__ == "__main__": main()