# SPDX-FileCopyrightText: 2025 Koen van Greevenbroek
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Pre-compute health data for SOS2 linearisation in the solver."""
from collections.abc import Iterable
import logging
import math
from pathlib import Path
import geopandas as gpd
import numpy as np
import numpy.typing as npt
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from workflow.scripts.logging_config import setup_script_logging
AGE_BUCKETS = [
"<1",
"1-4",
"5-9",
"10-14",
"15-19",
"20-24",
"25-29",
"30-34",
"35-39",
"40-44",
"45-49",
"50-54",
"55-59",
"60-64",
"65-69",
"70-74",
"75-79",
"80-84",
"85-89",
"90-94",
"95+",
]
# Age utilities
def _age_bucket_min(age: str) -> int:
"""Return the lower bound of an age bucket label like '25-29' or '95+'."""
age = str(age)
if age.startswith("<"):
return 0
if "-" in age:
return int(age.split("-")[0])
if age.endswith("+"):
return int(age.rstrip("+"))
return 0
# Logger will be configured in __main__ block
logger = logging.getLogger(__name__)
def _load_life_expectancy(path: str) -> pd.Series:
"""Load processed life expectancy data from prepare_life_table.py output."""
df = pd.read_csv(path)
if df.empty:
raise ValueError("Life table file is empty")
required_cols = {"age", "life_exp"}
if not required_cols.issubset(df.columns):
raise ValueError(f"Life table missing required columns: {required_cols}")
# Validate all expected age buckets are present
missing = [bucket for bucket in AGE_BUCKETS if bucket not in df["age"].values]
if missing:
raise ValueError(
"Life table missing life expectancy entries for age buckets: "
+ ", ".join(missing)
)
series = df.set_index("age")["life_exp"]
series.name = "life_exp"
return series
def _build_country_clusters(
regions_path: str,
countries: Iterable[str],
n_clusters: int,
population: pd.DataFrame | None = None,
gdp_per_capita: pd.DataFrame | None = None,
weights: dict[str, float] | None = None,
) -> tuple[pd.Series, dict[int, list[str]]]:
"""
Cluster countries into health regions using multi-objective criteria.
Objectives (controlled by weights):
- geography: Geographic proximity (minimize spatial spread)
- gdp: GDP per capita similarity (group similar economies)
- population: Population balance (equalize total population across clusters)
Parameters
----------
regions_path : str
Path to GeoJSON file with country boundaries
countries : Iterable[str]
ISO3 country codes to include
n_clusters : int
Target number of clusters
population : pd.DataFrame, optional
Population data with columns: country, value (in thousands)
gdp_per_capita : pd.DataFrame, optional
GDP per capita data with columns: iso3, gdp_per_capita
weights : dict, optional
Weights for clustering objectives: geography, gdp, population
Returns
-------
cluster_series : pd.Series
Country ISO3 codes as index, cluster IDs as values
cluster_to_countries : dict
Mapping from cluster ID to list of country ISO3 codes
"""
if weights is None:
weights = {"geography": 1.0, "gdp": 0.0, "population": 0.0}
regions = gpd.read_file(regions_path)
# Project to equal-area CRS and compute country centroids
regions_equal_area = regions.to_crs(6933)
dissolved = regions_equal_area.dissolve(by="country", as_index=True)
centroids = dissolved.geometry.centroid
country_order = list(dissolved.index)
# Build geographic coordinates
coords = np.column_stack([centroids.x.values, centroids.y.values])
k = max(1, min(int(n_clusters), len(coords)))
if k < int(n_clusters):
logger.info(
f"Requested {n_clusters} clusters but only {len(coords)} countries available; using {k}."
)
if len(coords) == 1:
labels = np.array([0])
else:
# Build multi-objective feature matrix
features = _build_clustering_features(
coords, country_order, gdp_per_capita, weights
)
km = KMeans(n_clusters=k, n_init=20, random_state=0)
labels = km.fit_predict(features)
# Apply population balance refinement if weight > 0
pop_weight = weights["population"]
if pop_weight > 0 and population is not None:
labels = _refine_population_balance(
labels, country_order, population, coords, pop_weight
)
dissolved["health_cluster"] = labels
cluster_series = dissolved["health_cluster"].astype(int)
grouped = cluster_series.groupby(cluster_series).groups
cluster_to_countries = {
int(cluster): sorted(indexes) for cluster, indexes in grouped.items()
}
# Log cluster statistics
_log_cluster_statistics(
cluster_series, cluster_to_countries, population, gdp_per_capita
)
return cluster_series, cluster_to_countries
def _build_clustering_features(
coords: np.ndarray,
country_order: list[str],
gdp_per_capita: pd.DataFrame | None,
weights: dict[str, float],
) -> np.ndarray:
"""
Build weighted feature matrix for clustering.
Combines geographic coordinates and GDP per capita with configurable weights.
Features are standardized before weighting to ensure comparable scales.
GDP data is assumed complete (imputation handled in retrieve_gdp_per_capita.py).
"""
w_geo = weights["geography"]
w_gdp = weights["gdp"]
# Standardize geographic coordinates
scaler = StandardScaler()
coords_scaled = scaler.fit_transform(coords)
if w_gdp > 0 and gdp_per_capita is not None:
# Map GDP to countries in order
gdp_map = gdp_per_capita.set_index("iso3")["gdp_per_capita"]
gdp_values = np.array([gdp_map[c] for c in country_order])
# Log-transform to reduce skew (GDP is typically log-normal)
gdp_log = np.log1p(gdp_values).reshape(-1, 1)
gdp_scaled = scaler.fit_transform(gdp_log)
# Apply weights (sqrt because K-means minimizes squared distances)
# Geography has 2 dimensions, GDP has 1
total_weight = 2 * w_geo + w_gdp
geo_factor = np.sqrt(w_geo / total_weight)
gdp_factor = np.sqrt(w_gdp / total_weight)
features = np.column_stack(
[
coords_scaled * geo_factor,
gdp_scaled * gdp_factor,
]
)
else:
# Geography only (original behavior)
features = coords_scaled
return features
def _refine_population_balance(
labels: np.ndarray,
country_order: list[str],
population: pd.DataFrame,
coords: np.ndarray,
pop_weight: float,
max_iter: int = 100,
) -> np.ndarray:
"""
Iteratively refine cluster assignments to improve population balance.
Moves boundary countries from over-populated to under-populated clusters
until the population coefficient of variation (CV) is acceptable.
The target CV is determined by the population weight:
- Higher weight = stricter balance requirement (lower target CV)
"""
labels = labels.copy()
# Get total population per country (sum across years if multiple)
pop_by_country = (
population[population["age"] == "all-a"].groupby("country")["value"].sum()
)
country_pop = np.array([pop_by_country.get(c, 0.0) for c in country_order])
# Target CV based on population weight (higher weight = stricter balance)
# Weight 0.3 -> target CV ~0.6, Weight 1.0 -> target CV ~0.3
target_cv = max(0.2, 0.8 - 0.5 * pop_weight)
for iteration in range(max_iter):
# Compute cluster populations
cluster_ids = np.unique(labels)
cluster_pops = {cid: country_pop[labels == cid].sum() for cid in cluster_ids}
# Compute coefficient of variation
pop_values = np.array(list(cluster_pops.values()))
if pop_values.mean() == 0:
break
cv = pop_values.std() / pop_values.mean()
if cv <= target_cv:
logger.info(
f"Population balance achieved after {iteration} iterations "
f"(CV={cv:.3f}, target={target_cv:.3f})"
)
break
# Find most over-populated and under-populated clusters
max_cluster = max(cluster_pops, key=cluster_pops.get)
min_cluster = min(cluster_pops, key=cluster_pops.get)
if max_cluster == min_cluster:
break
# Find boundary country in over-populated cluster (furthest from centroid)
in_max = np.where(labels == max_cluster)[0]
if len(in_max) <= 1:
# Can't remove from a single-country cluster
break
cluster_coords = coords[in_max]
centroid = cluster_coords.mean(axis=0)
dists = np.linalg.norm(cluster_coords - centroid, axis=1)
boundary_local_idx = dists.argmax()
boundary_idx = in_max[boundary_local_idx]
# Reassign to under-populated cluster
labels[boundary_idx] = min_cluster
else:
logger.info(
f"Population balance refinement reached max iterations "
f"(CV={cv:.3f}, target={target_cv:.3f})"
)
return labels
def _log_cluster_statistics(
cluster_series: pd.Series,
cluster_to_countries: dict[int, list[str]],
population: pd.DataFrame | None,
gdp_per_capita: pd.DataFrame | None,
) -> None:
"""Log summary statistics about the clustering result."""
n_clusters = len(cluster_to_countries)
n_countries = len(cluster_series)
logger.info(f"Created {n_clusters} health clusters from {n_countries} countries")
if population is not None:
pop_by_country = (
population[population["age"] == "all-a"].groupby("country")["value"].sum()
)
cluster_pops = []
for members in cluster_to_countries.values():
cluster_pop = sum(pop_by_country.get(c, 0.0) for c in members)
cluster_pops.append(cluster_pop)
if cluster_pops:
pop_arr = np.array(cluster_pops) * 1000 # Convert to persons
cv = pop_arr.std() / pop_arr.mean() if pop_arr.mean() > 0 else 0
logger.info(
f"Cluster population stats: min={pop_arr.min() / 1e6:.1f}M, "
f"max={pop_arr.max() / 1e6:.1f}M, CV={cv:.3f}"
)
if gdp_per_capita is not None:
gdp_map = gdp_per_capita.set_index("iso3")["gdp_per_capita"]
for cluster_id, members in list(cluster_to_countries.items())[:3]:
gdp_values = [gdp_map.get(c) for c in members if c in gdp_map.index]
if gdp_values:
gdp_arr = np.array(gdp_values)
logger.info(
f"Cluster {cluster_id}: {len(members)} countries, "
f"GDP/cap ${gdp_arr.mean():,.0f} (std=${gdp_arr.std():,.0f})"
)
[docs]
class RelativeRiskTable(dict[tuple[str, str], dict[str, np.ndarray]]):
"""Container mapping (risk, cause) to exposure grids and log RR values."""
def _build_rr_tables(
rr_df: pd.DataFrame,
risk_factors: Iterable[str],
risk_cause_map: dict[str, list[str]],
) -> tuple[RelativeRiskTable, dict[str, float]]:
"""Build lookup tables for relative risk curves by (risk, cause) pairs.
Returns:
table: Dict mapping (risk, cause) to exposure arrays and log(RR) values
max_exposure_g_per_day: Dict mapping risk factor to maximum exposure level in data
"""
table: RelativeRiskTable = RelativeRiskTable()
max_exposure_g_per_day: dict[str, float] = dict.fromkeys(risk_factors, 0.0)
allowed = {(risk, cause) for risk in risk_factors for cause in risk_cause_map[risk]}
seen_pairs: set[tuple[str, str]] = set()
seen_risks: set[str] = set()
for (risk, cause), grp in rr_df.groupby(["risk_factor", "cause"], sort=True):
if (risk, cause) not in allowed:
continue
grp = grp.sort_values("exposure_g_per_day")
exposures = grp["exposure_g_per_day"].astype(float).values
if len(exposures) == 0:
continue
log_rr_mean = np.log(grp["rr_mean"].astype(float).values)
table[(risk, cause)] = {
"exposures": exposures,
"log_rr_mean": log_rr_mean,
}
max_exposure_g_per_day[risk] = max(
max_exposure_g_per_day[risk], float(exposures.max())
)
seen_risks.add(risk)
seen_pairs.add((risk, cause))
missing_pairs = sorted(allowed - seen_pairs)
if missing_pairs:
text = ", ".join([f"{r}:{c}" for r, c in missing_pairs])
raise ValueError(f"Relative risk table is missing risk-cause pairs: {text}")
return table, max_exposure_g_per_day
def _derive_tmrel_from_rr(
rr_lookup: RelativeRiskTable, risk_to_causes: dict[str, list[str]]
) -> dict[str, float]:
"""Derive TMREL intake per risk from empirical RR curves.
For each risk, find the intake that minimizes the sum of log(RR) across all
its causes (i.e., the product of RRs), evaluated on the union of exposure
knots in the RR tables.
"""
tmrel: dict[str, float] = {}
for risk, causes in risk_to_causes.items():
exposure_grid: list[float] = []
for cause in causes:
exposure_grid.extend(rr_lookup[(risk, cause)]["exposures"])
if not exposure_grid:
raise ValueError(f"No RR exposure data for risk factor: {risk}")
grid = sorted({float(x) for x in exposure_grid})
best_intake = grid[0]
best_log = math.inf
for intake in grid:
total_log = 0.0
for cause in causes:
total_log += math.log(_evaluate_rr(rr_lookup, risk, cause, intake))
if total_log < best_log:
best_log = total_log
best_intake = intake
tmrel[risk] = best_intake
return tmrel
def _evaluate_rr(
table: RelativeRiskTable, risk: str, cause: str, intake: float
) -> float:
"""Interpolate relative risk for given intake using log-linear interpolation."""
data = table[(risk, cause)]
exposures: npt.NDArray[np.floating] = data["exposures"]
log_rr: npt.NDArray[np.floating] = data["log_rr_mean"]
if intake <= exposures[0]:
return float(math.exp(log_rr[0]))
if intake >= exposures[-1]:
return float(math.exp(log_rr[-1]))
log_val = float(np.interp(intake, exposures, log_rr))
return float(math.exp(log_val))
def _load_input_data(
snakemake,
cfg_countries: list[str],
reference_year: int,
) -> tuple:
"""Load and perform initial processing of all input datasets."""
# Load population data first (needed for clustering)
pop = pd.read_csv(snakemake.input["population"])
pop["value"] = pd.to_numeric(pop["value"], errors="coerce") / 1_000.0
# Load GDP per capita data
gdp_per_capita = pd.read_csv(snakemake.input["gdp"])
# Get clustering weights from config
health_cfg = snakemake.params["health"]
clustering_cfg = health_cfg["clustering"]
weights = clustering_cfg["weights"]
cluster_series, cluster_to_countries = _build_country_clusters(
snakemake.input["regions"],
cfg_countries,
int(health_cfg["region_clusters"]),
population=pop,
gdp_per_capita=gdp_per_capita,
weights=weights,
)
cluster_map = cluster_series.rename("health_cluster").reset_index()
cluster_map.columns = ["country_iso3", "health_cluster"]
cluster_map = cluster_map.sort_values("country_iso3")
diet = pd.read_csv(snakemake.input["diet"])
rr_df = pd.read_csv(snakemake.input["relative_risks"])
dr = pd.read_csv(
snakemake.input["dr"],
header=None,
names=["age", "cause", "country", "year", "value"],
)
life_exp = _load_life_expectancy(snakemake.input["life_table"])
return (
cluster_series,
cluster_to_countries,
cluster_map,
diet,
rr_df,
dr,
pop,
life_exp,
)
def _filter_and_prepare_data(
diet: pd.DataFrame,
dr: pd.DataFrame,
pop: pd.DataFrame,
rr_df: pd.DataFrame,
cfg_countries: list[str],
reference_year: int,
life_exp: pd.Series,
risk_factors: list[str],
risk_cause_map: dict[str, list[str]],
intake_age_min: int,
) -> tuple:
"""Filter datasets to reference year and compute derived quantities."""
# Filter dietary intake data to adult buckets and compute population-weighted means
adult_ages = {
age for age in diet["age"].unique() if _age_bucket_min(age) >= intake_age_min
}
diet = diet[
(diet["age"].isin(adult_ages))
& (diet["year"] == reference_year)
& (diet["country"].isin(cfg_countries))
].copy()
# Build relative risk lookup tables
rr_lookup, max_exposure_g_per_day = _build_rr_tables(
rr_df, risk_factors, risk_cause_map
)
# Filter mortality and population data
dr = dr[(dr["year"] == reference_year) & (dr["country"].isin(cfg_countries))].copy()
pop = pop[
(pop["year"] == reference_year) & (pop["country"].isin(cfg_countries))
].copy()
valid_ages = life_exp.index
dr = dr[dr["age"].isin(valid_ages)].copy()
pop_age = pop[pop["age"].isin(valid_ages)].copy()
pop_total = (
pop[pop["age"] == "all-a"]
.groupby("country")["value"]
.sum()
.astype(float)
.reindex(cfg_countries)
)
# Determine relevant risk-cause pairs
risk_to_causes = {risk: list(risk_cause_map[risk]) for risk in risk_factors}
relevant_causes = sorted(
{cause for causes in risk_to_causes.values() for cause in causes}
)
dr = dr[dr["cause"].isin(relevant_causes)].copy()
# Map diet items to risk factors
item_to_risk = {
"whole_grains": "whole_grains",
"legumes": "legumes",
"soybeans": "legumes",
"nuts_seeds": "nuts_seeds",
"vegetables": "vegetables",
"fruits_trop": "fruits",
"fruits_temp": "fruits",
"fruits_starch": "fruits",
"fruits": "fruits",
"beef": "red_meat",
"lamb": "red_meat",
"pork": "red_meat",
"red_meat": "red_meat",
"prc_meat": "prc_meat",
"shellfish": "fish",
"fish_freshw": "fish",
"fish_pelag": "fish",
"fish_demrs": "fish",
"fish": "fish",
"sugar": "sugar",
}
diet["risk_factor"] = diet["item"].map(item_to_risk)
diet = diet.dropna(subset=["risk_factor"])
# Population-weighted adult intakes per country and risk factor
# The dietary intake file is already aggregated to adult bands ("11-74 years", "75+ years").
# Population file is per narrow age band, so collapse to total adult population per country.
pop_adult = (
pop_age[pop_age["age"].isin(adult_ages)]
.groupby("country")["value"]
.sum()
.astype(float)
.rename("population_adult")
)
if pop_adult.isna().any() or (pop_adult <= 0).any():
raise ValueError("Adult population totals are missing or non-positive")
diet = diet.rename(columns={"value": "intake"})
diet["intake"] = pd.to_numeric(diet["intake"], errors="coerce")
if diet["intake"].isna().any():
raise ValueError("Dietary intake contains non-numeric values")
# For each country/risk, take the adult-age mean intake and weight by total adult population
diet_grouped = (
diet.groupby(["country", "risk_factor"])["intake"].mean().rename("intake_mean")
)
intake_by_country = diet_grouped.unstack(fill_value=0.0).reindex(
cfg_countries, fill_value=0.0
)
return (
dr,
pop_age,
pop_total,
rr_lookup,
max_exposure_g_per_day,
relevant_causes,
risk_to_causes,
intake_by_country,
)
def _compute_baseline_health_metrics(
dr: pd.DataFrame,
pop_age: pd.DataFrame,
life_exp: pd.Series,
) -> pd.DataFrame:
"""Compute baseline death counts and YLL statistics by country."""
pop_age = pop_age.rename(columns={"value": "population"})
dr = dr.rename(columns={"value": "death_rate"})
combo = dr.merge(pop_age, on=["age", "country", "year"], how="left").merge(
life_exp.rename("life_exp"), left_on="age", right_index=True, how="left"
)
combo["population"] = combo["population"].fillna(0.0)
combo["death_rate"] = combo["death_rate"].fillna(0.0)
combo["death_count"] = combo["death_rate"] * combo["population"]
combo["yll"] = combo["death_count"] * combo["life_exp"]
return combo
def _build_intake_caps(
max_exposure_g_per_day: dict[str, float],
intake_cap_limit: float,
) -> dict[str, float]:
"""Apply a uniform generous intake cap across all risk factors."""
if intake_cap_limit <= 0:
return dict(max_exposure_g_per_day)
caps = dict(max_exposure_g_per_day)
for risk in list(caps.keys()):
caps[risk] = max(caps[risk], float(intake_cap_limit))
return caps
def _process_health_clusters(
cluster_to_countries: dict[int, list[str]],
pop_total: pd.Series,
combo: pd.DataFrame,
risk_factors: list[str],
intake_by_country: pd.DataFrame,
intake_caps_g_per_day: dict[str, float],
rr_lookup: RelativeRiskTable,
risk_to_causes: dict[str, list[str]],
relevant_causes: list[str],
tmrel_g_per_day: dict[str, float],
) -> tuple:
"""Process each health cluster to compute baseline metrics and intakes."""
cluster_summary_rows = []
cluster_cause_rows = []
cluster_risk_baseline_rows = []
baseline_intake_registry: dict[str, set] = {risk: set() for risk in risk_factors}
for cluster_id, members in cluster_to_countries.items():
pop_weights = pop_total.reindex(members).fillna(0.0)
total_pop_thousand = float(pop_weights.sum())
if total_pop_thousand <= 0:
continue
total_population_persons = total_pop_thousand * 1_000.0
cluster_combo = combo[combo["country"].isin(members)]
yll_by_cause_cluster = cluster_combo.groupby("cause")["yll"].sum()
cluster_summary_rows.append(
{
"health_cluster": int(cluster_id),
"population_persons": total_population_persons,
}
)
log_rr_ref_totals: dict[str, float] = dict.fromkeys(relevant_causes, 0.0)
log_rr_baseline_totals: dict[str, float] = dict.fromkeys(relevant_causes, 0.0)
for risk in risk_factors:
if risk not in intake_by_country.columns:
baseline_intake = 0.0
else:
baseline_intake = (
intake_by_country[risk].reindex(members).fillna(0.0) * pop_weights
).sum() / total_pop_thousand
baseline_intake = float(baseline_intake)
if not math.isfinite(baseline_intake):
baseline_intake = 0.0
max_exposure = float(intake_caps_g_per_day[risk])
baseline_intake = max(0.0, min(baseline_intake, max_exposure))
baseline_intake_registry.setdefault(risk, set()).add(baseline_intake)
cluster_risk_baseline_rows.append(
{
"health_cluster": int(cluster_id),
"risk_factor": risk,
"baseline_intake_g_per_day": baseline_intake,
}
)
# Use TMREL intake as reference point for health cost calculations.
# rr_ref is calculated at TMREL, ensuring that in the solver, health cost
# is zero when RR = RR_ref (i.e., when intake is at optimal levels).
# This implements Cost = V * YLL_base * (RR/RR_ref - 1).
tmrel_intake = float(tmrel_g_per_day[risk])
if not math.isfinite(tmrel_intake):
tmrel_intake = 0.0
tmrel_intake = max(0.0, min(tmrel_intake, max_exposure))
causes = risk_to_causes[risk]
for cause in causes:
if (risk, cause) not in rr_lookup:
continue
rr_ref = _evaluate_rr(rr_lookup, risk, cause, tmrel_intake)
log_rr_ref_totals[cause] = log_rr_ref_totals[cause] + math.log(rr_ref)
rr_base = _evaluate_rr(rr_lookup, risk, cause, baseline_intake)
log_rr_baseline_totals[cause] = log_rr_baseline_totals[
cause
] + math.log(rr_base)
for cause in relevant_causes:
log_rr_baseline = log_rr_baseline_totals[cause]
rr_baseline = math.exp(log_rr_baseline)
rr_ref = math.exp(log_rr_ref_totals[cause])
paf = (
0.0 if rr_baseline <= 0 else 1.0 - rr_ref / rr_baseline
) # burden relative to TMREL
paf = max(0.0, min(1.0, paf))
yll_total = yll_by_cause_cluster.get(cause, 0.0)
yll_diet_attrib = yll_total * paf
cluster_cause_rows.append(
{
"health_cluster": int(cluster_id),
"cause": cause,
"log_rr_total_ref": log_rr_ref_totals[cause],
"log_rr_total_baseline": log_rr_baseline,
"paf_baseline": paf,
"yll_total": yll_total,
"yll_base": yll_diet_attrib,
}
)
cluster_summary = pd.DataFrame(
cluster_summary_rows,
columns=["health_cluster", "population_persons"],
)
cluster_cause_baseline = pd.DataFrame(
cluster_cause_rows,
columns=[
"health_cluster",
"cause",
"log_rr_total_ref",
"log_rr_total_baseline",
"paf_baseline",
"yll_total",
"yll_base",
],
)
cluster_risk_baseline = pd.DataFrame(
cluster_risk_baseline_rows,
columns=["health_cluster", "risk_factor", "baseline_intake_g_per_day"],
)
return (
cluster_summary,
cluster_cause_baseline,
cluster_risk_baseline,
baseline_intake_registry,
)
def _generate_breakpoint_tables(
risk_factors: list[str],
intake_caps_g_per_day: dict[str, float],
baseline_intake_registry: dict[str, set],
intake_grid_points: int,
rr_lookup: RelativeRiskTable,
risk_to_causes: dict[str, list[str]],
relevant_causes: list[str],
log_rr_points: int,
tmrel_g_per_day: dict[str, float],
) -> tuple:
"""Generate SOS2 linearization breakpoint tables for risks and causes.
Intake grids:
- Evenly spaced `intake_grid_points` over the empirical RR data range
(min→max exposure in RR table, expanded to include 0).
- Always include all empirical exposure points, TMREL, baseline intakes,
and the global intake cap for feasibility beyond the data range.
- The generous cap is *added* as a knot but does not stretch the
linspace; this keeps knot density high where RR actually changes.
Cause grids:
- `log_rr_points` evenly spaced between aggregated min/max log(RR)
implied by the risk grids above.
"""
risk_breakpoint_rows = []
cause_log_min: dict[str, float] = dict.fromkeys(relevant_causes, 0.0)
cause_log_max: dict[str, float] = dict.fromkeys(relevant_causes, 0.0)
for risk in risk_factors:
cap = float(intake_caps_g_per_day[risk])
if cap <= 0:
continue
causes = risk_to_causes[risk]
# Empirical exposure domain from RR table (may vary by cause; take union)
exposures = []
for cause in causes:
exposures = rr_lookup[(risk, cause)]["exposures"]
if exposures is not None:
exposures = list(exposures)
break
if not exposures:
continue
lo = min(0.0, float(min(exposures)))
hi_empirical = float(max(exposures))
# Even spacing only over the empirical RR range
lin = np.linspace(lo, hi_empirical, max(intake_grid_points, 2))
grid_points = {float(x) for x in lin}
grid_points.update(float(x) for x in exposures)
grid_points.add(0.0)
grid_points.add(hi_empirical)
for val in baseline_intake_registry[risk]:
grid_points.add(float(val))
# Include TMREL as a breakpoint for accurate interpolation at optimal intake
if risk in tmrel_g_per_day:
grid_points.add(float(tmrel_g_per_day[risk]))
# Add the generous cap without stretching the linspace range
grid_points.add(cap)
grid = sorted(grid_points)
for cause in causes:
key = (risk, cause)
if key not in rr_lookup:
continue
log_values: list[float] = []
for intake in grid:
rr_val = _evaluate_rr(rr_lookup, risk, cause, intake)
log_rr = math.log(rr_val)
log_values.append(log_rr)
risk_breakpoint_rows.append(
{
"risk_factor": risk,
"cause": cause,
"intake_g_per_day": float(intake),
"log_rr": log_rr,
}
)
if log_values:
cause_log_min[cause] += min(log_values)
cause_log_max[cause] += max(log_values)
risk_breakpoints = pd.DataFrame(risk_breakpoint_rows)
cause_breakpoint_rows = []
for cause in relevant_causes:
min_total = cause_log_min[cause]
max_total = cause_log_max[cause]
if not math.isfinite(min_total):
min_total = 0.0
if not math.isfinite(max_total):
max_total = 0.0
if max_total < min_total:
min_total, max_total = max_total, min_total
if abs(max_total - min_total) < 1e-6:
log_vals = np.array([min_total])
else:
log_vals = np.linspace(min_total, max_total, max(log_rr_points, 2))
for log_val in log_vals:
cause_breakpoint_rows.append(
{
"cause": cause,
"log_rr_total": float(log_val),
"rr_total": math.exp(float(log_val)),
}
)
cause_log_breakpoints = pd.DataFrame(cause_breakpoint_rows)
return risk_breakpoints, cause_log_breakpoints
[docs]
def main() -> None:
"""Main entry point for health cost preparation."""
logger = logging.getLogger(__name__)
cfg_countries: list[str] = list(snakemake.params["countries"])
health_cfg = snakemake.params["health"]
configured_risk_factors: list[str] = list(health_cfg["risk_factors"])
# Filter risk factors to only those with foods mapped in food_groups.csv
food_groups_df = pd.read_csv(snakemake.input.food_groups)
available_risk_factors = set(food_groups_df["group"].unique())
risk_factors: list[str] = [
rf for rf in configured_risk_factors if rf in available_risk_factors
]
excluded_risk_factors = set(configured_risk_factors) - set(risk_factors)
if excluded_risk_factors:
logger.warning(
"Risk factors configured but not in food_groups.csv (no foods mapped): %s. "
"These will be excluded from health cost calculations.",
sorted(excluded_risk_factors),
)
risk_cause_map: dict[str, list[str]] = {
str(risk): list(health_cfg["risk_cause_map"][risk]) for risk in risk_factors
}
reference_year = int(health_cfg["reference_year"])
intake_grid_points = int(health_cfg["intake_grid_points"])
log_rr_points = int(health_cfg["log_rr_points"])
intake_cap_limit = float(health_cfg["intake_cap_g_per_day"])
intake_age_min = int(health_cfg["intake_age_min"])
# Load input data
(
_cluster_series,
cluster_to_countries,
cluster_map,
diet,
rr_df,
dr,
pop,
life_exp,
) = _load_input_data(snakemake, cfg_countries, reference_year)
# Filter and prepare datasets
(
dr,
pop_age,
pop_total,
rr_lookup,
max_exposure_g_per_day,
relevant_causes,
risk_to_causes,
intake_by_country,
) = _filter_and_prepare_data(
diet,
dr,
pop,
rr_df,
cfg_countries,
reference_year,
life_exp,
risk_factors,
risk_cause_map,
intake_age_min,
)
tmrel_g_per_day = _derive_tmrel_from_rr(rr_lookup, risk_to_causes)
logger.info(
"Derived TMREL from RR curves for %d risks (ignoring configured TMREL values)",
len(tmrel_g_per_day),
)
intake_caps_g_per_day = _build_intake_caps(max_exposure_g_per_day, intake_cap_limit)
# Compute baseline health metrics
combo = _compute_baseline_health_metrics(
dr,
pop_age,
life_exp,
)
# Process health clusters
(
cluster_summary,
cluster_cause_baseline,
cluster_risk_baseline,
baseline_intake_registry,
) = _process_health_clusters(
cluster_to_countries,
pop_total,
combo,
risk_factors,
intake_by_country,
intake_caps_g_per_day,
rr_lookup,
risk_to_causes,
relevant_causes,
tmrel_g_per_day,
)
# Generate breakpoint tables for SOS2 linearization
risk_breakpoints, cause_log_breakpoints = _generate_breakpoint_tables(
risk_factors,
intake_caps_g_per_day,
baseline_intake_registry,
intake_grid_points,
rr_lookup,
risk_to_causes,
relevant_causes,
log_rr_points,
tmrel_g_per_day,
)
# Write outputs
output_dir = Path(snakemake.output["risk_breakpoints"]).parent
output_dir.mkdir(parents=True, exist_ok=True)
risk_breakpoints.sort_values(["risk_factor", "cause", "intake_g_per_day"]).to_csv(
snakemake.output["risk_breakpoints"], index=False
)
cluster_cause_baseline.sort_values(["health_cluster", "cause"]).to_csv(
snakemake.output["cluster_cause"], index=False
)
cause_log_breakpoints.sort_values(["cause", "log_rr_total"]).to_csv(
snakemake.output["cause_log"], index=False
)
cluster_summary.sort_values("health_cluster").to_csv(
snakemake.output["cluster_summary"], index=False
)
cluster_map.to_csv(snakemake.output["clusters"], index=False)
cluster_risk_baseline.sort_values(["health_cluster", "risk_factor"]).to_csv(
snakemake.output["cluster_risk_baseline"], index=False
)
if __name__ == "__main__":
# Configure logging
logger = setup_script_logging(log_file=snakemake.log[0] if snakemake.log else None)
main()