Source code for workflow.scripts.plotting.plot_crop_production_map

#! /usr/bin/env python3
# SPDX-FileCopyrightText: 2025 Koen van Greevenbroek
#
# SPDX-License-Identifier: GPL-3.0-or-later

import logging
from pathlib import Path
from typing import Dict, Iterable, Mapping, Tuple

import cartopy.crs as ccrs
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
import geopandas as gpd
import matplotlib

matplotlib.use("pdf")
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import pypsa

logger = logging.getLogger(__name__)


def _region_from_bus0(bus0: str) -> str:
    parts = bus0.split("_")
    return parts[1] if len(parts) >= 2 else "unknown"


def _dict_to_df(data: Dict[Tuple[str, str], float]) -> pd.DataFrame:
    if not data:
        return pd.DataFrame()
    series = pd.Series(data).sort_index()
    df = series.unstack(fill_value=0.0).sort_index(axis=0).sort_index(axis=1)
    df.index.name = "region"
    return df


def _aggregate_production_by_region(n: pypsa.Network, snapshot: str) -> pd.DataFrame:
    data: Dict[Tuple[str, str], float] = {}

    def add(region: str, crop: str, value: float) -> None:
        if not np.isfinite(value) or value <= 0:
            return
        key = (region, crop)
        data[key] = data.get(key, 0.0) + float(value)

    produce_links = [name for name in n.links.index if str(name).startswith("produce_")]
    if produce_links:
        p1 = n.links_t.p1.loc[snapshot, produce_links]
        bus0 = n.links.loc[produce_links, "bus0"]
        for name, value in p1.items():
            crop = str(name).split("_")[1] if "_" in str(name) else "unknown"
            region = _region_from_bus0(str(bus0.loc[name]))
            add(region, crop, abs(float(value)))

    graze_links = [name for name in n.links.index if str(name).startswith("graze_")]
    if graze_links:
        p1 = n.links_t.p1.loc[snapshot, graze_links]
        bus0 = n.links.loc[graze_links, "bus0"]
        for name, value in p1.items():
            region = _region_from_bus0(str(bus0.loc[name]))
            add(region, "grassland", abs(float(value)))

    df = _dict_to_df(data)
    if "grassland" not in df.columns:
        df["grassland"] = 0.0
    return df


def _aggregate_land_use_by_region(n: pypsa.Network, snapshot: str) -> pd.DataFrame:
    """Aggregate land use by region and crop.

    Returns land area in hectares.
    """
    data: Dict[Tuple[str, str], float] = {}

    def add(region: str, crop: str, value: float) -> None:
        if not np.isfinite(value) or value <= 0:
            return
        key = (region, crop)
        data[key] = data.get(key, 0.0) + float(value)

    produce_links = [name for name in n.links.index if str(name).startswith("produce_")]
    if produce_links:
        p0 = n.links_t.p0.loc[snapshot, produce_links]
        bus0 = n.links.loc[produce_links, "bus0"]
        for name, value in p0.items():
            crop = str(name).split("_")[1] if "_" in str(name) else "unknown"
            region = _region_from_bus0(str(bus0.loc[name]))
            add(region, crop, max(float(value), 0.0) * 1e6)

    graze_links = [name for name in n.links.index if str(name).startswith("graze_")]
    if graze_links:
        p0 = n.links_t.p0.loc[snapshot, graze_links]
        bus0 = n.links.loc[graze_links, "bus0"]
        for name, value in p0.items():
            region = _region_from_bus0(str(bus0.loc[name]))
            add(region, "grassland", max(float(value), 0.0) * 1e6)

    df = _dict_to_df(data)
    if "grassland" not in df.columns:
        df["grassland"] = 0.0
    return df


def _setup_regions(regions_path: str) -> Tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
    gdf = gpd.read_file(regions_path)
    if gdf.crs is None:
        logger.warning("Regions GeoDataFrame missing CRS; assuming EPSG:4326")
        gdf = gdf.set_crs(4326, allow_override=True)
    else:
        gdf = gdf.to_crs(4326)

    if "region" not in gdf.columns:
        raise ValueError("Regions GeoDataFrame must contain a 'region' column")

    gdf = gdf.set_index("region", drop=False)
    gdf_eq = gdf.to_crs("+proj=eqearth")
    gdf_eq = gdf_eq.set_index("region", drop=False)
    return gdf, gdf_eq


