diff --git a/dataframely/_native.pyi b/dataframely/_native.pyi index 1d670c39..2fea95da 100644 --- a/dataframely/_native.pyi +++ b/dataframely/_native.pyi @@ -1,6 +1,9 @@ from typing import overload -def format_rule_failures(failures: list[tuple[str, int]]) -> str: +def format_rule_failures( + failures: list[tuple[str, int]], + examples: dict[str, list[str]] | None = None, +) -> str: """ Format rule failures with the same logic that produces validation errors from the polars plugin. @@ -8,6 +11,9 @@ def format_rule_failures(failures: list[tuple[str, int]]) -> str: Args: failures: The name of the failures and their counts. This should only include failures with a count of at least 1. + examples: Optional mapping from rule name to a list of example row strings. + When provided, up to ``len(examples[rule])`` distinct examples are included + in the formatted message for each rule. Returns: The formatted rule failures. diff --git a/dataframely/_plugin.py b/dataframely/_plugin.py index 7616af39..4bab60b2 100644 --- a/dataframely/_plugin.py +++ b/dataframely/_plugin.py @@ -58,6 +58,7 @@ def all_rules_required( *, null_is_valid: bool = True, schema_name: str, + data_columns: Iterable[IntoExpr] | None = None, ) -> pl.Expr: """Execute :mod:`~polars.all_horizontal` and `.all` for a set of rules. @@ -70,15 +71,25 @@ def all_rules_required( schema_name: The name of the schema being validated. This is used to produce better error messages. null_is_valid: Whether to treat null values as valid (i.e., `true`). + data_columns: Optional data columns to include for generating example rows in + error messages. If provided, up to 5 distinct example rows are included + for each failing rule. Returns: A scalar boolean expression. """ + rules_list = [rules] if isinstance(rules, pl.Expr) else list(rules) + num_rule_columns = len(rules_list) + data_columns_list = list(data_columns) if data_columns is not None else [] return register_plugin_function( plugin_path=PLUGIN_PATH, function_name="all_rules_required", - args=rules, - kwargs={"null_is_valid": null_is_valid, "schema_name": schema_name}, + args=[*rules_list, *data_columns_list], + kwargs={ + "null_is_valid": null_is_valid, + "schema_name": schema_name, + "num_rule_columns": num_rule_columns, + }, use_abs_path=True, returns_scalar=True, ) diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index 665b85ae..c822fb68 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -410,7 +410,9 @@ def validate( filtered, failures = cls.filter(data, cast=cast, eager=True) if any(len(failure) > 0 for failure in failures.values()): errors = { - member: format_rule_failures(list(failure.counts().items())) + member: format_rule_failures( + list(failure.counts().items()), failure.examples() + ) for member, failure in failures.items() if len(failure) > 0 } diff --git a/dataframely/filter_result.py b/dataframely/filter_result.py index f2f88b75..469d580f 100644 --- a/dataframely/filter_result.py +++ b/dataframely/filter_result.py @@ -146,6 +146,22 @@ def counts(self) -> dict[str, int]: """ return _compute_counts(self._df, self._rule_columns) + def examples(self, max_examples: int = 5) -> dict[str, list[str]]: + """Example rows for each failing rule. + + For each rule that has at least one failure, returns up to `max_examples` + distinct example rows (as formatted strings) from the original data columns. + + Args: + max_examples: The maximum number of distinct example rows to return per + rule. + + Returns: + A mapping from rule name to a list of example row strings. Rules with no + failures are not included. + """ + return _compute_examples(self._df, self._rule_columns, max_examples) + def cooccurrence_counts(self) -> dict[frozenset[str], int]: """The number of validation failures per co-occurring rule validation failure. @@ -409,6 +425,28 @@ def _compute_counts(df: pl.DataFrame, rule_columns: list[str]) -> dict[str, int] } +def _compute_examples( + df: pl.DataFrame, rule_columns: list[str], max_examples: int +) -> dict[str, list[str]]: + if len(rule_columns) == 0: + return {} + + data_columns = [c for c in df.columns if c not in rule_columns] + if not data_columns: + return {} + + result = {} + for rule_name in rule_columns: + failing = df.filter(pl.col(rule_name).not_()) + if len(failing) == 0: + continue + examples_df = ( + failing.select(data_columns).unique(maintain_order=True).head(max_examples) + ) + result[rule_name] = [str(row) for row in examples_df.to_dicts()] + return result + + def _compute_cooccurrence_counts( df: pl.DataFrame, rule_columns: list[str] ) -> dict[frozenset[str], int]: diff --git a/dataframely/schema.py b/dataframely/schema.py index b64f0616..d1aa0834 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -573,7 +573,9 @@ def validate( out, failure = cls.filter(df, cast=cast, eager=True) if len(failure) > 0: raise ValidationError( - format_rule_failures(list(failure.counts().items())) + format_rule_failures( + list(failure.counts().items()), failure.examples() + ) ) return out else: @@ -583,7 +585,13 @@ def validate( if rules := cls._validation_rules(with_cast=False): lf = ( lf.pipe(with_evaluation_rules, rules) - .filter(all_rules_required(rules.keys(), schema_name=cls.__name__)) + .filter( + all_rules_required( + rules.keys(), + schema_name=cls.__name__, + data_columns=cls.column_names(), + ) + ) .drop(rules.keys()) ) return lf # type: ignore diff --git a/src/polars_plugin/mod.rs b/src/polars_plugin/mod.rs index 06cb9b99..4d20a9c7 100644 --- a/src/polars_plugin/mod.rs +++ b/src/polars_plugin/mod.rs @@ -2,6 +2,8 @@ mod rule_failure; mod utils; mod validation_error; +use std::collections::{HashMap, HashSet}; + use polars::prelude::*; use polars_core::POOL; use pyo3_polars::derive::polars_expr; @@ -62,18 +64,71 @@ pub fn all_rules(inputs: &[Series]) -> PolarsResult { struct RequiredValidationKwargs { schema_name: String, null_is_valid: bool, + #[serde(default)] + num_rule_columns: Option, +} + +/// The maximum number of distinct example rows included in validation error messages. +const MAX_EXAMPLES: usize = 5; + +/// Format a single data row (at `row_idx`) from the given data series as a Python-like dict string. +fn format_example_row(data_series: &[Series], row_idx: usize) -> String { + let kvs: Vec = data_series + .iter() + .map(|s| { + let val = s.get(row_idx).unwrap_or(AnyValue::Null); + format!("'{}': {}", s.name(), val) + }) + .collect(); + format!("{{{}}}", kvs.join(", ")) +} + +/// Compute up to `max_examples` distinct example rows for a failing rule. +fn compute_examples( + bool_ca: &BooleanChunked, + null_is_valid: bool, + data_series: &[Series], + max_examples: usize, +) -> Vec { + let mut seen: HashSet = HashSet::new(); + let mut examples: Vec = Vec::new(); + + for (i, val) in bool_ca.iter().enumerate() { + let is_failure = match val { + Some(false) => true, + None => !null_is_valid, + _ => false, + }; + if is_failure { + let row_str = format_example_row(data_series, i); + if seen.insert(row_str.clone()) { + examples.push(row_str); + if examples.len() >= max_examples { + break; + } + } + } + } + + examples } /// Reduce a set of boolean columns into a single boolean scalar, AND-ing all values. /// Null values are treated as `true`. /// In contrast to `all_rules`, this function raises an error if the returned value would be /// `false`, including details about the `false` values (i.e. "rules" that failed). +/// The first `num_rule_columns` inputs are boolean rule columns; any remaining inputs are +/// data columns used to generate example rows in error messages. #[polars_expr(output_type=Boolean)] pub fn all_rules_required( inputs: &[Series], kwargs: RequiredValidationKwargs, ) -> PolarsResult { - let failures = compute_rule_failures(inputs, kwargs.null_is_valid)?; + let num_rule = kwargs.num_rule_columns.unwrap_or(inputs.len()); + let rule_inputs = &inputs[..num_rule]; + let data_inputs = &inputs[num_rule..]; + + let failures = compute_rule_failures(rule_inputs, kwargs.null_is_valid)?; // If there's any failure, we know that validation failed and use the failure object for an // informative error message. If no failure exists, we simply return a series with a single @@ -84,7 +139,26 @@ pub fn all_rules_required( return Ok(BooleanChunked::new(PlSmallStr::EMPTY, [true]).into_series()); } + // Compute examples for each failing rule using the data columns. + let examples: HashMap> = if data_inputs.is_empty() { + HashMap::new() + } else { + failures + .iter() + .map(|failure| { + let rule_series = rule_inputs + .iter() + .find(|s| s.name().as_str() == failure.rule) + .expect("failing rule not found in inputs"); + let bool_ca = as_bool(rule_series)?; + let examples = + compute_examples(bool_ca, kwargs.null_is_valid, data_inputs, MAX_EXAMPLES); + Ok((failure.rule.to_string(), examples)) + }) + .collect::>>()? + }; + // Aggregate failure counts into a validation error. let error = RuleValidationError::new(failures); - Err(polars_err!(ComputeError: format!("\n{}", error.to_string(Some(&kwargs.schema_name))))) + Err(polars_err!(ComputeError: format!("\n{}", error.to_string(Some(&kwargs.schema_name), Some(&examples))))) } diff --git a/src/polars_plugin/validation_error.rs b/src/polars_plugin/validation_error.rs index b2ca7187..c82e2c4e 100644 --- a/src/polars_plugin/validation_error.rs +++ b/src/polars_plugin/validation_error.rs @@ -2,6 +2,7 @@ use itertools::Itertools; use num_format::{Locale, ToFormattedString}; use polars::prelude::*; use pyo3::{create_exception, exceptions::PyException, prelude::*}; +use std::collections::HashMap; use super::RuleFailure; @@ -39,7 +40,11 @@ impl<'a> RuleValidationError<'a> { } } - pub fn to_string(&self, schema: Option<&str>) -> String { + pub fn to_string( + &self, + schema: Option<&str>, + examples: Option<&HashMap>>, + ) -> String { let mut result = if let Some(schema) = schema { format!( "{} rules failed validation for schema '{schema}':", @@ -49,10 +54,12 @@ impl<'a> RuleValidationError<'a> { format!("{} rules failed validation:", self.num_rule_failures) }; self.schema_errors.iter().for_each(|failure| { + let examples_str = format_examples(failure.rule, examples); result += format!( - "\n - '{}' failed for {} rows", + "\n - '{}' failed for {} rows{}", failure.rule, - failure.count.to_formatted_string(&Locale::en) + failure.count.to_formatted_string(&Locale::en), + examples_str, ) .as_str(); }); @@ -63,10 +70,13 @@ impl<'a> RuleValidationError<'a> { ) .as_str(); errors.iter().for_each(|failure| { + let full_rule = format!("{}|{}", column, failure.rule); + let examples_str = format_examples(&full_rule, examples); result += format!( - "\n - '{}' failed for {} rows", + "\n - '{}' failed for {} rows{}", failure.rule, - failure.count.to_formatted_string(&Locale::en) + failure.count.to_formatted_string(&Locale::en), + examples_str, ) .as_str(); }); @@ -75,8 +85,26 @@ impl<'a> RuleValidationError<'a> { } } +fn format_examples(rule: &str, examples: Option<&HashMap>>) -> String { + match examples.and_then(|ex| ex.get(rule)) { + Some(ex) if !ex.is_empty() => { + let suffix = if ex.len() == 1 { + "example".to_string() + } else { + "examples".to_string() + }; + format!(" with {} distinct {}: [{}]", ex.len(), suffix, ex.join(", ")) + } + _ => String::new(), + } +} + #[pyfunction] -pub fn format_rule_failures(failures: Vec<(String, IdxSize)>) -> String { +#[pyo3(signature = (failures, examples=None))] +pub fn format_rule_failures( + failures: Vec<(String, IdxSize)>, + examples: Option>>, +) -> String { let validation_error = RuleValidationError::new( failures .iter() @@ -86,5 +114,5 @@ pub fn format_rule_failures(failures: Vec<(String, IdxSize)>) -> String { }) .collect(), ); - return validation_error.to_string(None); + return validation_error.to_string(None, examples.as_ref()); } diff --git a/tests/schema/test_validate.py b/tests/schema/test_validate.py index d4967273..3d5a2b82 100644 --- a/tests/schema/test_validate.py +++ b/tests/schema/test_validate.py @@ -130,8 +130,9 @@ def test_invalid_primary_key( with pytest.raises( ValidationError if eager else plexc.ComputeError, match=r"1 rules failed validation", - ): + ) as exc_info: _validate_and_collect(MySchema, df, eager=eager) + exc_info.match(r"with 2 distinct examples") assert not MySchema.is_valid(df)