#! /usr/bin/env python3
# SPDX-FileCopyrightText: 2025 Koen van Greevenbroek
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Plot a Plate Carrée map of resource classes by grid cell."""
import logging
from pathlib import Path
from typing import Iterable, Sequence
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
import matplotlib
matplotlib.use("pdf")
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import xarray as xr
import geopandas as gpd
from rasterio.transform import Affine, array_bounds
logger = logging.getLogger(__name__)
def _load_resource_classes(
path: Path,
) -> tuple[np.ndarray, Affine, Sequence[float] | None]:
with xr.open_dataset(path, mode="r") as ds:
if "resource_class" not in ds:
raise ValueError("NetCDF must contain a 'resource_class' variable")
data = ds["resource_class"].values.astype(float)
transform_vals = ds.attrs.get("transform")
if transform_vals is None:
raise ValueError("NetCDF missing 'transform' attribute")
transform = Affine(*transform_vals)
crs_wkt = ds.attrs.get("crs_wkt")
if not crs_wkt:
raise ValueError("NetCDF missing 'crs_wkt' attribute for CRS")
if abs(transform.b) > 1e-9 or abs(transform.d) > 1e-9:
raise ValueError("Resource class grid must not contain rotation terms")
quantiles_attr = ds.attrs.get("quantiles")
if quantiles_attr is None:
quantiles = None
else:
quantiles = tuple(float(x) for x in np.atleast_1d(quantiles_attr))
return data, transform, quantiles
def _subdued_colors(count: int) -> Iterable[str]:
if count <= 0:
return []
cmap = plt.colormaps["YlGnBu"]
span = max(count - 1, 1)
colors = [
cmap(0.3 + 0.5 * (i / span)) if count > 1 else cmap(0.45) for i in range(count)
]
return [mcolors.to_hex(c) for c in reversed(colors)]
def _quantile_labels(quantiles: Sequence[float] | None, count: int) -> list[str]:
if not quantiles or len(quantiles) < count + 1:
return [f"Class {i + 1}" for i in range(count)]
labels: list[str] = []
for i in range(count):
lo = float(quantiles[i]) * 100.0
hi = float(quantiles[i + 1]) * 100.0
lo_str = f"{lo:.0f}" if abs(lo - round(lo)) < 1e-6 else f"{lo:.1f}"
hi_str = f"{hi:.0f}" if abs(hi - round(hi)) < 1e-6 else f"{hi:.1f}"
if i == count - 1 or hi >= 100.0:
labels.append(f"Class {i + 1} ({lo_str}+%)")
else:
labels.append(f"Class {i + 1} ({lo_str}-{hi_str}%)")
return labels
[docs]
def plot_resource_classes_map(
classes_path: Path, regions_path: Path, output_path: Path
) -> None:
data, transform, quantiles = _load_resource_classes(classes_path)
valid = data[data >= 0]
if valid.size == 0:
raise ValueError("Resource class grid does not contain any classified cells")
class_count = int(valid.max()) + 1
masked = np.ma.masked_less(data, 0)
colors = list(_subdued_colors(class_count))
cmap = mcolors.ListedColormap(colors)
bounds = np.arange(-0.5, class_count + 0.5, 1.0)
norm = mcolors.BoundaryNorm(bounds, cmap.N)
height, width = masked.shape
lon_min, lat_min, lon_max, lat_max = array_bounds(height, width, transform)
extent = (lon_min, lon_max, lat_min, lat_max)
gdf = gpd.read_file(regions_path)
if gdf.crs is None:
logger.warning("Regions input CRS missing; assuming EPSG:4326 (WGS84)")
gdf = gdf.set_crs(4326, allow_override=True)
gdf = gdf.to_crs(4326)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(
figsize=(12, 6),
dpi=150,
subplot_kw={"projection": ccrs.PlateCarree()},
)
ax.set_facecolor("#f7f9fb")
ax.imshow(
masked,
origin="upper",
extent=extent,
transform=ccrs.PlateCarree(),
cmap=cmap,
norm=norm,
interpolation="nearest",
zorder=1,
)
ax.set_global()
ax.add_geometries(
gdf.geometry,
crs=ccrs.PlateCarree(),
facecolor="none",
edgecolor="#444444",
linewidth=0.4,
zorder=2,
)
for name, spine in ax.spines.items():
if name == "geo":
spine.set_visible(True)
spine.set_linewidth(0.5)
spine.set_edgecolor("#555555")
spine.set_alpha(0.7)
else:
spine.set_visible(False)
gl = ax.gridlines(
draw_labels=True,
crs=ccrs.PlateCarree(),
linewidth=0.35,
color="#888888",
alpha=0.45,
linestyle="--",
)
gl.xlocator = mticker.FixedLocator(np.arange(-180, 181, 30))
gl.ylocator = mticker.FixedLocator(np.arange(-60, 61, 15))
gl.xformatter = LongitudeFormatter(number_format=".0f")
gl.yformatter = LatitudeFormatter(number_format=".0f")
gl.xlabel_style = {"size": 8, "color": "#555555"}
gl.ylabel_style = {"size": 8, "color": "#555555"}
gl.top_labels = False
gl.right_labels = False
ax.set_xlabel("Longitude", fontsize=8, color="#555555")
ax.set_ylabel("Latitude", fontsize=8, color="#555555")
sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
cbar = fig.colorbar(
sm, ax=ax, fraction=0.032, pad=0.02, ticks=np.arange(class_count)
)
cbar.ax.set_yticklabels(_quantile_labels(quantiles, class_count))
cbar.set_label("Resource class")
ax.set_title("Resource Classes by Grid Cell")
plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight", dpi=300)
plt.close(fig)
logger.info("Saved resource class map to %s", output_path)
if __name__ == "__main__":
plot_resource_classes_map(
Path(snakemake.input.classes), # type: ignore[name-defined]
Path(snakemake.input.regions), # type: ignore[name-defined]
Path(snakemake.output.pdf), # type: ignore[name-defined]
)