def _plot_pie_map(
    by_region: pd.DataFrame,
    gdf: gpd.GeoDataFrame,
    gdf_eq: gpd.GeoDataFrame,
    colors: Dict[str, str],
    output_path: str,
    title: str,
    legend_title: str,
    pie_scale_title: str,
    pie_unit: str,
    min_total: float,
) -> None:
    out = Path(output_path)
    out.parent.mkdir(parents=True, exist_ok=True)

    fig, ax = plt.subplots(
        figsize=(13, 6.5),
        dpi=150,
        subplot_kw={"projection": ccrs.EqualEarth()},
    )
    ax.set_facecolor("#f7f9fb")
    ax.set_global()
    plate = ccrs.PlateCarree()

    ax.add_geometries(
        gdf.geometry,
        crs=plate,
        facecolor="#e6eef2",
        edgecolor="#666666",
        linewidth=0.3,
        zorder=1,
    )

    model_regions = gdf.index
    present_regions = by_region.index if not by_region.empty else pd.Index([])
    missing_regions = model_regions.difference(present_regions)

    if len(missing_regions) > 0:
        ax.add_geometries(
            gdf.loc[missing_regions].geometry,
            crs=plate,
            facecolor="#f0f0f0",
            edgecolor="#666666",
            linewidth=0.3,
            hatch="..",
            zorder=1.5,
        )

    filtered = by_region.copy()
    if not filtered.empty:
        filtered = filtered.reindex(model_regions.intersection(filtered.index)).fillna(
            0.0
        )
        crop_totals = filtered.sum(axis=0).sort_values(ascending=False)
        crop_totals = crop_totals[crop_totals >= min_total]
        if crop_totals.empty:
            filtered = filtered.iloc[:, 0:0]
        else:
            filtered = filtered.loc[:, crop_totals.index]

    if not filtered.empty:
        crops = list(filtered.columns)
        color_list = [colors[c] for c in crops]
        totals = filtered.sum(axis=1)
        if totals.max() > 0:
            xmin, ymin, xmax, ymax = gdf_eq.total_bounds
            width = xmax - xmin
            height = ymax - ymin
            r_max = 0.024 * max(width, height)
            radii = (np.sqrt(totals / totals.max()) * r_max).fillna(0.0)

            centroids = gdf_eq.geometry.representative_point()
            for region in filtered.index:
                point = centroids.loc[region]
                x, y = point.x, point.y
                values = filtered.loc[region].values.tolist()
                _draw_pie(ax, x, y, values, color_list, float(radii.get(region, 0.0)))

            handles = [mpatches.Patch(facecolor=colors[c], label=c) for c in crops]
            legend1 = ax.legend(
                handles=handles,
                title=legend_title,
                loc="lower left",
                bbox_to_anchor=(0.15, 0.03),
                fontsize=8,
                title_fontsize=9,
                frameon=True,
                borderpad=0.8,
                labelspacing=0.6,
                handletextpad=0.6,
            )
            legend1._legend_box.align = "left"
            ax.add_artist(legend1)

            ref_fracs = np.array([0.25, 0.5, 1.0])
            ref_vals = np.unique(totals.max() * ref_fracs)
            handle_scale = 900.0
            size_handles = [
                ax.scatter(
                    [],
                    [],
                    s=float((val / totals.max()) * handle_scale),
                    facecolors="#d0d7de",
                    edgecolors="#555555",
                    linewidths=0.6,
                    alpha=0.7,
                )
                for val in ref_vals
            ]
            size_labels = [f"{val:,.0f} {pie_unit}" for val in ref_vals]
            legend2 = ax.legend(
                size_handles,
                size_labels,
                title=pie_scale_title,
                loc="lower left",
                bbox_to_anchor=(0.6, 0.03),
                fontsize=8,
                title_fontsize=9,
                frameon=True,
                scatterpoints=1,
                handlelength=1.5,
                borderpad=0.8,
                labelspacing=2,
            )
            legend2._legend_box.align = "left"
            ax.add_artist(legend2)

    if len(missing_regions) > 0:
        hatch_handle = mpatches.Patch(
            facecolor="#f0f0f0",
            edgecolor="#666666",
            hatch="..",
            label="No activity",
        )
        ax.legend(
            handles=[hatch_handle],
            loc="lower right",
            bbox_to_anchor=(0.99, 0.02),
            fontsize=8,
            frameon=True,
        )

    for name, spine in ax.spines.items():
        if name == "geo":
            spine.set_visible(True)
            spine.set_linewidth(0.5)
            spine.set_edgecolor("#555555")
            spine.set_alpha(0.7)
        else:
            spine.set_visible(False)

    gl = ax.gridlines(
        draw_labels=True,
        crs=plate,
        linewidth=0.35,
        color="#888888",
        alpha=0.45,
        linestyle="--",
    )
    gl.xlocator = mticker.FixedLocator(np.arange(-180, 181, 30))
    gl.ylocator = mticker.FixedLocator(np.arange(-60, 61, 15))
    gl.xformatter = LongitudeFormatter(number_format=".0f")
    gl.yformatter = LatitudeFormatter(number_format=".0f")
    gl.xlabel_style = {"size": 8, "color": "#555555"}
    gl.ylabel_style = {"size": 8, "color": "#555555"}
    gl.top_labels = False
    gl.right_labels = False

    ax.set_xlabel("Longitude", fontsize=8, color="#555555")
    ax.set_ylabel("Latitude", fontsize=8, color="#555555")
    ax.set_title(title)
    plt.tight_layout()
    fig.savefig(out, bbox_inches="tight", dpi=300)
    plt.close(fig)

    if not by_region.empty:
        csv_path = out.with_suffix("")
        csv_out = csv_path.parent / f"{csv_path.name}_by_region.csv"
        by_region.sort_index(axis=1).to_csv(csv_out, index=True)
        logger.info("Saved regional totals to %s", csv_out)

    logger.info("Saved map to %s", out)


