Source code for dataframe_expectations.core.expectation
from abc import ABC, abstractmethod
from typing import List, Optional, cast
from pandas import DataFrame as PandasDataFrame
from pyspark.sql import DataFrame as PySparkDataFrame
# Import the connect DataFrame type for Spark Connect
try:
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
except ImportError:
# Fallback for older PySpark versions that don't have connect
PySparkConnectDataFrame = None # type: ignore[misc,assignment]
from dataframe_expectations.core.types import DataFrameLike, DataFrameType
from dataframe_expectations.core.tagging import TagSet
from dataframe_expectations.result_message import (
DataFrameExpectationResultMessage,
)
[docs]
class DataFrameExpectation(ABC):
"""
Base class for DataFrame expectations.
"""
[docs]
def __init__(self, tags: Optional[List[str]] = None):
"""
Initialize the base expectation with optional tags.
:param tags: Optional tags as list of strings in "key:value" format.
Example: ["priority:high", "env:test"]
"""
self.__tags = TagSet(tags)
[docs]
def get_expectation_name(self) -> str:
"""
Returns the class name as the expectation name.
"""
return type(self).__name__
[docs]
@abstractmethod
def get_description(self) -> str:
"""
Returns a description of the expectation.
"""
raise NotImplementedError(
f"description method must be implemented for {self.__class__.__name__}"
)
def __str__(self):
"""
Returns a string representation of the expectation.
"""
return f"{self.get_expectation_name()} ({self.get_description()})"
[docs]
@classmethod
def infer_data_frame_type(cls, data_frame: DataFrameLike) -> DataFrameType:
"""
Infer the DataFrame type based on the provided DataFrame.
"""
match data_frame:
case PandasDataFrame():
return DataFrameType.PANDAS
case PySparkDataFrame():
return DataFrameType.PYSPARK
case _ if PySparkConnectDataFrame is not None and isinstance(
data_frame, PySparkConnectDataFrame
):
return DataFrameType.PYSPARK
case _:
raise ValueError(f"Unsupported DataFrame type: {type(data_frame)}")
[docs]
def validate(self, data_frame: DataFrameLike, **kwargs):
"""
Validate the DataFrame against the expectation.
"""
data_frame_type = self.infer_data_frame_type(data_frame)
match data_frame_type:
case DataFrameType.PANDAS:
return self.validate_pandas(data_frame=data_frame, **kwargs)
case DataFrameType.PYSPARK:
return self.validate_pyspark(data_frame=data_frame, **kwargs)
case _:
raise ValueError(f"Unsupported DataFrame type: {data_frame_type}")
[docs]
@abstractmethod
def validate_pandas(
self, data_frame: DataFrameLike, **kwargs
) -> DataFrameExpectationResultMessage:
"""
Validate a pandas DataFrame against the expectation.
"""
raise NotImplementedError(
f"validate_pandas method must be implemented for {self.__class__.__name__}"
)
[docs]
@abstractmethod
def validate_pyspark(
self, data_frame: DataFrameLike, **kwargs
) -> DataFrameExpectationResultMessage:
"""
Validate a PySpark DataFrame against the expectation.
"""
raise NotImplementedError(
f"validate_pyspark method must be implemented for {self.__class__.__name__}"
)
[docs]
@classmethod
def num_data_frame_rows(cls, data_frame: DataFrameLike) -> int:
"""
Count the number of rows in the DataFrame.
"""
data_frame_type = cls.infer_data_frame_type(data_frame)
if data_frame_type == DataFrameType.PANDAS:
# Cast to PandasDataFrame since we know it's a Pandas DataFrame at this point
return len(cast(PandasDataFrame, data_frame))
elif data_frame_type == DataFrameType.PYSPARK:
# Cast to PySparkDataFrame since we know it's a PySpark DataFrame at this point
return cast(PySparkDataFrame, data_frame).count()
else:
raise ValueError(f"Unsupported DataFrame type: {data_frame_type}")