Source code for workflow.scripts.plotting.plot_resource_classes_map

#! /usr/bin/env python3
# SPDX-FileCopyrightText: 2025 Koen van Greevenbroek
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Plot a Plate Carrée map of resource classes by grid cell."""

import logging
from pathlib import Path
from typing import Iterable, Sequence

import cartopy.crs as ccrs
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
import matplotlib

matplotlib.use("pdf")
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import xarray as xr
import geopandas as gpd
from rasterio.transform import Affine, array_bounds

logger = logging.getLogger(__name__)


def _load_resource_classes(
    path: Path,
) -> tuple[np.ndarray, Affine, Sequence[float] | None]:
    with xr.open_dataset(path, mode="r") as ds:
        if "resource_class" not in ds:
            raise ValueError("NetCDF must contain a 'resource_class' variable")
        data = ds["resource_class"].values.astype(float)
        transform_vals = ds.attrs.get("transform")
        if transform_vals is None:
            raise ValueError("NetCDF missing 'transform' attribute")
        transform = Affine(*transform_vals)
        crs_wkt = ds.attrs.get("crs_wkt")
        if not crs_wkt:
            raise ValueError("NetCDF missing 'crs_wkt' attribute for CRS")
        if abs(transform.b) > 1e-9 or abs(transform.d) > 1e-9:
            raise ValueError("Resource class grid must not contain rotation terms")
        quantiles_attr = ds.attrs.get("quantiles")
        if quantiles_attr is None:
            quantiles = None
        else:
            quantiles = tuple(float(x) for x in np.atleast_1d(quantiles_attr))
    return data, transform, quantiles


def _subdued_colors(count: int) -> Iterable[str]:
    if count <= 0:
        return []
    cmap = plt.colormaps["YlGnBu"]
    span = max(count - 1, 1)
    colors = [
        cmap(0.3 + 0.5 * (i / span)) if count > 1 else cmap(0.45) for i in range(count)
    ]
    return [mcolors.to_hex(c) for c in reversed(colors)]


def _quantile_labels(quantiles: Sequence[float] | None, count: int) -> list[str]:
    if not quantiles or len(quantiles) < count + 1:
        return [f"Class {i + 1}" for i in range(count)]
    labels: list[str] = []
    for i in range(count):
        lo = float(quantiles[i]) * 100.0
        hi = float(quantiles[i + 1]) * 100.0
        lo_str = f"{lo:.0f}" if abs(lo - round(lo)) < 1e-6 else f"{lo:.1f}"
        hi_str = f"{hi:.0f}" if abs(hi - round(hi)) < 1e-6 else f"{hi:.1f}"
        if i == count - 1 or hi >= 100.0:
            labels.append(f"Class {i + 1} ({lo_str}+%)")
        else:
            labels.append(f"Class {i + 1} ({lo_str}-{hi_str}%)")
    return labels


[docs] def plot_resource_classes_map( classes_path: Path, regions_path: Path, output_path: Path ) -> None: data, transform, quantiles = _load_resource_classes(classes_path) valid = data[data >= 0] if valid.size == 0: raise ValueError("Resource class grid does not contain any classified cells") class_count = int(valid.max()) + 1 masked = np.ma.masked_less(data, 0) colors = list(_subdued_colors(class_count)) cmap = mcolors.ListedColormap(colors) bounds = np.arange(-0.5, class_count + 0.5, 1.0) norm = mcolors.BoundaryNorm(bounds, cmap.N) height, width = masked.shape lon_min, lat_min, lon_max, lat_max = array_bounds(height, width, transform) extent = (lon_min, lon_max, lat_min, lat_max) gdf = gpd.read_file(regions_path) if gdf.crs is None: logger.warning("Regions input CRS missing; assuming EPSG:4326 (WGS84)") gdf = gdf.set_crs(4326, allow_override=True) gdf = gdf.to_crs(4326) output_path.parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots( figsize=(12, 6), dpi=150, subplot_kw={"projection": ccrs.PlateCarree()}, ) ax.set_facecolor("#f7f9fb") ax.imshow( masked, origin="upper", extent=extent, transform=ccrs.PlateCarree(), cmap=cmap, norm=norm, interpolation="nearest", zorder=1, ) ax.set_global() ax.add_geometries( gdf.geometry, crs=ccrs.PlateCarree(), facecolor="none", edgecolor="#444444", linewidth=0.4, zorder=2, ) for name, spine in ax.spines.items(): if name == "geo": spine.set_visible(True) spine.set_linewidth(0.5) spine.set_edgecolor("#555555") spine.set_alpha(0.7) else: spine.set_visible(False) gl = ax.gridlines( draw_labels=True, crs=ccrs.PlateCarree(), linewidth=0.35, color="#888888", alpha=0.45, linestyle="--", ) gl.xlocator = mticker.FixedLocator(np.arange(-180, 181, 30)) gl.ylocator = mticker.FixedLocator(np.arange(-60, 61, 15)) gl.xformatter = LongitudeFormatter(number_format=".0f") gl.yformatter = LatitudeFormatter(number_format=".0f") gl.xlabel_style = {"size": 8, "color": "#555555"} gl.ylabel_style = {"size": 8, "color": "#555555"} gl.top_labels = False gl.right_labels = False ax.set_xlabel("Longitude", fontsize=8, color="#555555") ax.set_ylabel("Latitude", fontsize=8, color="#555555") sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap) sm.set_array([]) cbar = fig.colorbar( sm, ax=ax, fraction=0.032, pad=0.02, ticks=np.arange(class_count) ) cbar.ax.set_yticklabels(_quantile_labels(quantiles, class_count)) cbar.set_label("Resource class") ax.set_title("Resource Classes by Grid Cell") plt.tight_layout() fig.savefig(output_path, bbox_inches="tight", dpi=300) plt.close(fig) logger.info("Saved resource class map to %s", output_path)
if __name__ == "__main__": plot_resource_classes_map( Path(snakemake.input.classes), # type: ignore[name-defined] Path(snakemake.input.regions), # type: ignore[name-defined] Path(snakemake.output.pdf), # type: ignore[name-defined] )