Source code for workflow.scripts.plotting.plot_food_consumption

# SPDX-FileCopyrightText: 2025 Koen van Greevenbroek
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Plot global food consumption by food group per person per day."""

from pathlib import Path

import matplotlib

matplotlib.use("pdf")
import matplotlib.pyplot as plt
import pandas as pd

from workflow.scripts.constants import DAYS_PER_YEAR, GRAMS_PER_MEGATONNE, PJ_TO_KCAL
from workflow.scripts.logging_config import setup_script_logging
from workflow.scripts.plotting.color_utils import categorical_colors

# Alias for backwards compatibility with modules that import from here
KCAL_PER_PJ = PJ_TO_KCAL


def _load_global_consumption(
    food_group_consumption_path: str, population_path: str
) -> tuple[pd.Series, pd.Series]:
    """Load food group consumption and compute global per-capita values.

    Returns
    -------
    tuple[pd.Series, pd.Series]
        (mass_g_per_person_day, calories_kcal_per_person_day) indexed by food_group
    """
    df = pd.read_csv(food_group_consumption_path)
    pop_df = pd.read_csv(population_path)

    if df.empty:
        return pd.Series(dtype=float), pd.Series(dtype=float)

    # Sum absolute values across all countries
    global_totals = df.groupby("food_group")[["consumption_mt", "cal_pj"]].sum()

    # Get total population
    total_population = pop_df["population"].sum()

    # Convert to per-capita
    mass_per_capita = (
        global_totals["consumption_mt"]
        * GRAMS_PER_MEGATONNE
        / (total_population * DAYS_PER_YEAR)
    )
    calories_per_capita = (
        global_totals["cal_pj"] * PJ_TO_KCAL / (total_population * DAYS_PER_YEAR)
    )

    return mass_per_capita, calories_per_capita


def _assign_colors(
    groups: list[str], overrides: dict[str, str] | None = None
) -> dict[str, str]:
    return categorical_colors(groups, overrides)


def _plot(
    mass_g_per_person_day: pd.Series,
    calories_kcal_per_person_day: pd.Series,
    output_pdf: Path,
) -> None:
    mass_g_per_person_day = mass_g_per_person_day[mass_g_per_person_day > 0]
    calories_kcal_per_person_day = calories_kcal_per_person_day[
        calories_kcal_per_person_day > 0
    ]

    ordered_groups: list[str] = []
    ordered_groups.extend(
        mass_g_per_person_day.sort_values(ascending=False).index.tolist()
    )
    for group in calories_kcal_per_person_day.sort_values(ascending=False).index:
        if group not in ordered_groups:
            ordered_groups.append(group)

    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    if not ordered_groups:
        for ax in axes:
            ax.text(
                0.5, 0.5, "No food group consumption data", ha="center", va="center"
            )
            ax.axis("off")
        fig.tight_layout()
        fig.savefig(output_pdf, bbox_inches="tight", dpi=300)
        plt.close(fig)
        return

    group_colors = getattr(snakemake.params, "group_colors", {}) or {}
    colors = _assign_colors(ordered_groups, group_colors)

    # Mass subplot
    ax_mass = axes[0]
    bottom = 0.0
    for group in ordered_groups:
        value = float(mass_g_per_person_day.get(group, 0.0))
        if value <= 0.0:
            continue
        ax_mass.bar(0, value, bottom=bottom, color=colors[group], label=group)
        bottom += value

    ax_mass.set_xticks([0])
    ax_mass.set_xticklabels(["Mass"])
    ax_mass.set_ylabel("g/person/day")
    ax_mass.set_title("Global Food Consumption (Mass)")
    ax_mass.grid(axis="y", alpha=0.3)

    # Calories subplot
    ax_cal = axes[1]
    bottom = 0.0
    for group in ordered_groups:
        value = float(calories_kcal_per_person_day.get(group, 0.0))
        if value <= 0.0:
            continue
        ax_cal.bar(0, value, bottom=bottom, color=colors[group])
        bottom += value

    ax_cal.set_xticks([0])
    ax_cal.set_xticklabels(["Calories"])
    ax_cal.set_ylabel("kcal/person/day")
    ax_cal.set_title("Global Food Consumption (Calories)")
    ax_cal.grid(axis="y", alpha=0.3)

    handles, labels = ax_mass.get_legend_handles_labels()
    if handles:
        fig.legend(
            handles[::-1],
            labels[::-1],
            loc="center left",
            bbox_to_anchor=(1.0, 0.5),
        )
        fig.tight_layout(rect=(0, 0, 0.85, 1))
    else:
        fig.tight_layout()

    fig.savefig(output_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)


[docs] def main() -> None: try: snakemake # type: ignore[name-defined] except NameError as exc: # pragma: no cover - Snakemake injects the variable raise RuntimeError("This script must be run from Snakemake") from exc logger = setup_script_logging(snakemake.log[0]) food_group_consumption_path = snakemake.input.food_group_consumption population_path = snakemake.input.population output_pdf = Path(snakemake.output.pdf) output_pdf.parent.mkdir(parents=True, exist_ok=True) logger.info("Loading food group consumption from %s", food_group_consumption_path) mass_per_capita, calories_per_capita = _load_global_consumption( food_group_consumption_path, population_path ) logger.info( "Found %d food groups with mass data and %d with calorie data", mass_per_capita[mass_per_capita > 0].shape[0], calories_per_capita[calories_per_capita > 0].shape[0], ) _plot(mass_per_capita, calories_per_capita, output_pdf) logger.info("Food consumption plot saved to %s", output_pdf)
if __name__ == "__main__": main()