Source code for workflow.scripts.plotting.plot_food_consumption

# SPDX-FileCopyrightText: 2025 Koen van Greevenbroek
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Plot global food consumption by food group per person per day."""

import logging
from pathlib import Path
from typing import Dict

import matplotlib
import numpy as np
import pandas as pd
import pypsa

matplotlib.use("pdf")
import matplotlib.pyplot as plt

try:  # Prefer package import when available (e.g., during documentation builds)
    from workflow.scripts.color_utils import categorical_colors
except ImportError:  # Fallback to Snakemake's script-directory loader
    from color_utils import categorical_colors  # type: ignore


logger = logging.getLogger(__name__)

GRAMS_PER_MEGATONNE = 1e12
KCAL_PER_MCAL = 1e6
DAYS_PER_YEAR = 365


def _select_snapshot(network: pypsa.Network) -> pd.Index | str:
    if "now" in network.snapshots:
        return "now"
    if len(network.snapshots) == 1:
        return network.snapshots[0]
    raise ValueError("Expected snapshot 'now' or single snapshot in solved network")


def _group_from_bus(bus: str) -> str:
    remainder = bus[len("group_") :]
    if "_" in remainder:
        return remainder.rsplit("_", 1)[0]
    return remainder


def _bus_column_to_leg(column: str) -> int | None:
    if not column.startswith("bus"):
        return None
    suffix = column[len("bus") :]
    if not suffix:
        return 0
    if suffix.isdigit():
        return int(suffix)
    return None


def _link_dispatch_at_snapshot(
    network: pypsa.Network, snapshot
) -> dict[int, pd.Series]:
    dispatch: dict[int, pd.Series] = {}
    for attr in dir(network.links_t):
        if not attr.startswith("p"):
            continue
        suffix = attr[1:]
        if not suffix.isdigit():
            continue
        series = getattr(network.links_t, attr)
        if snapshot not in series.index:
            continue
        dispatch[int(suffix)] = series.loc[snapshot]
    return dispatch


def _aggregate_group_mass(network: pypsa.Network, snapshot) -> pd.Series:
    links = network.links
    if links.empty:
        return pd.Series(dtype=float)

    consume_links = links[links.index.str.startswith("consume_")]
    if consume_links.empty:
        return pd.Series(dtype=float)

    bus_columns = [col for col in consume_links.columns if col.startswith("bus")]
    if not bus_columns:
        return pd.Series(dtype=float)

    dispatch_lookup = _link_dispatch_at_snapshot(network, snapshot)
    if not dispatch_lookup:
        return pd.Series(dtype=float)

    totals: Dict[str, float] = {}
    for link_name, row in consume_links[bus_columns].iterrows():
        for column, bus_value in row.items():
            if not isinstance(bus_value, str) or not bus_value.startswith("group_"):
                continue
            leg = _bus_column_to_leg(column)
            if leg is None:
                continue
            dispatch = dispatch_lookup.get(leg)
            if dispatch is None:
                continue
            value = float(dispatch.get(link_name, 0.0))
            if value == 0.0 or not np.isfinite(value):
                continue
            group = _group_from_bus(bus_value)
            totals[group] = totals.get(group, 0.0) + abs(value)

    return pd.Series(totals, dtype=float)


def _available_legs(links: pd.DataFrame) -> list[int]:
    legs: set[int] = set()
    for column in links.columns:
        if not column.startswith("bus"):
            continue
        if column == "bus0":
            continue
        suffix = column[3:]
        if not suffix:
            continue
        try:
            legs.add(int(suffix))
        except ValueError:
            continue
    return sorted(legs)


def _aggregate_group_calories(network: pypsa.Network, snapshot) -> pd.Series:
    links = network.links
    if links.empty:
        return pd.Series(dtype=float)

    legs = _available_legs(links)
    if not legs:
        return pd.Series(dtype=float)

    time_series_lookup: dict[int, pd.Series] = {}
    for leg in legs:
        attr = f"p{leg}"
        series = getattr(network.links_t, attr, None)
        if series is None or snapshot not in series.index:
            continue
        time_series_lookup[leg] = series.loc[snapshot]

    if not time_series_lookup:
        return pd.Series(dtype=float)

    totals: dict[str, float] = {}

    for link_name in links.index:
        if not link_name.startswith("consume_"):
            continue

        group_name: str | None = None
        kcal_leg: int | None = None

        for leg in legs:
            column = f"bus{leg}"
            if column not in links.columns:
                continue
            bus_value = links.at[link_name, column]
            if pd.isna(bus_value):
                continue
            bus_str = str(bus_value)
            if bus_str.startswith("group_"):
                group_name = _group_from_bus(bus_str)
            if bus_str.startswith("kcal_"):
                kcal_leg = leg

        if group_name is None or kcal_leg is None:
            continue

        series = time_series_lookup.get(kcal_leg)
        if series is None:
            continue

        value = abs(float(series.get(link_name, 0.0)))
        if value <= 0.0:
            continue

        totals[group_name] = totals.get(group_name, 0.0) + value

    return pd.Series(totals, dtype=float)


def _assign_colors(
    groups: list[str], overrides: Dict[str, str] | None = None
) -> dict[str, str]:
    return categorical_colors(groups, overrides)


def _plot(
    mass_g_per_person_day: pd.Series,
    calories_kcal_per_person_day: pd.Series,
    output_pdf: Path,
) -> None:
    mass_g_per_person_day = mass_g_per_person_day[mass_g_per_person_day > 0]
    calories_kcal_per_person_day = calories_kcal_per_person_day[
        calories_kcal_per_person_day > 0
    ]

    ordered_groups: list[str] = []
    ordered_groups.extend(
        mass_g_per_person_day.sort_values(ascending=False).index.tolist()
    )
    for group in calories_kcal_per_person_day.sort_values(ascending=False).index:
        if group not in ordered_groups:
            ordered_groups.append(group)

    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    if not ordered_groups:
        for ax in axes:
            ax.text(
                0.5, 0.5, "No food group consumption data", ha="center", va="center"
            )
            ax.axis("off")
        fig.tight_layout()
        fig.savefig(output_pdf, bbox_inches="tight", dpi=300)
        plt.close(fig)
        return

    group_colors = getattr(snakemake.params, "group_colors", {}) or {}
    colors = _assign_colors(ordered_groups, group_colors)

    # Mass subplot
    ax_mass = axes[0]
    bottom = 0.0
    for group in ordered_groups:
        value = float(mass_g_per_person_day.get(group, 0.0))
        if value <= 0.0:
            continue
        ax_mass.bar(0, value, bottom=bottom, color=colors[group], label=group)
        bottom += value

    ax_mass.set_xticks([0])
    ax_mass.set_xticklabels(["Mass"])
    ax_mass.set_ylabel("g/person/day")
    ax_mass.set_title("Global Food Consumption (Mass)")
    ax_mass.grid(axis="y", alpha=0.3)

    # Calories subplot
    ax_cal = axes[1]
    bottom = 0.0
    for group in ordered_groups:
        value = float(calories_kcal_per_person_day.get(group, 0.0))
        if value <= 0.0:
            continue
        ax_cal.bar(0, value, bottom=bottom, color=colors[group])
        bottom += value

    ax_cal.set_xticks([0])
    ax_cal.set_xticklabels(["Calories"])
    ax_cal.set_ylabel("kcal/person/day")
    ax_cal.set_title("Global Food Consumption (Calories)")
    ax_cal.grid(axis="y", alpha=0.3)

    handles, labels = ax_mass.get_legend_handles_labels()
    if handles:
        fig.legend(
            handles[::-1],
            labels[::-1],
            loc="center left",
            bbox_to_anchor=(1.0, 0.5),
        )
        fig.tight_layout(rect=(0, 0, 0.85, 1))
    else:
        fig.tight_layout()

    fig.savefig(output_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)


[docs] def main() -> None: try: snakemake # type: ignore[name-defined] except NameError as exc: # pragma: no cover - Snakemake injects the variable raise RuntimeError("This script must be run from Snakemake") from exc network_path = snakemake.input.network # type: ignore[attr-defined] population_path = snakemake.input.population # type: ignore[attr-defined] output_pdf = Path(snakemake.output.pdf) # type: ignore[attr-defined] output_pdf.parent.mkdir(parents=True, exist_ok=True) logger.info("Loading solved network from %s", network_path) network = pypsa.Network(network_path) snapshot = _select_snapshot(network) logger.info("Using snapshot '%s' for aggregation", snapshot) mass = _aggregate_group_mass(network, snapshot) calories = _aggregate_group_calories(network, snapshot) population_df = pd.read_csv(population_path) if "population" not in population_df.columns: raise ValueError("Population file must contain a 'population' column") total_population = float(population_df["population"].sum()) if total_population <= 0.0: raise ValueError("Total population must be positive for per-capita conversion") mass_per_capita = mass * GRAMS_PER_MEGATONNE / (total_population * DAYS_PER_YEAR) calories_per_capita = calories * KCAL_PER_MCAL / (total_population * DAYS_PER_YEAR) logger.info( "Found %d food groups with mass data and %d with calorie data", mass_per_capita[mass_per_capita > 0].shape[0], calories_per_capita[calories_per_capita > 0].shape[0], ) _plot(mass_per_capita, calories_per_capita, output_pdf) logger.info("Food consumption plot saved to %s", output_pdf)
if __name__ == "__main__": main()