2025-04-24 18:16:11 +05:30

119 lines
3.2 KiB
Python

from typing import Callable, List
import pytest
from pandas import DataFrame, Series
class Assumption:
def __init__(self):
self.not_ = False
def validate(self, df: DataFrame):
if self.not_:
self.assume_false(df)
else:
self.assume_positive(df)
def __invert__(self):
return self.copy(not_=not self.not_)
def assume_positive(self, df: DataFrame):
raise NotImplementedError
def assume_false(self, df: DataFrame):
with pytest.raises(AssertionError):
self.assume_positive(df)
def copy(self, **kwargs):
new = self.__class__.__new__(self.__class__)
new.__dict__.update(self.__dict__)
new.__dict__.update(kwargs)
return new
class AssumeUnqiue(Assumption):
def __init__(self, column: str):
super().__init__()
self.column = column
def assume_positive(self, df: DataFrame):
assert df[self.column].nunique() == len(df)
class AssumeNotNull(Assumption):
def __init__(self, column: str):
super().__init__()
self.column = column
def assume_positive(self, df: DataFrame):
assert df[self.column].notnull().all()
class AssumeMatchRegex(Assumption):
def __init__(self, column: str, regex: str):
super().__init__()
self.column = column
self.regex = regex
def assume_positive(self, df: DataFrame):
assert df[self.column].str.match(self.regex).all()
class AssumeBetween(Assumption):
def __init__(self, column: str, min_value: int, max_value: int):
super().__init__()
self.column = column
self.min_value = min_value
self.max_value = max_value
def assume_positive(self, df: DataFrame):
assert (df[self.column] >= self.min_value).all()
assert (df[self.column] <= self.max_value).all()
class AssumeLengthBetween(Assumption):
def __init__(self, column: str, min_length: int, max_length: int):
super().__init__()
self.column = column
self.min_length = min_length
self.max_length = max_length
def assume_positive(self, df: DataFrame):
assert (df[self.column].str.len() >= self.min_length).all()
assert (df[self.column].str.len() <= self.max_length).all()
class AssumeColumnValuesIn(Assumption):
def __init__(self, column: str, allowed_values: List[str]):
super().__init__()
self.column = column
self.allowed_values = allowed_values
def assume_positive(self, df: DataFrame):
assert df[self.column].isin(self.allowed_values).all()
class AssumeArbitrary(Assumption):
def __init__(self, column: str, fn: Callable[[Series], Series]):
super().__init__()
self.column = column
self.fn = fn
def assume_positive(self, df: DataFrame):
assert self.fn(
df[self.column]
).all(), f"failed test {self.__class__.__name__} for column {self.column}"
class Assumptions:
notnull = AssumeNotNull
unique = AssumeUnqiue
match_regex = AssumeMatchRegex
between = AssumeBetween
length_between = AssumeLengthBetween
column_values_in = AssumeColumnValuesIn
arbitrary = AssumeArbitrary
assume = Assumptions()