def _draw_pie(
    ax: plt.Axes,
    x: float,
    y: float,
    sizes: Iterable[float],
    colors: Iterable[str],
    radius: float,
) -> None:
    total = float(sum(sizes))
    if total <= 0 or radius <= 0:
        return
    sizes = list(sizes)
    colors = list(colors)
    angles = np.cumsum([0.0] + [s / total * 360.0 for s in sizes])
    for i, size in enumerate(sizes):
        if size <= 0:
            continue
        wedge = mpatches.Wedge(
            center=(x, y),
            r=radius,
            theta1=angles[i],
            theta2=angles[i + 1],
            facecolor=colors[i],
            edgecolor="white",
            linewidth=0.4,
            alpha=0.85,
            zorder=10,
        )
        ax.add_patch(wedge)
    circ = mpatches.Circle(
        (x, y),
        radius=radius,
        facecolor="none",
        edgecolor="#444444",
        linewidth=0.3,
        alpha=0.7,
        zorder=11,
    )
    ax.add_patch(circ)


def _build_color_mapping(
    crops: Iterable[str],
    overrides: Mapping[str, str],
    fallback_cmap_name: str,
) -> Dict[str, str]:
    """Return color per crop, falling back to a named matplotlib colormap."""

    try:
        cmap = plt.get_cmap(fallback_cmap_name)
    except ValueError:
        logger.warning(
            "Unknown colormap '%s'; defaulting to 'Set3' for crop colors",
            fallback_cmap_name,
        )
        cmap = plt.get_cmap("Set3")

    listed_colors = list(getattr(cmap, "colors", []))
    colors: Dict[str, str] = {}
    fallback_idx = 0

    for crop in crops:
        override = overrides.get(crop)
        if isinstance(override, str) and override:
            colors[crop] = override
            continue

        if listed_colors:
            color = listed_colors[fallback_idx % len(listed_colors)]
        else:
            size = max(getattr(cmap, "N", 256), 1)
            value = (fallback_idx % size) / max(size - 1, 1)
            color = cmap(value)

        colors[crop] = mcolors.to_hex(color)
        fallback_idx += 1

    return colors


[docs] def main() -> None: n = pypsa.Network(snakemake.input.network) # type: ignore[name-defined] regions_path: str = snakemake.input.regions # type: ignore[name-defined] prod_pdf: str = snakemake.output.production_pdf # type: ignore[name-defined] land_pdf: str = snakemake.output.land_pdf # type: ignore[name-defined] snapshot = "now" if "now" in n.snapshots else n.snapshots[0] gdf, gdf_eq = _setup_regions(regions_path) production = _aggregate_production_by_region(n, snapshot) land_use = _aggregate_land_use_by_region(n, snapshot) all_regions = gdf.index all_columns = sorted(set(production.columns) | set(land_use.columns)) if production.empty: production = pd.DataFrame( index=pd.Index([], name="region"), columns=all_columns ) else: production = production.reindex( all_regions.intersection(production.index) ).fillna(0.0) production = production.reindex(columns=all_columns, fill_value=0.0) if land_use.empty: land_use = pd.DataFrame(index=pd.Index([], name="region"), columns=all_columns) else: land_use = land_use.reindex(all_regions.intersection(land_use.index)).fillna( 0.0 ) land_use = land_use.reindex(columns=all_columns, fill_value=0.0) params = getattr(snakemake, "params", None) # type: ignore[name-defined] overrides: Mapping[str, str] = {} fallback_cmap_name = "Set3" if params is not None: raw_overrides = getattr(params, "crop_colors", None) if raw_overrides is not None: overrides = dict(raw_overrides) fallback_cmap_name = getattr(params, "fallback_cmap", fallback_cmap_name) colors = _build_color_mapping(all_columns, overrides, fallback_cmap_name) _plot_pie_map( production, gdf, gdf_eq, colors, prod_pdf, title="Crop and Grassland Output by Region", legend_title="Crops / grassland", pie_scale_title="Pie size ∝ total production", pie_unit="t", min_total=10_000.0, ) _plot_pie_map( land_use, gdf, gdf_eq, colors, land_pdf, title="Land Use by Crop and Grassland", legend_title="Crops / grassland", pie_scale_title="Pie size ∝ total land area", pie_unit="ha", min_total=1_000.0, )
if __name__ == "__main__": main()