#!/usr/bin/env python
# SPDX-License-Identifier: GPL-3.0-or-later
# Copyright (C) 2024-2026 SYMFLUENCE Team <dev@symfluence.org>
# -*- coding: utf-8 -*-
"""
Base Model Evaluator
This module provides the abstract base class for different evaluation variables.
"""
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
import numpy as np
import pandas as pd
import xarray as xr
from symfluence.core.mixins import ConfigurableMixin
from symfluence.evaluation import metrics
if TYPE_CHECKING:
from symfluence.core.config.models import SymfluenceConfig
[docs]
class ModelEvaluator(ConfigurableMixin, ABC):
"""
Abstract base class for hydrological model evaluation.
Provides standardized infrastructure for comparing simulated and observed
data across different hydrological variables (streamflow, snow, ET, etc.).
Handles time series alignment, period-based evaluation (calibration/validation),
and multi-metric calculation using the centralized metrics module.
Subclasses must implement:
- get_simulation_files(): Locate model output files
- extract_simulated_data(): Parse simulation results
- get_observed_data_path(): Locate observation files
- needs_routing(): Whether mizuRoute output is required
- _get_observed_data_column(): Identify data column in obs files
Attributes:
config: SymfluenceConfig instance with typed access
calibration_period: Tuple of (start, end) timestamps for calibration
evaluation_period: Tuple of (start, end) timestamps for validation
eval_timestep: Target timestep for comparison ('native', 'hourly', 'daily')
"""
def __init__(
self,
config: 'SymfluenceConfig',
project_dir: Optional[Path] = None,
logger: Optional[logging.Logger] = None,
):
self.config = config
self._project_dir = project_dir or Path(".")
self._logger = logger
# Parse time periods from typed config (with dict_key fallback for worker dicts)
calibration_period_str = self._get_config_value(
lambda: self.config.domain.calibration_period,
default='',
dict_key='CALIBRATION_PERIOD'
)
evaluation_period_str = self._get_config_value(
lambda: self.config.domain.evaluation_period,
default='',
dict_key='EVALUATION_PERIOD'
)
self.calibration_period: Tuple[Optional[pd.Timestamp], Optional[pd.Timestamp]] = self._parse_date_range(calibration_period_str)
self.evaluation_period: Tuple[Optional[pd.Timestamp], Optional[pd.Timestamp]] = self._parse_date_range(evaluation_period_str)
# Parse calibration/evaluation timestep (with dict_key fallback for worker dicts)
self.eval_timestep = self._get_config_value(
lambda: self.config.optimization.calibration_timestep,
default='native',
dict_key='CALIBRATION_TIMESTEP'
).lower()
if self.eval_timestep not in ['native', 'hourly', 'daily']:
self.logger.warning(
f"Invalid calibration_timestep '{self.eval_timestep}'. "
"Using 'native'. Valid options: 'native', 'hourly', 'daily'"
)
self.eval_timestep = 'native'
if self.eval_timestep != 'native':
self.logger.debug(f"Evaluation will use {self.eval_timestep} timestep")
@property
def variable_type(self) -> str:
"""Return the variable type for resampling behavior.
Override in subclasses for flux variables (precipitation, ET) that
should use sum aggregation instead of mean.
Returns:
'state' (default) for state variables - use mean aggregation
'flux' for flux/accumulation variables - use sum aggregation
"""
return 'state'
[docs]
def evaluate(self, sim: Any, obs: Optional[pd.Series] = None,
mizuroute_dir: Optional[Path] = None,
calibration_only: bool = True) -> Optional[Dict[str, float]]:
"""Alias for calculate_metrics for consistency with other parts of the system"""
return self.calculate_metrics(sim, obs, mizuroute_dir, calibration_only)
[docs]
def calculate_metrics(self, sim: Any, obs: Optional[pd.Series] = None,
mizuroute_dir: Optional[Path] = None,
calibration_only: bool = True) -> Optional[Dict[str, float]]:
"""
Calculate performance metrics for this target.
Args:
sim: Either a Path to simulation directory or a pre-loaded pd.Series
obs: Optional pre-loaded pd.Series of observations. If None, loads from file.
mizuroute_dir: mizuRoute simulation directory (if needed and sim is Path)
calibration_only: If True, only calculate calibration period metrics
"""
try:
# 1. Prepare simulated data
if isinstance(sim, (str, Path)):
sim_dir = Path(sim)
# Determine which simulation directory to use
if self.needs_routing() and mizuroute_dir:
output_dir = mizuroute_dir
else:
output_dir = sim_dir
# Get simulation files
sim_files = self.get_simulation_files(output_dir)
if not sim_files:
self.logger.error(f"No simulation files found in {output_dir}")
return None
# Extract simulated data
sim_data = self.extract_simulated_data(sim_files)
self.logger.debug(f"Extracted {len(sim_data)} simulated data points from {len(sim_files)} file(s)")
else:
sim_data = sim
if sim_data is None:
self.logger.error("Failed to extract simulated data")
return None
# Validate simulated data
is_valid, error_msg = self._validate_data(sim_data, 'simulated')
if not is_valid:
self.logger.error(error_msg)
return None
# 2. Prepare observed data
if obs is None:
obs_data = self._load_observed_data()
else:
obs_data = obs
if obs_data is None or len(obs_data) == 0:
self.logger.error("Failed to load observed data (check path and column names)")
return None
# Validate observed data
is_valid, error_msg = self._validate_data(obs_data, 'observed')
if not is_valid:
self.logger.error(error_msg)
return None
self.logger.debug(f"Loaded {len(obs_data)} observed data points")
# 3. Align time series and calculate metrics
metrics_dict = {}
# Always calculate metrics for calibration period if available
if self.calibration_period[0] and self.calibration_period[1]:
calib_metrics = self._calculate_period_metrics(
obs_data, sim_data, self.calibration_period, "Calib"
)
metrics_dict.update(calib_metrics)
# Also add unprefixed versions for the primary (calibration) period
# to support model runners/loggers expecting simple names
for k, v in calib_metrics.items():
unprefixed_key = k.replace("Calib_", "")
if unprefixed_key not in metrics_dict:
metrics_dict[unprefixed_key] = v
# Only calculate evaluation period metrics if requested (final evaluation)
if not calibration_only and self.evaluation_period[0] and self.evaluation_period[1]:
eval_metrics = self._calculate_period_metrics(
obs_data, sim_data, self.evaluation_period, "Eval"
)
metrics_dict.update(eval_metrics)
# If no specific periods, calculate for full overlap (fallback)
if not metrics_dict:
full_metrics = self._calculate_period_metrics(obs_data, sim_data, (None, None), "")
metrics_dict.update(full_metrics)
return metrics_dict
except Exception as e: # noqa: BLE001 — must-not-raise contract
self.logger.error(f"Error calculating metrics for {self.__class__.__name__}: {str(e)}")
return None
[docs]
@abstractmethod
def get_simulation_files(self, sim_dir: Path) -> List[Path]:
"""Get relevant simulation output files for this target"""
pass
[docs]
@abstractmethod
def get_observed_data_path(self) -> Path:
"""Get path to observed data file"""
pass
[docs]
@abstractmethod
def needs_routing(self) -> bool:
"""Whether this target requires mizuRoute routing"""
pass
def _load_observed_data(self) -> Optional[pd.Series]:
"""Load observed data from file"""
try:
obs_path = self.get_observed_data_path()
return self._load_observed_data_from_path(obs_path)
except Exception as e: # noqa: BLE001 — must-not-raise contract
self.logger.error(f"Error loading observed data: {str(e)}")
return None
def _load_observed_data_from_path(self, obs_path: Path) -> Optional[pd.Series]:
"""Load observed data from a specific path."""
if not obs_path.exists():
self.logger.error(f"Observed data file not found: {obs_path}")
return None
# Model-ready NetCDF store
if obs_path.suffix == '.nc':
return self._load_observed_data_from_netcdf(obs_path)
# Try to read with index_col=0 first (handles GRACE/TWS files where date is first column)
try:
obs_df = pd.read_csv(obs_path, index_col=0)
obs_df.index = pd.to_datetime(obs_df.index, format='mixed')
# Check if index looks like dates
if isinstance(obs_df.index, pd.DatetimeIndex):
# Index is already a datetime, use it directly
data_col = self._get_observed_data_column(obs_df.columns)
if data_col:
self.logger.debug(f"Loaded {obs_path} with date index, data column: {data_col}")
return obs_df[data_col]
# Check for "Unnamed: 0" or numeric index that might be dates
if obs_df.index.name == 'Unnamed: 0' or obs_df.index.name is None:
try:
obs_df.index = pd.to_datetime(obs_df.index)
if isinstance(obs_df.index, pd.DatetimeIndex):
data_col = self._get_observed_data_column(obs_df.columns)
if data_col:
self.logger.debug(f"Loaded {obs_path} with parsed date index, data column: {data_col}")
return obs_df[data_col]
except (ValueError, TypeError):
pass # Index is not a date, fall through to standard parsing
except Exception as e: # noqa: BLE001 — must-not-raise contract
self.logger.debug(f"Could not parse {obs_path} with index_col=0: {e}")
# Fallback: Standard CSV read with explicit date column search
obs_df = pd.read_csv(obs_path)
# Find date and data columns
date_col = next((col for col in obs_df.columns
if any(term in col.lower() for term in ['date', 'time', 'datetime'])), None)
data_col = self._get_observed_data_column(obs_df.columns)
if not date_col or not data_col:
self.logger.error(f"Could not identify date/data columns in {obs_path}. "
f"Columns: {list(obs_df.columns)}")
return None
# Process data
obs_df['DateTime'] = pd.to_datetime(obs_df[date_col])
obs_df.set_index('DateTime', inplace=True)
return obs_df[data_col]
def _load_observed_data_from_netcdf(self, nc_path: Path) -> Optional[pd.Series]:
"""Load observed data from the model-ready grouped NetCDF store."""
try:
import xarray as xr
group = self._get_observation_group()
ds = xr.open_dataset(nc_path, group=group)
# Take the first non-coordinate data variable
data_vars = [v for v in ds.data_vars if v not in ('gauge_id', 'hru_id', 'station_id', 'basin_id')]
if not data_vars:
self.logger.error(f"No data variables in {nc_path} group {group}")
ds.close()
return None
var_name = data_vars[0]
da = ds[var_name]
# Collapse spatial dim if present (take first or squeeze)
spatial_dims = [d for d in da.dims if d != 'time']
if spatial_dims:
da = da.isel({spatial_dims[0]: 0})
series = da.to_series().dropna()
series.name = var_name
ds.close()
return series
except Exception as e: # noqa: BLE001 — must-not-raise contract
self.logger.error(f"Error reading NetCDF observations from {nc_path}: {e}")
return None
def _get_observation_group(self) -> str:
"""Return the NetCDF group name for this evaluator type.
Subclasses can override this to point to their observation group.
Default mapping is based on the class name.
"""
class_name = self.__class__.__name__.lower()
if 'streamflow' in class_name:
return 'streamflow'
elif 'snow' in class_name:
return 'snow'
elif 'et' in class_name or 'evapotranspiration' in class_name:
return 'et'
elif 'soil' in class_name:
return 'soil_moisture'
elif 'tws' in class_name or 'groundwater' in class_name:
return 'terrestrial_water_storage'
return 'streamflow'
def _validate_data(
self,
data: pd.Series,
data_name: str,
min_valid_points: int = 10
) -> Tuple[bool, Optional[str]]:
"""Validate data series for quality issues.
Performs comprehensive validation of input data before metric calculation
to catch common data quality issues early with clear error messages.
Validation Checks:
1. None check: Data must not be None
2. Empty check: Data must have at least one element
3. All-NaN check: Data must have at least one valid (non-NaN) value
4. Minimum points: Must have at least min_valid_points valid values
5. Constant warning: Logs warning if all values are identical
Args:
data: pandas Series to validate
data_name: Human-readable name for error messages (e.g., 'simulated', 'observed')
min_valid_points: Minimum number of non-NaN points required (default: 10)
Returns:
Tuple of (is_valid, error_message):
- (True, None) if data passes all checks
- (False, "error description") if validation fails
Example:
is_valid, error_msg = self._validate_data(sim_data, 'simulated')
if not is_valid:
self.logger.error(error_msg)
return None
"""
if data is None:
return False, f"{data_name} data is None"
if len(data) == 0:
return False, f"{data_name} data is empty"
valid_count = data.notna().sum()
if valid_count == 0:
return False, f"{data_name} data contains only NaN values"
if valid_count < min_valid_points:
return False, (
f"{data_name} data has insufficient valid points: "
f"{valid_count} < {min_valid_points}"
)
# Warn about constant data (some metrics undefined)
if data.dropna().nunique() == 1:
self.logger.warning(
f"{data_name} data is constant (all values = {data.dropna().iloc[0]:.4g}). "
"Some metrics (e.g., correlation, NSE) may be undefined."
)
return True, None
@abstractmethod
def _get_observed_data_column(self, columns: List[str]) -> Optional[str]:
"""Identify the data column in observed data file"""
pass
def _calculate_period_metrics(self, obs_data: pd.Series, sim_data: pd.Series,
period: Tuple, prefix: str) -> Dict[str, float]:
"""Calculate metrics for a specific time period with explicit filtering"""
try:
# Ensure indices are DatetimeIndex
if not isinstance(obs_data.index, pd.DatetimeIndex):
obs_data.index = pd.to_datetime(obs_data.index)
if not isinstance(sim_data.index, pd.DatetimeIndex):
sim_data.index = pd.to_datetime(sim_data.index)
# EXPLICIT filtering for both datasets (consistent with parallel worker)
# Round BOTH indices to ensure alignment (fixes misalignment with daily data)
if period[0] and period[1]:
# Round both observed and simulated indices consistently
obs_data_rounded = obs_data.copy()
obs_data_rounded.index = obs_data_rounded.index.round('h')
sim_data_rounded = sim_data.copy()
sim_data_rounded.index = sim_data_rounded.index.round('h')
# Filter observed data to period
obs_period_mask = (obs_data_rounded.index >= period[0]) & (obs_data_rounded.index <= period[1])
obs_period = obs_data_rounded[obs_period_mask].copy()
# Explicitly filter simulated data to same period (like parallel worker)
sim_period_mask = (sim_data_rounded.index >= period[0]) & (sim_data_rounded.index <= period[1])
sim_period = sim_data_rounded[sim_period_mask].copy()
# Apply spinup removal if configured
spinup_years = self._get_config_value(
lambda: self.config.evaluation.spinup_years,
default=0,
dict_key='EVALUATION_SPINUP_YEARS'
)
try:
spinup_years = int(float(spinup_years))
except (TypeError, ValueError):
spinup_years = 0
if spinup_years > 0 and not obs_period.empty:
cutoff = obs_period.index.min() + pd.DateOffset(years=spinup_years)
obs_period = obs_period[obs_period.index >= cutoff]
sim_period = sim_period[sim_period.index >= cutoff]
self.logger.debug(f"Applied {spinup_years} year spinup removal, cutoff: {cutoff}")
# Log filtering results for debugging
self.logger.debug(f"{prefix} period filtering: {period[0]} to {period[1]}")
self.logger.debug(f"{prefix} observed points in period: {len(obs_period)}")
self.logger.debug(f"{prefix} simulated points in period: {len(sim_period)}")
else:
# Round BOTH indices consistently for alignment
obs_period = obs_data.copy()
obs_period.index = obs_period.index.round('h')
sim_period = sim_data.copy()
sim_period.index = sim_period.index.round('h')
# ENHANCED: Normalize timezones before intersection to ensure match
# Some datasets come from NetCDF (UTC) while others from CSV (naive)
if obs_period.index.tz is not None:
obs_period.index = obs_period.index.tz_localize(None)
if sim_period.index.tz is not None:
sim_period.index = sim_period.index.tz_localize(None)
# Resample to evaluation timestep if specified in config
if self.eval_timestep != 'native':
self.logger.debug(f"Resampling data to {self.eval_timestep} timestep")
# DEBUG: Log before resampling
if len(obs_period) > 0 and len(sim_period) > 0:
self.logger.debug(f"BEFORE resample - obs: {obs_period.min():.3f} to {obs_period.max():.3f}, sim: {sim_period.min():.3f} to {sim_period.max():.3f}")
obs_period = self._resample_to_timestep(obs_period, self.eval_timestep)
sim_period = self._resample_to_timestep(sim_period, self.eval_timestep)
self.logger.debug(f"After resampling - obs points: {len(obs_period)}, sim points: {len(sim_period)}")
# DEBUG: Log after resampling
if len(obs_period) > 0 and len(sim_period) > 0:
self.logger.debug(f"AFTER resample - obs: {obs_period.min():.3f} to {obs_period.max():.3f}, sim: {sim_period.min():.3f} to {sim_period.max():.3f}")
# Final check: ensure both are midnight-aligned if daily
if self.eval_timestep == 'daily':
obs_period.index = obs_period.index.normalize()
sim_period.index = sim_period.index.normalize()
# Find common time indices
common_idx = obs_period.index.intersection(sim_period.index)
if len(common_idx) == 0:
self.logger.debug(f"No common time indices for {prefix} period")
if len(obs_period) > 0 and len(sim_period) > 0:
self.logger.debug(f"Obs index sample: {obs_period.index[0]} to {obs_period.index[-1]} (type: {obs_period.index.dtype})")
self.logger.debug(f"Sim index sample: {sim_period.index[0]} to {sim_period.index[-1]} (type: {sim_period.index.dtype})")
return {}
obs_common = obs_period.loc[common_idx]
sim_common = sim_period.loc[common_idx]
# Log final aligned data for debugging
self.logger.debug(f"{prefix} aligned data points: {len(common_idx)}")
self.logger.debug(f"{prefix} obs mean: {obs_common.mean():.4f}, range: {obs_common.min():.4f} to {obs_common.max():.4f}")
self.logger.debug(f"{prefix} sim mean: {sim_common.mean():.4f}, range: {sim_common.min():.4f} to {sim_common.max():.4f}")
# Calculate metrics
base_metrics = self._calculate_performance_metrics(obs_common, sim_common)
# Optionally compute log-likelihood if observation uncertainties are available
likelihood_metrics = self._calculate_likelihood_metrics(obs_common, sim_common)
if likelihood_metrics:
base_metrics.update(likelihood_metrics)
# Add prefix if specified
if prefix:
return {f"{prefix}_{k}": v for k, v in base_metrics.items()}
else:
return base_metrics
except Exception as e: # noqa: BLE001 — must-not-raise contract
self.logger.error(f"Error calculating period metrics: {str(e)}")
import traceback
self.logger.debug(traceback.format_exc())
return {}
def _resample_to_timestep(self, data: pd.Series, target_timestep: str) -> pd.Series:
"""
Resample time series data to target timestep.
Aggregation (fine → coarse) uses mean for state variables (temperature,
storage) and sum for flux variables (precipitation, ET).
Upsampling (coarse → fine) is rejected as it creates synthetic data
through interpolation, which is inappropriate for observations.
Args:
data: Time series data with DatetimeIndex
target_timestep: Target timestep ('hourly' or 'daily')
Returns:
Resampled time series
Raises:
ValueError: If upsampling is attempted (coarse to fine resolution)
"""
if target_timestep == 'native' or data is None or len(data) == 0:
return data
try:
# Infer current frequency
inferred_freq = pd.infer_freq(data.index)
if inferred_freq is None:
# Try to infer from first few differences
if len(data) > 1:
time_diff = data.index[1] - data.index[0]
self.logger.debug(f"Inferred time difference: {time_diff}")
else:
self.logger.warning("Cannot infer frequency from single data point")
return data
else:
self.logger.debug(f"Inferred frequency: {inferred_freq}")
# Determine current timestep
time_diff = data.index[1] - data.index[0] if len(data) > 1 else pd.Timedelta(hours=1)
# Check if already at target timestep
if target_timestep == 'hourly' and pd.Timedelta(minutes=45) <= time_diff <= pd.Timedelta(minutes=75):
self.logger.debug("Data already at hourly timestep")
return data
elif target_timestep == 'daily' and pd.Timedelta(hours=20) <= time_diff <= pd.Timedelta(hours=28):
self.logger.debug("Data already at daily timestep")
return data
# Determine aggregation function based on variable type
agg_func = 'sum' if self.variable_type == 'flux' else 'mean'
# Perform resampling
if target_timestep == 'hourly':
if time_diff < pd.Timedelta(hours=1):
# Aggregation: sub-hourly to hourly
self.logger.debug(f"Aggregating {time_diff} data to hourly using {agg_func}")
resampled = pd.Series(data.resample('h').agg(agg_func))
elif time_diff > pd.Timedelta(hours=1):
# Upsampling: daily/coarser to hourly - REJECT
raise ValueError(
f"Cannot upsample {time_diff} data to hourly: "
f"interpolation creates synthetic observations. "
f"Use native timestep or aggregated (coarser) timestep instead."
)
else:
resampled = data
elif target_timestep == 'daily':
if time_diff < pd.Timedelta(days=1):
# Aggregation: hourly/sub-daily to daily
self.logger.debug(f"Aggregating {time_diff} data to daily using {agg_func}")
resampled = pd.Series(data.resample('D').agg(agg_func))
elif time_diff > pd.Timedelta(days=1):
# Upsampling: weekly/monthly to daily - REJECT
raise ValueError(
f"Cannot upsample {time_diff} data to daily: "
f"interpolation creates synthetic observations."
)
else:
resampled = data
else:
resampled = data
# Remove any NaN values introduced by resampling at edges
resampled = resampled.dropna()
self.logger.debug(
f"Resampled from {len(data)} to {len(resampled)} points "
f"(target: {target_timestep}, agg: {agg_func})"
)
return resampled
except ValueError:
# Re-raise ValueError for upsampling rejection
raise
except Exception as e: # noqa: BLE001 — must-not-raise contract
self.logger.error(f"Error resampling to {target_timestep}: {str(e)}")
self.logger.warning("Returning original data without resampling")
return data
def _calculate_likelihood_metrics(
self,
obs_common: pd.Series,
sim_common: pd.Series,
) -> Dict[str, float]:
"""
Optionally compute Gaussian log-likelihood using observation uncertainties.
This is activated when the config specifies LIKELIHOOD_FUNCTION (e.g., 'gaussian').
Observation uncertainties are loaded from the flux data file (_uc columns)
and combined with model error to form the total error variance.
Returns empty dict if likelihood mode is not enabled or uncertainties
are unavailable, preserving backward compatibility.
"""
likelihood_function = self._get_config_value(
lambda: self.config.optimization.likelihood_function,
default='',
dict_key='LIKELIHOOD_FUNCTION'
)
if not likelihood_function:
return {}
try:
from symfluence.evaluation.likelihood import gaussian_log_likelihood
# Load observation uncertainty aligned to the common index
obs_uc = self._load_observation_uncertainty(obs_common.index)
# Get model error configuration
model_error_fraction = self._get_config_value(
lambda: self.config.optimization.model_error_fraction,
default=0.0,
dict_key='MODEL_ERROR_FRACTION'
)
try:
model_error_fraction = float(model_error_fraction)
except (TypeError, ValueError):
model_error_fraction = 0.0
model_error_base = self._get_config_value(
lambda: self.config.optimization.model_error_base,
default=0.0,
dict_key='MODEL_ERROR_BASE'
)
try:
model_error_base = float(model_error_base)
except (TypeError, ValueError):
model_error_base = 0.0
obs_arr = obs_common.values.astype(np.float64)
sim_arr = sim_common.values.astype(np.float64)
# Build model error: sigma_model = base + fraction * |sim|
sigma_model = None
if model_error_base > 0 or model_error_fraction > 0:
sigma_model = model_error_base + model_error_fraction * np.abs(sim_arr)
# Observation uncertainty array (or None)
obs_uc_arr = None
if obs_uc is not None and len(obs_uc) == len(obs_arr):
obs_uc_arr = obs_uc.values.astype(np.float64)
log_lik = gaussian_log_likelihood(
obs_arr, sim_arr,
obs_uncertainty=obs_uc_arr,
model_error=sigma_model,
)
return {'log_likelihood': log_lik}
except Exception as e: # noqa: BLE001 — optional feature, must not break metrics
self.logger.debug(f"Likelihood computation skipped: {e}")
return {}
def _load_observation_uncertainty(
self, time_index: pd.DatetimeIndex
) -> Optional[pd.Series]:
"""
Load observation uncertainty data aligned to the given time index.
Override in subclasses that know how to locate uncertainty data
(e.g., FLUXNET _uc columns). Default returns None (no uncertainty).
"""
return None
def _calculate_performance_metrics(self, observed: pd.Series, simulated: pd.Series) -> Dict[str, float]:
"""Calculate performance metrics between observed and simulated data"""
try:
# Clean data
observed = pd.to_numeric(observed, errors='coerce')
simulated = pd.to_numeric(simulated, errors='coerce')
# Use centralized metrics module for all calculations
result = metrics.calculate_all_metrics(observed, simulated)
# Return subset of metrics for compatibility
return {
'KGE': result['KGE'],
'NSE': result['NSE'],
'RMSE': result['RMSE'],
'PBIAS': result['PBIAS'],
'MAE': result['MAE'],
'correlation': result['correlation'],
'r': result['r'],
'alpha': result['alpha'],
'beta': result['beta']
}
except Exception as e: # noqa: BLE001 — must-not-raise contract
self.logger.error(f"Error calculating performance metrics: {str(e)}")
return {
'KGE': np.nan,
'NSE': np.nan,
'RMSE': np.nan,
'PBIAS': np.nan,
'MAE': np.nan,
'correlation': np.nan,
}
def _parse_date_range(self, date_range_str: str) -> Tuple[Optional[pd.Timestamp], Optional[pd.Timestamp]]:
"""Parse date range string from config"""
if not date_range_str:
return None, None
try:
dates = [d.strip() for d in date_range_str.split(',')]
if len(dates) >= 2:
return pd.Timestamp(dates[0]), pd.Timestamp(dates[1])
except Exception as e: # noqa: BLE001 — must-not-raise contract
self.logger.warning(f"Could not parse date range '{date_range_str}': {str(e)}")
return None, None
[docs]
def align_series(self, sim: pd.Series, obs: pd.Series) -> Tuple[pd.Series, pd.Series]:
"""Align simulation and observation series after dropping spinup years."""
spinup_years = self._get_config_value(
lambda: self.config.evaluation.spinup_years,
default=0,
dict_key='EVALUATION_SPINUP_YEARS'
)
try:
spinup_years = int(float(spinup_years))
except (TypeError, ValueError):
spinup_years = 0
spinup_years = max(0, spinup_years)
if sim.empty or obs.empty:
return sim, obs
common_start = max(sim.index.min(), obs.index.min())
cutoff = common_start + pd.DateOffset(years=spinup_years) if spinup_years else common_start
sim_trimmed = sim[sim.index >= cutoff]
obs_trimmed = obs[obs.index >= cutoff]
common_idx = sim_trimmed.index.intersection(obs_trimmed.index)
if not common_idx.empty:
sim_trimmed = sim_trimmed.loc[common_idx]
obs_trimmed = obs_trimmed.loc[common_idx]
else:
self.logger.warning("No overlapping indices after alignment; returning trimmed series")
return sim_trimmed, obs_trimmed
def _collapse_spatial_dims(self, data_array: xr.DataArray, aggregate: str = 'mean') -> pd.Series:
"""
Collapse spatial dimensions from xarray DataArray to pandas Series.
Handles common spatial dimension patterns in SUMMA/FUSE/NGEN output:
- Single HRU/GRU: select index 0
- Multiple HRU/GRU: aggregate (mean by default)
- Other spatial dims: select first or aggregate
Args:
data_array: xarray DataArray with time and possibly spatial dimensions
aggregate: Aggregation method for multiple spatial units ('mean', 'sum', 'first')
Returns:
pandas Series with time index
"""
spatial_dims = ['hru', 'gru', 'param_set', 'latitude', 'longitude', 'seg', 'reachID']
result = data_array
for dim in spatial_dims:
if dim in result.dims:
dim_size = result.shape[result.dims.index(dim)]
if dim_size == 1:
result = result.isel({dim: 0})
elif aggregate == 'mean':
result = result.mean(dim=dim)
elif aggregate == 'sum':
result = result.sum(dim=dim)
elif aggregate == 'first':
result = result.isel({dim: 0})
# Handle any remaining non-time dimensions
non_time_dims = [dim for dim in result.dims if dim != 'time']
for dim in non_time_dims:
dim_size = result.shape[result.dims.index(dim)]
if dim_size == 1:
result = result.isel({dim: 0})
elif aggregate == 'mean':
result = result.mean(dim=dim)
elif aggregate == 'sum':
result = result.sum(dim=dim)
else:
result = result.isel({dim: 0})
return cast(pd.Series, result.to_pandas())
def _find_date_column(self, columns: List[str]) -> Optional[str]:
"""
Find timestamp/date column in a DataFrame.
Searches for common date column names used across different data sources.
Args:
columns: List of column names from DataFrame
Returns:
Name of date column, or None if not found
"""
# Priority order for timestamp column candidates
timestamp_candidates = [
'timestamp', 'TIMESTAMP_START', 'TIMESTAMP_END',
'datetime', 'DateTime', 'time', 'Time',
'date', 'Date', 'DATE'
]
# First check exact matches
for candidate in timestamp_candidates:
if candidate in columns:
return candidate
# Then check partial matches
for col in columns:
col_lower = col.lower()
if any(term in col_lower for term in ['timestamp', 'datetime', 'date', 'time']):
return col
return None