Skip to content

Commit

Permalink
remove dependency on six, drop support for python 3.7, update dev dep…
Browse files Browse the repository at this point in the history
…dendencies

update pytest describe

add flexible terminal string formatter

start swapping in default formats

start phasing out underline_cells

wip

wip

demonstrate how custom formatting can be injected via pytest fixtures

wip

make sure format_string works with no formats

add new custom formatter to readme

bump to v0.10.0

New Chispa interface (#94)

* add formats to dataframe comparer

* add new chispa interface

lock

run tests for multiple python versions

small fix

add runs-on argument

fix

reset ci
  • Loading branch information
fpgmaas committed Jul 14, 2024
1 parent 663dea1 commit d5f4c01
Show file tree
Hide file tree
Showing 22 changed files with 684 additions and 1,372 deletions.
49 changes: 49 additions & 0 deletions .github/actions/setup-poetry-env/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: "setup-poetry-env"
description: "Composite action to setup the Python and poetry environment."

inputs:
python-version:
required: false
description: "The python version to use"
default: "3.11"

runs:
using: "composite"
steps:
- name: Set up python
uses: actions/setup-python@v4
with:
python-version: ${{ inputs.python-version }}

- name: Install Poetry
env:
# renovate: datasource=pypi depName=poetry
POETRY_VERSION: "1.5.1"
run: curl -sSL https://install.python-poetry.org | python - -y
shell: bash

- name: Add Poetry to Path
run: echo "$HOME/.local/bin" >> $GITHUB_PATH
if: ${{ matrix.os != 'Windows' }}
shell: bash

- name: Add Poetry to Path
run: echo "$APPDATA\Python\Scripts" >> $GITHUB_PATH
if: ${{ matrix.os == 'Windows' }}
shell: bash

- name: Configure Poetry virtual environment in project
run: poetry config virtualenvs.in-project true
shell: bash

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v3
with:
path: .venv
key: venv-${{ runner.os }}-${{ inputs.python-version }}-${{ hashFiles('poetry.lock') }}

- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
shell: bash
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ chispa.egg-info/
tmp/
.idea/
.DS_Store
.python_version

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1 change: 0 additions & 1 deletion .python-version

This file was deleted.

36 changes: 26 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,42 @@ nan2 = float('nan')
nan1 == nan2 # False
```

Pandas, a popular DataFrame library, does consider NaN values to be equal by default.

This library requires you to set a flag to consider two NaN values to be equal.
pandas considers NaN values to be equal by default, but this library requires you to set a flag to consider two NaN values to be equal.

```python
assert_df_equality(df1, df2, allow_nan_equality=True)
```

### Underline differences within rows
## Customize formatting

*Available in chispa 0.10+*.

You can specify custom formats for the printed error messages as follows:

```python
@dataclass
class MyFormats:
mismatched_rows = ["light_yellow"]
matched_rows = ["cyan", "bold"]
mismatched_cells = ["purple"]
matched_cells = ["blue"]

assert_basic_rows_equality(df1.collect(), df2.collect(), formats=MyFormats())
```

You can choose to underline columns within a row that are different by setting `underline_cells` to True, i.e.:
You can also define these formats in `conftest.py` and inject them via a fixture:

```python
assert_df_equality(df1, df2, underline_cells=True)
@pytest.fixture()
def my_formats():
return MyFormats()

def test_shows_assert_basic_rows_equality(my_formats):
...
assert_basic_rows_equality(df1.collect(), df2.collect(), formats=my_formats)
```

![DfsNotEqualUnderlined](https://github.com/MrPowers/chispa/blob/main/images/df_not_equal_underlined.png)
![custom_formats](https://github.com/MrPowers/chispa/blob/main/images/custom_formats.png)

## Approximate column equality

Expand Down Expand Up @@ -456,9 +475,6 @@ TODO: Need to benchmark these methods vs. the spark-testing-base ones

## Vendored dependencies

These dependencies are vendored:

* [six](https://github.com/benjaminp/six)
* [PrettyTable](https://github.com/jazzband/prettytable)

The dependencies are vendored to save you from dependency hell.
Expand Down
20 changes: 20 additions & 0 deletions chispa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,23 @@
from .dataframe_comparer import DataFramesNotEqualError, assert_df_equality, assert_approx_df_equality
from .column_comparer import ColumnsNotEqualError, assert_column_equality, assert_approx_column_equality
from .rows_comparer import assert_basic_rows_equality
from chispa.default_formats import DefaultFormats

class Chispa():
def __init__(self, formats=DefaultFormats(), default_output=None):
self.formats = formats
self.default_outputs = default_output

def assert_df_equality(self, df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False,
ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False):
return assert_df_equality(
df1,
df2,
ignore_nullable,
transforms,
allow_nan_equality,
ignore_column_order,
ignore_row_order,
underline_cells,
ignore_metadata,
self.formats)
17 changes: 9 additions & 8 deletions chispa/dataframe_comparer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from chispa.schema_comparer import assert_schema_equality
from chispa.row_comparer import *
from chispa.default_formats import DefaultFormats
from chispa.rows_comparer import assert_basic_rows_equality, assert_generic_rows_equality
from chispa.row_comparer import are_rows_equal_enhanced, are_rows_approx_equal
from functools import reduce


Expand All @@ -10,7 +11,7 @@ class DataFramesNotEqualError(Exception):


def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False,
ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False):
ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False, formats=DefaultFormats()):
if transforms is None:
transforms = []
if ignore_column_order:
Expand All @@ -22,10 +23,10 @@ def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_n
assert_schema_equality(df1.schema, df2.schema, ignore_nullable, ignore_metadata)
if allow_nan_equality:
assert_generic_rows_equality(
df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], underline_cells=underline_cells)
df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], underline_cells=underline_cells, formats=formats)
else:
assert_basic_rows_equality(
df1.collect(), df2.collect(), underline_cells=underline_cells)
df1.collect(), df2.collect(), underline_cells=underline_cells, formats=formats)


def are_dfs_equal(df1, df2):
Expand All @@ -37,7 +38,7 @@ def are_dfs_equal(df1, df2):


def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transforms=None, allow_nan_equality=False,
ignore_column_order=False, ignore_row_order=False):
ignore_column_order=False, ignore_row_order=False, formats=DefaultFormats()):
if transforms is None:
transforms = []
if ignore_column_order:
Expand All @@ -48,8 +49,8 @@ def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transf
df2 = reduce(lambda acc, fn: fn(acc), transforms, df2)
assert_schema_equality(df1.schema, df2.schema, ignore_nullable)
if precision != 0:
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_approx_equal, [precision, allow_nan_equality])
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_approx_equal, [precision, allow_nan_equality], formats)
elif allow_nan_equality:
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True])
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], formats)
else:
assert_basic_rows_equality(df1.collect(), df2.collect())
assert_basic_rows_equality(df1.collect(), df2.collect(), formats)
8 changes: 8 additions & 0 deletions chispa/default_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass

@dataclass
class DefaultFormats:
mismatched_rows = ["red"]
matched_rows = ["blue"]
mismatched_cells = ["red", "underline"]
matched_cells = ["blue"]
115 changes: 54 additions & 61 deletions chispa/rows_comparer.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,78 @@
import chispa.six as six
from itertools import zip_longest
from chispa.prettytable import PrettyTable
from chispa.bcolors import *
import chispa
from pyspark.sql.types import Row
from typing import List
from chispa.terminal_str_formatter import format_string
from chispa.default_formats import DefaultFormats


def assert_basic_rows_equality(rows1, rows2, underline_cells=False):
if underline_cells:
row_column_names = rows1[0].__fields__
num_columns = len(row_column_names)
def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats=DefaultFormats()):
if rows1 != rows2:
t = PrettyTable(["df1", "df2"])
zipped = list(six.moves.zip_longest(rows1, rows2))
zipped = list(zip_longest(rows1, rows2))
all_rows_equal = True

for r1, r2 in zipped:
if r1 == r2:
t.add_row([blue(r1), blue(r2)])
if r1 is None and r2 is not None:
t.add_row([None, format_string(r2, formats.mismatched_rows)])
all_rows_equal = False
elif r1 is not None and r2 is None:
t.add_row([format_string(r1, formats.mismatched_rows), None])
all_rows_equal = False
else:
if underline_cells:
t.add_row(__underline_cells_in_row(
r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns))
else:
t.add_row([r1, r2])
raise chispa.DataFramesNotEqualError("\n" + t.get_string())
r_zipped = list(zip_longest(r1.__fields__, r2.__fields__))
r1_string = []
r2_string = []
for r1_field, r2_field in r_zipped:
if r1[r1_field] != r2[r2_field]:
all_rows_equal = False
r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells))
r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells))
else:
r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells))
r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells))
r1_res = ", ".join(r1_string)
r2_res = ", ".join(r2_string)

t.add_row([r1_res, r2_res])
if all_rows_equal == False:
raise chispa.DataFramesNotEqualError("\n" + t.get_string())


def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, underline_cells=False):
def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, underline_cells=False, formats=DefaultFormats()):
df1_rows = rows1
df2_rows = rows2
zipped = list(six.moves.zip_longest(df1_rows, df2_rows))
zipped = list(zip_longest(df1_rows, df2_rows))
t = PrettyTable(["df1", "df2"])
allRowsEqual = True
if underline_cells:
row_column_names = rows1[0].__fields__
num_columns = len(row_column_names)
all_rows_equal = True
for r1, r2 in zipped:
# rows are not equal when one is None and the other isn't
if (r1 is not None and r2 is None) or (r2 is not None and r1 is None):
allRowsEqual = False
t.add_row([r1, r2])
all_rows_equal = False
t.add_row([format_string(r1, formats.mismatched_rows), format_string(r2, formats.mismatched_rows)])
# rows are equal
elif row_equality_fun(r1, r2, *row_equality_fun_args):
first = bcolors.LightBlue + str(r1) + bcolors.LightRed
second = bcolors.LightBlue + str(r2) + bcolors.LightRed
t.add_row([first, second])
r1_string = ", ".join(map(lambda f: f"{f}={r1[f]}", r1.__fields__))
r2_string = ", ".join(map(lambda f: f"{f}={r2[f]}", r2.__fields__))
t.add_row([format_string(r1_string, formats.matched_rows), format_string(r2_string, formats.matched_rows)])
# otherwise, rows aren't equal
else:
allRowsEqual = False
# Underline cells if requested
if underline_cells:
t.add_row(__underline_cells_in_row(
r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns))
else:
t.add_row([r1, r2])
if allRowsEqual == False:
raise chispa.DataFramesNotEqualError("\n" + t.get_string())


def __underline_cells_in_row(r1=Row, r2=Row, row_column_names=List[str], num_columns=int) -> List[str]:
"""
Takes two Row types, a list of column names for the Rows and the length of columns
Returns list of two strings, with underlined columns within rows that are different for PrettyTable
"""
r1_string = "Row("
r2_string = "Row("
for index, column in enumerate(row_column_names):
if ((index+1) == num_columns):
append_str = ""
else:
append_str = ", "

if r1[column] != r2[column]:
r1_string += underline_text(
f"{column}='{r1[column]}'") + f"{append_str}"
r2_string += underline_text(
f"{column}='{r2[column]}'") + f"{append_str}"
else:
r1_string += f"{column}='{r1[column]}'{append_str}"
r2_string += f"{column}='{r2[column]}'{append_str}"

r1_string += ")"
r2_string += ")"
r_zipped = list(zip_longest(r1.__fields__, r2.__fields__))
r1_string = []
r2_string = []
for r1_field, r2_field in r_zipped:
if r1[r1_field] != r2[r2_field]:
all_rows_equal = False
r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells))
r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells))
else:
r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells))
r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells))
r1_res = ", ".join(r1_string)
r2_res = ", ".join(r2_string)

return [bcolors.LightRed + r1_string, r2_string]
t.add_row([r1_res, r2_res])
if all_rows_equal == False:
raise chispa.DataFramesNotEqualError("\n" + t.get_string())
12 changes: 6 additions & 6 deletions chispa/schema_comparer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from chispa.prettytable import PrettyTable
from chispa.bcolors import *
import chispa.six as six
from itertools import zip_longest


class SchemasNotEqualError(Exception):
Expand All @@ -19,15 +19,15 @@ def assert_schema_equality_full(s1, s2, ignore_nullable=False, ignore_metadata=F
def inner(s1, s2, ignore_nullable, ignore_metadata):
if len(s1) != len(s2):
return False
zipped = list(six.moves.zip_longest(s1, s2))
zipped = list(zip_longest(s1, s2))
for sf1, sf2 in zipped:
if not are_structfields_equal(sf1, sf2, ignore_nullable, ignore_metadata):
return False
return True

if not inner(s1, s2, ignore_nullable, ignore_metadata):
t = PrettyTable(["schema1", "schema2"])
zipped = list(six.moves.zip_longest(s1, s2))
zipped = list(zip_longest(s1, s2))
for sf1, sf2 in zipped:
if are_structfields_equal(sf1, sf2, True):
t.add_row([blue(sf1), blue(sf2)])
Expand All @@ -42,7 +42,7 @@ def inner(s1, s2, ignore_nullable, ignore_metadata):
def assert_basic_schema_equality(s1, s2):
if s1 != s2:
t = PrettyTable(["schema1", "schema2"])
zipped = list(six.moves.zip_longest(s1, s2))
zipped = list(zip_longest(s1, s2))
for sf1, sf2 in zipped:
if sf1 == sf2:
t.add_row([blue(sf1), blue(sf2)])
Expand All @@ -56,7 +56,7 @@ def assert_basic_schema_equality(s1, s2):
def assert_schema_equality_ignore_nullable(s1, s2):
if not are_schemas_equal_ignore_nullable(s1, s2):
t = PrettyTable(["schema1", "schema2"])
zipped = list(six.moves.zip_longest(s1, s2))
zipped = list(zip_longest(s1, s2))
for sf1, sf2 in zipped:
if are_structfields_equal(sf1, sf2, True):
t.add_row([blue(sf1), blue(sf2)])
Expand All @@ -69,7 +69,7 @@ def assert_schema_equality_ignore_nullable(s1, s2):
def are_schemas_equal_ignore_nullable(s1, s2):
if len(s1) != len(s2):
return False
zipped = list(six.moves.zip_longest(s1, s2))
zipped = list(zip_longest(s1, s2))
for sf1, sf2 in zipped:
if not are_structfields_equal(sf1, sf2, True):
return False
Expand Down
Loading

0 comments on commit d5f4c01

Please sign in to comment.