Source code for dataframe_expectations.core.aggregation_expectation
from abc import abstractmethod
from typing import List, Optional, Union
from dataframe_expectations.core.types import DataFrameLike, DataFrameType
from dataframe_expectations.core.expectation import DataFrameExpectation
from dataframe_expectations.result_message import (
DataFrameExpectationFailureMessage,
DataFrameExpectationResultMessage,
)
[docs]
class DataFrameAggregationExpectation(DataFrameExpectation):
"""
Base class for DataFrame aggregation expectations.
This class is designed to first aggregate data and then validate the aggregation results.
"""
[docs]
def __init__(
self,
expectation_name: str,
column_names: List[str],
description: str,
tags: Optional[List[str]] = None,
):
"""
Template for implementing DataFrame aggregation expectations, where data is first aggregated
and then the aggregation results are validated.
:param expectation_name: The name of the expectation. This will be used during logging.
:param column_names: The list of column names to aggregate on.
:param description: A description of the expectation used in logging.
:param tags: Optional tags as list of strings in "key:value" format.
Example: ["priority:high", "env:test"]
"""
super().__init__(tags=tags)
self.expectation_name = expectation_name
self.column_names = column_names
self.description = description
[docs]
def get_expectation_name(self) -> str:
"""
Returns the expectation name.
"""
return self.expectation_name
[docs]
def get_description(self) -> str:
"""
Returns a description of the expectation.
"""
return self.description
[docs]
@abstractmethod
def aggregate_and_validate_pandas(
self, data_frame: DataFrameLike, **kwargs
) -> DataFrameExpectationResultMessage:
"""
Aggregate and validate a pandas DataFrame against the expectation.
Note: This method should NOT check for column existence - that's handled
automatically by the validate_pandas method.
"""
raise NotImplementedError(
f"aggregate_and_validate_pandas method must be implemented for {self.__class__.__name__}"
)
[docs]
@abstractmethod
def aggregate_and_validate_pyspark(
self, data_frame: DataFrameLike, **kwargs
) -> DataFrameExpectationResultMessage:
"""
Aggregate and validate a PySpark DataFrame against the expectation.
Note: This method should NOT check for column existence - that's handled
automatically by the validate_pyspark method.
"""
raise NotImplementedError(
f"aggregate_and_validate_pyspark method must be implemented for {self.__class__.__name__}"
)
[docs]
def validate_pandas(
self, data_frame: DataFrameLike, **kwargs
) -> DataFrameExpectationResultMessage:
"""
Validate a pandas DataFrame against the expectation.
Automatically checks column existence before calling the implementation.
"""
# Check if all required columns exist
column_error = self._check_columns_exist(data_frame)
if column_error:
return DataFrameExpectationFailureMessage(
expectation_str=str(self),
data_frame_type=DataFrameType.PANDAS,
message=column_error,
)
# Call the implementation-specific validation
return self.aggregate_and_validate_pandas(data_frame, **kwargs)
[docs]
def validate_pyspark(
self, data_frame: DataFrameLike, **kwargs
) -> DataFrameExpectationResultMessage:
"""
Validate a PySpark DataFrame against the expectation.
Automatically checks column existence before calling the implementation.
"""
# Check if all required columns exist
column_error = self._check_columns_exist(data_frame)
if column_error:
return DataFrameExpectationFailureMessage(
expectation_str=str(self),
data_frame_type=DataFrameType.PYSPARK,
message=column_error,
)
# Call the implementation-specific validation
return self.aggregate_and_validate_pyspark(data_frame, **kwargs)
def _check_columns_exist(self, data_frame: DataFrameLike) -> Union[str, None]:
"""
Check if all required columns exist in the DataFrame.
Returns error message if columns are missing, None otherwise.
"""
# Skip column check if no columns are required (e.g., for DataFrame-level expectations)
if not self.column_names:
return None
missing_columns = [col for col in self.column_names if col not in data_frame.columns]
if missing_columns:
if len(missing_columns) == 1:
return f"Column '{missing_columns[0]}' does not exist in the DataFrame."
else:
missing_columns_str = ", ".join([f"'{col}'" for col in missing_columns])
return f"Columns [{missing_columns_str}] do not exist in the DataFrame."
return None