Source code for workflow.scripts.compute_resource_classes

"""
SPDX-FileCopyrightText: 2025 Koen van Greevenbroek

SPDX-License-Identifier: GPL-3.0-or-later
"""

from pathlib import Path
import numpy as np
import geopandas as gpd
import rasterio
import rasterio.features as rfeatures
import xarray as xr


[docs] def read_raster_float(path: str): src = rasterio.open(path) arr = src.read(1).astype(float) if src.nodata is not None: arr = np.where(arr == src.nodata, np.nan, arr) return arr, src
if __name__ == "__main__": # Inputs provided by Snakemake regions_path: str = snakemake.input.regions # type: ignore[name-defined] # Yield rasters as a list of paths yield_paths: list[str] = list(snakemake.input.yields) # type: ignore[attr-defined] quantiles: list[float] = ( [0.0] + list(snakemake.params.resource_class_quantiles) + [1.0] ) # type: ignore[name-defined] # Read regions and use first raster as reference for grid/CRS regions_gdf = gpd.read_file(regions_path) # Use the first yield raster as reference y0, src0 = read_raster_float(yield_paths[0]) height, width = y0.shape transform = src0.transform crs = src0.crs # Reproject regions to raster CRS if needed if regions_gdf.crs and crs and regions_gdf.crs != crs: regions_gdf = regions_gdf.to_crs(crs) # Running maximum of yields in t/ha across all provided rasters y_max = (y0 / 1000.0).astype(float) # kg/ha -> t/ha for path in yield_paths[1:]: arr, _ = read_raster_float(path) arr_tpha = arr / 1000.0 y_max = np.fmax(y_max, arr_tpha) # Rasterize regions to integer ids (0..N-1), -1 outside region_shapes = [(geom, idx) for idx, geom in enumerate(regions_gdf.geometry)] region_raster = rfeatures.rasterize( region_shapes, out_shape=(height, width), transform=transform, fill=-1, dtype=np.int32, ) # Build xarray DataArrays y_da = xr.DataArray(y_max, dims=("y", "x")) reg_da = xr.DataArray(region_raster, dims=("y", "x")) # Vectorized per-region quantiles and class assignment # Ignore cells with zero/negative potential yield so desert pixels # do not collapse the quantile bins. positive_y = xr.where((y_da > 0) & np.isfinite(y_da), y_da, np.nan) reg_quantiles = positive_y.groupby(reg_da).quantile(quantiles) thresholds = reg_quantiles.sel(group=reg_da).reset_coords(drop=True) class_da = xr.full_like(y_da, np.nan, dtype=float) for ci in range(len(quantiles) - 1): lo = thresholds.isel(quantile=ci) hi = thresholds.isel(quantile=ci + 1) if ci == len(quantiles) - 2: sel = (reg_da >= 0) & np.isfinite(y_da) & (y_da >= lo) else: sel = (reg_da >= 0) & np.isfinite(y_da) & (y_da >= lo) & (y_da < hi) class_da = xr.where(sel, float(ci), class_da) ds = xr.Dataset( { "region_id": reg_da.astype(np.int32), "resource_class": class_da.fillna(-1).astype(np.int8), } ) # Store transform/CRS/bounds as attrs for downstream use ds.attrs.update( { "transform": tuple(transform) if hasattr(transform, "__iter__") else None, "crs_wkt": crs.to_wkt() if crs else None, "height": int(height), "width": int(width), "quantiles": tuple(quantiles), } ) out_path = Path(snakemake.output[0]) # type: ignore[name-defined] out_path.parent.mkdir(parents=True, exist_ok=True) ds.to_netcdf(out_path)