Source code for dataframe_expectations.core.column_expectation

from typing import Callable, List, Optional

from dataframe_expectations.core.types import DataFrameLike, DataFrameType
from dataframe_expectations.core.expectation import DataFrameExpectation
from dataframe_expectations.result_message import (
    DataFrameExpectationFailureMessage,
    DataFrameExpectationResultMessage,
    DataFrameExpectationSuccessMessage,
)


[docs] class DataFrameColumnExpectation(DataFrameExpectation): """ Base class for DataFrame column expectations. This class is designed to validate a specific column in a DataFrame against a condition defined by `fn_violations_pandas` and `fn_violations_pyspark` functions."""
[docs] def __init__( self, expectation_name: str, column_name: str, fn_violations_pandas: Callable, fn_violations_pyspark: Callable, description: str, error_message: str, tags: Optional[List[str]] = None, ): """ Template for implementing DataFrame column expectations, where a column value is tested against a condition. The conditions are defined by the `fn_violations_pandas` and `fn_violations_pyspark` functions. :param expectation_name: The name of the expectation. This will be used during logging. :param column_name: The name of the column to check. :param fn_violations_pandas: Function to find violations in a pandas DataFrame. :param fn_violations_pyspark: Function to find violations in a PySpark DataFrame. :param description: A description of the expectation used in logging. :param error_message: The error message to return if the expectation fails. :param tags: Optional tags as list of strings in "key:value" format. Example: ["priority:high", "env:test"] """ super().__init__(tags=tags) self.column_name = column_name self.expectation_name = expectation_name self.fn_violations_pandas = fn_violations_pandas self.fn_violations_pyspark = fn_violations_pyspark self.description = description self.error_message = error_message
[docs] def get_expectation_name(self) -> str: """ Returns the expectation name. """ return self.expectation_name
[docs] def get_description(self) -> str: """ Returns a description of the expectation. """ return self.description
[docs] def row_validation( self, data_frame_type: DataFrameType, data_frame: DataFrameLike, fn_violations: Callable, **kwargs, ) -> DataFrameExpectationResultMessage: """ Validate the DataFrame against the expectation. :param data_frame_type: The type of DataFrame (Pandas or PySpark). :param data_frame: The DataFrame to validate. :param fn_violations: The function to find violations. :return: ExpectationResultMessage indicating success or failure. """ if self.column_name not in data_frame.columns: return DataFrameExpectationFailureMessage( expectation_str=str(self), data_frame_type=data_frame_type, message=f"Column '{self.column_name}' does not exist in the DataFrame.", ) violations = fn_violations(data_frame) # calculate number of violations based on DataFrame type num_violations = self.num_data_frame_rows(violations) if num_violations == 0: return DataFrameExpectationSuccessMessage(expectation_name=self.get_expectation_name()) return DataFrameExpectationFailureMessage( expectation_str=str(self), data_frame_type=data_frame_type, violations_data_frame=violations, message=f"Found {num_violations} row(s) where {self.error_message}", )
[docs] def validate_pandas( self, data_frame: DataFrameLike, **kwargs ) -> DataFrameExpectationResultMessage: return self.row_validation( data_frame_type=DataFrameType.PANDAS, data_frame=data_frame, fn_violations=self.fn_violations_pandas, **kwargs, )
[docs] def validate_pyspark( self, data_frame: DataFrameLike, **kwargs ) -> DataFrameExpectationResultMessage: return self.row_validation( data_frame_type=DataFrameType.PYSPARK, data_frame=data_frame, fn_violations=self.fn_violations_pyspark, **kwargs, )