diff --git a/README.md b/README.md index 907d387..0bd1a92 100644 --- a/README.md +++ b/README.md @@ -311,31 +311,51 @@ assert_df_equality(df1, df2, allow_nan_equality=True) ## Customize formatting -*Available in chispa 0.10+*. - You can specify custom formats for the printed error messages as follows: ```python -@dataclass -class MyFormats: - mismatched_rows = ["light_yellow"] - matched_rows = ["cyan", "bold"] - mismatched_cells = ["purple"] - matched_cells = ["blue"] +from chispa import FormattingConfig + +formats = FormattingConfig( + mismatched_rows={"color": "light_yellow"}, + matched_rows={"color": "cyan", "style": "bold"}, + mismatched_cells={"color": "purple"}, + matched_cells={"color": "blue"}, + ) -assert_basic_rows_equality(df1.collect(), df2.collect(), formats=MyFormats()) +assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats) +``` + +or similarly: + +```python +from chispa import FormattingConfig, Color, Style + +formats = FormattingConfig( + mismatched_rows={"color": Color.LIGHT_YELLOW}, + matched_rows={"color": Color.CYAN, "style": Style.BOLD}, + mismatched_cells={"color": Color.PURPLE}, + matched_cells={"color": Color.BLUE}, + ) + +assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats) ``` You can also define these formats in `conftest.py` and inject them via a fixture: ```python @pytest.fixture() -def my_formats(): - return MyFormats() +def chispa_formats(): + return FormattingConfig( + mismatched_rows={"color": "light_yellow"}, + matched_rows={"color": "cyan", "style": "bold"}, + mismatched_cells={"color": "purple"}, + matched_cells={"color": "blue"}, + ) -def test_shows_assert_basic_rows_equality(my_formats): +def test_shows_assert_basic_rows_equality(chispa_formats): ... - assert_basic_rows_equality(df1.collect(), df2.collect(), formats=my_formats) + assert_basic_rows_equality(df1.collect(), df2.collect(), formats=chispa_formats) ``` ![custom_formats](https://github.com/MrPowers/chispa/blob/main/images/custom_formats.png) diff --git a/chispa/__init__.py b/chispa/__init__.py index 2d26547..cef3d31 100644 --- a/chispa/__init__.py +++ b/chispa/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys from glob import glob @@ -28,6 +30,7 @@ exit(-1) from chispa.default_formats import DefaultFormats +from chispa.formatting import Color, Format, FormattingConfig, Style from .column_comparer import ( ColumnsNotEqualError, @@ -43,8 +46,14 @@ class Chispa: - def __init__(self, formats=DefaultFormats(), default_output=None): - self.formats = formats + def __init__(self, formats: FormattingConfig | None = None, default_output=None): + if not formats: + self.formats = FormattingConfig() + elif isinstance(formats, FormattingConfig): + self.formats = formats + else: + self.formats = FormattingConfig._from_arbitrary_dataclass(formats) + self.default_outputs = default_output def assert_df_equality( @@ -81,6 +90,10 @@ def assert_df_equality( "assert_column_equality", "assert_approx_column_equality", "assert_basic_rows_equality", - "DefaultFormats", + "Style", + "Color", + "FormattingConfig", + "Format", "Chispa", + "DefaultFormats", ) diff --git a/chispa/bcolors.py b/chispa/bcolors.py index 3c531f2..7b77215 100644 --- a/chispa/bcolors.py +++ b/chispa/bcolors.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + class bcolors: NC = "\033[0m" # No Color, reset all diff --git a/chispa/column_comparer.py b/chispa/column_comparer.py index a3fc43c..e6a380c 100644 --- a/chispa/column_comparer.py +++ b/chispa/column_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from prettytable import PrettyTable from chispa.bcolors import bcolors diff --git a/chispa/dataframe_comparer.py b/chispa/dataframe_comparer.py index 9c23178..167ceb2 100644 --- a/chispa/dataframe_comparer.py +++ b/chispa/dataframe_comparer.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from functools import reduce -from chispa.default_formats import DefaultFormats +from chispa.formatting import FormattingConfig from chispa.row_comparer import are_rows_approx_equal, are_rows_equal_enhanced from chispa.rows_comparer import ( assert_basic_rows_equality, @@ -25,8 +27,13 @@ def assert_df_equality( ignore_row_order=False, underline_cells=False, ignore_metadata=False, - formats=DefaultFormats(), + formats: FormattingConfig | None = None, ): + if not formats: + formats = FormattingConfig() + elif not isinstance(formats, FormattingConfig): + formats = FormattingConfig._from_arbitrary_dataclass(formats) + if transforms is None: transforms = [] if ignore_column_order: @@ -71,8 +78,13 @@ def assert_approx_df_equality( allow_nan_equality=False, ignore_column_order=False, ignore_row_order=False, - formats=DefaultFormats(), + formats: FormattingConfig | None = None, ): + if not formats: + formats = FormattingConfig() + elif not isinstance(formats, FormattingConfig): + formats = FormattingConfig._from_arbitrary_dataclass(formats) + if transforms is None: transforms = [] if ignore_column_order: diff --git a/chispa/default_formats.py b/chispa/default_formats.py index 433db11..ab1f292 100644 --- a/chispa/default_formats.py +++ b/chispa/default_formats.py @@ -1,9 +1,21 @@ -from dataclasses import dataclass +from __future__ import annotations + +import warnings +from dataclasses import dataclass, field @dataclass class DefaultFormats: - mismatched_rows = ["red"] - matched_rows = ["blue"] - mismatched_cells = ["red", "underline"] - matched_cells = ["blue"] + """ + This class is now deprecated and should be removed in a future release. + """ + + mismatched_rows: list[str] = field(default_factory=lambda: ["red"]) + matched_rows: list[str] = field(default_factory=lambda: ["blue"]) + mismatched_cells: list[str] = field(default_factory=lambda: ["red", "underline"]) + matched_cells: list[str] = field(default_factory=lambda: ["blue"]) + + def __post_init__(self): + warnings.warn( + "DefaultFormats is deprecated. Use `chispa.formatting.FormattingConfig` instead.", DeprecationWarning + ) diff --git a/chispa/formatting/__init__.py b/chispa/formatting/__init__.py new file mode 100644 index 0000000..6d107b0 --- /dev/null +++ b/chispa/formatting/__init__.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from chispa.formatting.format_string import format_string +from chispa.formatting.formats import RESET, Color, Format, Style +from chispa.formatting.formatting_config import FormattingConfig + +__all__ = ("Style", "Color", "FormattingConfig", "Format", "format_string", "RESET") diff --git a/chispa/formatting/format_string.py b/chispa/formatting/format_string.py new file mode 100644 index 0000000..0725744 --- /dev/null +++ b/chispa/formatting/format_string.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from chispa.formatting.formats import RESET, Format + + +def format_string(input_string: str, format: Format) -> str: + if not format.color and not format.style: + return input_string + + formatted_string = input_string + codes = [] + + if format.style: + for style in format.style: + codes.append(style.value) + + if format.color: + codes.append(format.color.value) + + formatted_string = "".join(codes) + formatted_string + RESET + return formatted_string diff --git a/chispa/formatting/formats.py b/chispa/formatting/formats.py new file mode 100644 index 0000000..76064d8 --- /dev/null +++ b/chispa/formatting/formats.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +RESET = "\033[0m" + + +class Color(str, Enum): + """ + Enum for terminal colors. + Each color is represented by its corresponding ANSI escape code. + """ + + BLACK = "\033[30m" + RED = "\033[31m" + GREEN = "\033[32m" + YELLOW = "\033[33m" + BLUE = "\033[34m" + PURPLE = "\033[35m" + CYAN = "\033[36m" + LIGHT_GRAY = "\033[37m" + DARK_GRAY = "\033[90m" + LIGHT_RED = "\033[91m" + LIGHT_GREEN = "\033[92m" + LIGHT_YELLOW = "\033[93m" + LIGHT_BLUE = "\033[94m" + LIGHT_PURPLE = "\033[95m" + LIGHT_CYAN = "\033[96m" + WHITE = "\033[97m" + + +class Style(str, Enum): + """ + Enum for text styles. + Each style is represented by its corresponding ANSI escape code. + """ + + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + BLINK = "\033[5m" + INVERT = "\033[7m" + HIDE = "\033[8m" + + +@dataclass +class Format: + """ + Data class to represent text formatting with color and style. + + Attributes: + color (Color | None): The color for the text. + style (list[Style] | None): A list of styles for the text. + """ + + color: Color | None = None + style: list[Style] | None = None + + @classmethod + def from_dict(cls, format_dict: dict) -> Format: + """ + Create a Format instance from a dictionary. + + Args: + format_dict (dict): A dictionary with keys 'color' and/or 'style'. + """ + if not isinstance(format_dict, dict): + raise ValueError("Input must be a dictionary") + + valid_keys = {"color", "style"} + invalid_keys = set(format_dict) - valid_keys + if invalid_keys: + raise ValueError(f"Invalid keys in format dictionary: {invalid_keys}. Valid keys are {valid_keys}") + + color = cls._get_color_enum(format_dict.get("color")) + style = format_dict.get("style") + if isinstance(style, str): + styles = [cls._get_style_enum(style)] + elif isinstance(style, list): + styles = [cls._get_style_enum(s) for s in style] + else: + styles = None + + return cls(color=color, style=styles) + + @classmethod + def from_list(cls, values: list[str]) -> Format: + """ + Create a Format instance from a list of strings. + + Args: + values (list[str]): A list of strings representing colors and styles. + """ + if not all(isinstance(value, str) for value in values): + raise ValueError("All elements in the list must be strings") + + color = None + styles = [] + valid_colors = [c.name.lower() for c in Color] + valid_styles = [s.name.lower() for s in Style] + + for value in values: + if value in valid_colors: + color = Color[value.upper()] + elif value in valid_styles: + styles.append(Style[value.upper()]) + else: + raise ValueError( + f"Invalid value: {value}. Valid values are colors: {valid_colors} and styles: {valid_styles}" + ) + + return cls(color=color, style=styles if styles else None) + + @staticmethod + def _get_color_enum(color: Color | str | None) -> Color | None: + if isinstance(color, Color): + return color + elif isinstance(color, str): + try: + return Color[color.upper()] + except KeyError: + valid_colors = [c.name.lower() for c in Color] + raise ValueError(f"Invalid color name: {color}. Valid color names are {valid_colors}") + return None + + @staticmethod + def _get_style_enum(style: Style | str | None) -> Style | None: + if isinstance(style, Style): + return style + elif isinstance(style, str): + try: + return Style[style.upper()] + except KeyError: + valid_styles = [f.name.lower() for f in Style] + raise ValueError(f"Invalid style name: {style}. Valid style names are {valid_styles}") + return None diff --git a/chispa/formatting/formatting_config.py b/chispa/formatting/formatting_config.py new file mode 100644 index 0000000..055428f --- /dev/null +++ b/chispa/formatting/formatting_config.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import warnings +from typing import Any, ClassVar + +from chispa.default_formats import DefaultFormats +from chispa.formatting.formats import Color, Format, Style + + +class FormattingConfig: + """ + Class to manage and parse formatting configurations. + """ + + VALID_KEYS: ClassVar = {"color", "style"} + + def __init__( + self, + mismatched_rows: Format | dict = Format(Color.RED), + matched_rows: Format | dict = Format(Color.BLUE), + mismatched_cells: Format | dict = Format(Color.RED, [Style.UNDERLINE]), + matched_cells: Format | dict = Format(Color.BLUE), + ): + """ + Initializes the FormattingConfig with given or default formatting. + + Each of the arguments can be provided as a `Format` object or a dictionary with the following keys: + - 'color': A string representing a color name, which should be one of the valid colors: + ['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', + 'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', + 'light_purple', 'light_cyan', 'white']. + - 'style': A string or list of strings representing styles, which should be one of the valid styles: + ['bold', 'underline', 'blink', 'invert', 'hide']. + + Args: + mismatched_rows (Format | dict): Format or dictionary for mismatched rows. + matched_rows (Format | dict): Format or dictionary for matched rows. + mismatched_cells (Format | dict): Format or dictionary for mismatched cells. + matched_cells (Format | dict): Format or dictionary for matched cells. + + Raises: + ValueError: If the dictionary contains invalid keys or values. + """ + self.mismatched_rows: Format = self._parse_format(mismatched_rows) + self.matched_rows: Format = self._parse_format(matched_rows) + self.mismatched_cells: Format = self._parse_format(mismatched_cells) + self.matched_cells: Format = self._parse_format(matched_cells) + + def _parse_format(self, format: Format | dict) -> Format: + if isinstance(format, Format): + return format + elif isinstance(format, dict): + return Format.from_dict(format) + raise ValueError("Invalid format type. Must be Format or dict.") + + @classmethod + def _from_arbitrary_dataclass(cls, instance: Any) -> FormattingConfig: + """ + Converts an instance of an arbitrary class with specified fields to a FormattingConfig instance. + This method is purely for backwards compatibility and should be removed in a future release, + together with the `DefaultFormats` class. + """ + + if not isinstance(instance, DefaultFormats): + warnings.warn( + "Using an arbitrary dataclass is deprecated. Use `chispa.formatting.FormattingConfig` instead.", + DeprecationWarning, + ) + + mismatched_rows = Format.from_list(getattr(instance, "mismatched_rows")) + matched_rows = Format.from_list(getattr(instance, "matched_rows")) + mismatched_cells = Format.from_list(getattr(instance, "mismatched_cells")) + matched_cells = Format.from_list(getattr(instance, "matched_cells")) + + return cls( + mismatched_rows=mismatched_rows, + matched_rows=matched_rows, + mismatched_cells=mismatched_cells, + matched_cells=matched_cells, + ) diff --git a/chispa/number_helpers.py b/chispa/number_helpers.py index 818d3f5..a49c84a 100644 --- a/chispa/number_helpers.py +++ b/chispa/number_helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math diff --git a/chispa/row_comparer.py b/chispa/row_comparer.py index a1bf519..7835e5d 100644 --- a/chispa/row_comparer.py +++ b/chispa/row_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math from pyspark.sql import Row diff --git a/chispa/rows_comparer.py b/chispa/rows_comparer.py index c7f338e..30fdc96 100644 --- a/chispa/rows_comparer.py +++ b/chispa/rows_comparer.py @@ -1,13 +1,19 @@ +from __future__ import annotations + from itertools import zip_longest from prettytable import PrettyTable import chispa -from chispa.default_formats import DefaultFormats -from chispa.terminal_str_formatter import format_string +from chispa.formatting import FormattingConfig, format_string + +def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats: FormattingConfig | None = None): + if not formats: + formats = FormattingConfig() + elif not isinstance(formats, FormattingConfig): + formats = FormattingConfig._from_arbitrary_dataclass(formats) -def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats=DefaultFormats()): if rows1 != rows2: t = PrettyTable(["df1", "df2"]) zipped = list(zip_longest(rows1, rows2)) @@ -15,10 +21,10 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats=Defa for r1, r2 in zipped: if r1 is None and r2 is not None: - t.add_row([None, format_string(r2, formats.mismatched_rows)]) + t.add_row([None, format_string(str(r2), formats.mismatched_rows)]) all_rows_equal = False elif r1 is not None and r2 is None: - t.add_row([format_string(r1, formats.mismatched_rows), None]) + t.add_row([format_string(str(r1), formats.mismatched_rows), None]) all_rows_equal = False else: r_zipped = list(zip_longest(r1.__fields__, r2.__fields__)) @@ -46,8 +52,13 @@ def assert_generic_rows_equality( row_equality_fun, row_equality_fun_args, underline_cells=False, - formats=DefaultFormats(), + formats: FormattingConfig | None = None, ): + if not formats: + formats = FormattingConfig() + elif not isinstance(formats, FormattingConfig): + formats = FormattingConfig._from_arbitrary_dataclass(formats) + df1_rows = rows1 df2_rows = rows2 zipped = list(zip_longest(df1_rows, df2_rows)) @@ -55,11 +66,11 @@ def assert_generic_rows_equality( all_rows_equal = True for r1, r2 in zipped: # rows are not equal when one is None and the other isn't - if (r1 is not None and r2 is None) or (r2 is not None and r1 is None): + if (r1 is None) ^ (r2 is None): all_rows_equal = False t.add_row([ - format_string(r1, formats.mismatched_rows), - format_string(r2, formats.mismatched_rows), + format_string(str(r1), formats.mismatched_rows), + format_string(str(r2), formats.mismatched_rows), ]) # rows are equal elif row_equality_fun(r1, r2, *row_equality_fun_args): diff --git a/chispa/schema_comparer.py b/chispa/schema_comparer.py index f7cebfc..c2e2150 100644 --- a/chispa/schema_comparer.py +++ b/chispa/schema_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from itertools import zip_longest from prettytable import PrettyTable diff --git a/chispa/structfield_comparer.py b/chispa/structfield_comparer.py index b0fd896..f1b4782 100644 --- a/chispa/structfield_comparer.py +++ b/chispa/structfield_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from chispa.schema_comparer import are_structfields_equal __all__ = ("are_structfields_equal",) diff --git a/chispa/terminal_str_formatter.py b/chispa/terminal_str_formatter.py deleted file mode 100644 index 69174d7..0000000 --- a/chispa/terminal_str_formatter.py +++ /dev/null @@ -1,30 +0,0 @@ -def format_string(input, formats): - formatting = { - "nc": "\033[0m", # No Color, reset all - "bold": "\033[1m", - "underline": "\033[4m", - "blink": "\033[5m", - "blue": "\033[34m", - "white": "\033[97m", - "red": "\033[31m", - "invert": "\033[7m", - "hide": "\033[8m", - "black": "\033[30m", - "green": "\033[32m", - "yellow": "\033[33m", - "purple": "\033[35m", - "cyan": "\033[36m", - "light_gray": "\033[37m", - "dark_gray": "\033[30m", - "light_red": "\033[31m", - "light_green": "\033[32m", - "light_yellow": "\033[93m", - "light_blue": "\033[34m", - "light_purple": "\033[35m", - "light_cyan": "\033[36m", - } - formatted = input - for format in formats: - s = formatting[format] - formatted = s + str(formatted) + s - return formatting["nc"] + str(formatted) + formatting["nc"] diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index edd357e..03becb6 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -3,6 +3,8 @@ https://mkdocstrings.github.io/recipes/#automatic-code-reference-pages """ +from __future__ import annotations + from pathlib import Path import mkdocs_gen_files diff --git a/mkdocs.yml b/mkdocs.yml index 514205d..260b226 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,7 +22,7 @@ plugins: handlers: python: options: - docstring_style: sphinx + docstring_style: google docstring_options: show_if_no_docstring: true show_source: true diff --git a/pyproject.toml b/pyproject.toml index 00b247a..167bfda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,9 +71,7 @@ preview = true select = ["E", "F", "I", "RUF", "UP"] ignore = [ # Line too long - "E501", - # Mutable class attributes - "RUF012" + "E501" ] [tool.ruff.lint.flake8-type-checking] @@ -81,3 +79,6 @@ strict = true [tool.ruff.lint.per-file-ignores] "tests/*" = ["S101", "S603"] + +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] diff --git a/tests/conftest.py b/tests/conftest.py index 514c2b6..5875d42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,25 @@ -from dataclasses import dataclass +from __future__ import annotations import pytest -from chispa import Chispa - - -@dataclass -class MyFormats: - mismatched_rows = ["light_yellow"] - matched_rows = ["cyan", "bold"] - mismatched_cells = ["purple"] - matched_cells = ["blue"] +from chispa.formatting import FormattingConfig @pytest.fixture() def my_formats(): - return MyFormats() + return FormattingConfig( + mismatched_rows={"color": "light_yellow"}, + matched_rows={"color": "cyan", "style": "bold"}, + mismatched_cells={"color": "purple"}, + matched_cells={"color": "blue"}, + ) @pytest.fixture() def my_chispa(): - return Chispa(formats=MyFormats()) + return FormattingConfig( + mismatched_rows={"color": "light_yellow"}, + matched_rows={"color": "cyan", "style": "bold"}, + mismatched_cells={"color": "purple"}, + matched_cells={"color": "blue"}, + ) diff --git a/tests/formatting/test_formats.py b/tests/formatting/test_formats.py new file mode 100644 index 0000000..99f0d3b --- /dev/null +++ b/tests/formatting/test_formats.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import re + +import pytest + +from chispa.formatting import Color, Format, Style + + +def test_format_from_dict_valid(): + format_dict = {"color": "blue", "style": ["bold", "underline"]} + format_instance = Format.from_dict(format_dict) + assert format_instance.color == Color.BLUE + assert format_instance.style == [Style.BOLD, Style.UNDERLINE] + + +def test_format_from_dict_invalid_color(): + format_dict = {"color": "invalid_color", "style": ["bold"]} + with pytest.raises(ValueError) as exc_info: + Format.from_dict(format_dict) + assert str(exc_info.value) == ( + "Invalid color name: invalid_color. Valid color names are " + "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " + "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " + "'light_cyan', 'white']" + ) + + +def test_format_from_dict_invalid_style(): + format_dict = {"color": "blue", "style": ["invalid_style"]} + with pytest.raises(ValueError) as exc_info: + Format.from_dict(format_dict) + assert str(exc_info.value) == ( + "Invalid style name: invalid_style. Valid style names are " "['bold', 'underline', 'blink', 'invert', 'hide']" + ) + + +def test_format_from_dict_invalid_key(): + format_dict = {"invalid_key": "value"} + try: + Format.from_dict(format_dict) + except ValueError as e: + error_message = str(e) + assert re.match( + r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}", + error_message, + ) + + +def test_format_from_list_valid(): + values = ["blue", "bold", "underline"] + format_instance = Format.from_list(values) + assert format_instance.color == Color.BLUE + assert format_instance.style == [Style.BOLD, Style.UNDERLINE] + + +def test_format_from_list_invalid_color(): + values = ["invalid_color", "bold"] + with pytest.raises(ValueError) as exc_info: + Format.from_list(values) + assert str(exc_info.value) == ( + "Invalid value: invalid_color. Valid values are colors: " + "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " + "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " + "'light_cyan', 'white'] and styles: ['bold', 'underline', 'blink', 'invert', 'hide']" + ) + + +def test_format_from_list_invalid_style(): + values = ["blue", "invalid_style"] + with pytest.raises(ValueError) as exc_info: + Format.from_list(values) + assert str(exc_info.value) == ( + "Invalid value: invalid_style. Valid values are colors: " + "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " + "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " + "'light_cyan', 'white'] and styles: ['bold', 'underline', 'blink', 'invert', 'hide']" + ) + + +def test_format_from_list_non_string_elements(): + values = ["blue", 123] + with pytest.raises(ValueError) as exc_info: + Format.from_list(values) + assert str(exc_info.value) == "All elements in the list must be strings" + + +def test_format_from_dict_empty(): + format_dict = {} + format_instance = Format.from_dict(format_dict) + assert format_instance.color is None + assert format_instance.style is None + + +def test_format_from_list_empty(): + values = [] + format_instance = Format.from_list(values) + assert format_instance.color is None + assert format_instance.style is None diff --git a/tests/formatting/test_formatting_config.py b/tests/formatting/test_formatting_config.py new file mode 100644 index 0000000..3214ac8 --- /dev/null +++ b/tests/formatting/test_formatting_config.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import re + +import pytest + +from chispa.formatting import Color, FormattingConfig, Style + + +def test_default_mismatched_rows(): + config = FormattingConfig() + assert config.mismatched_rows.color == Color.RED + assert config.mismatched_rows.style is None + + +def test_default_matched_rows(): + config = FormattingConfig() + assert config.matched_rows.color == Color.BLUE + assert config.matched_rows.style is None + + +def test_default_mismatched_cells(): + config = FormattingConfig() + assert config.mismatched_cells.color == Color.RED + assert config.mismatched_cells.style == [Style.UNDERLINE] + + +def test_default_matched_cells(): + config = FormattingConfig() + assert config.matched_cells.color == Color.BLUE + assert config.matched_cells.style is None + + +def test_custom_mismatched_rows(): + config = FormattingConfig(mismatched_rows={"color": "green", "style": ["bold", "underline"]}) + assert config.mismatched_rows.color == Color.GREEN + assert config.mismatched_rows.style == [Style.BOLD, Style.UNDERLINE] + + +def test_custom_matched_rows(): + config = FormattingConfig(matched_rows={"color": "yellow"}) + assert config.matched_rows.color == Color.YELLOW + assert config.matched_rows.style is None + + +def test_custom_mismatched_cells(): + config = FormattingConfig(mismatched_cells={"color": "purple", "style": ["blink"]}) + assert config.mismatched_cells.color == Color.PURPLE + assert config.mismatched_cells.style == [Style.BLINK] + + +def test_custom_matched_cells(): + config = FormattingConfig(matched_cells={"color": "cyan", "style": ["invert", "hide"]}) + assert config.matched_cells.color == Color.CYAN + assert config.matched_cells.style == [Style.INVERT, Style.HIDE] + + +def test_invalid_color(): + with pytest.raises(ValueError) as exc_info: + FormattingConfig(mismatched_rows={"color": "invalid_color"}) + assert str(exc_info.value) == ( + "Invalid color name: invalid_color. Valid color names are " + "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " + "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " + "'light_cyan', 'white']" + ) + + +def test_invalid_style(): + with pytest.raises(ValueError) as exc_info: + FormattingConfig(mismatched_rows={"style": ["invalid_style"]}) + assert str(exc_info.value) == ( + "Invalid style name: invalid_style. Valid style names are " "['bold', 'underline', 'blink', 'invert', 'hide']" + ) + + +def test_invalid_key(): + try: + FormattingConfig(mismatched_rows={"invalid_key": "value"}) + except ValueError as e: + error_message = str(e) + assert re.match( + r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}", + error_message, + ) diff --git a/tests/formatting/test_terminal_string_formatter.py b/tests/formatting/test_terminal_string_formatter.py new file mode 100644 index 0000000..d6d4959 --- /dev/null +++ b/tests/formatting/test_terminal_string_formatter.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from chispa.formatting import RESET, format_string +from chispa.formatting.formats import Color, Format, Style + + +def test_format_with_enum_inputs(): + format = Format(color=Color.BLUE, style=[Style.BOLD, Style.UNDERLINE]) + formatted_string = format_string("Hello, World!", format) + expected_string = f"{Style.BOLD.value}{Style.UNDERLINE.value}{Color.BLUE.value}Hello, World!{RESET}" + assert formatted_string == expected_string + + +def test_format_with_no_style(): + format = Format(color=Color.GREEN, style=[]) + formatted_string = format_string("Hello, World!", format) + expected_string = f"{Color.GREEN.value}Hello, World!{RESET}" + assert formatted_string == expected_string + + +def test_format_with_no_color(): + format = Format(color=None, style=[Style.BLINK]) + formatted_string = format_string("Hello, World!", format) + expected_string = f"{Style.BLINK.value}Hello, World!{RESET}" + assert formatted_string == expected_string + + +def test_format_with_no_color_or_style(): + format = Format(color=None, style=[]) + formatted_string = format_string("Hello, World!", format) + expected_string = "Hello, World!" + assert formatted_string == expected_string diff --git a/tests/spark.py b/tests/spark.py index c3bd294..b5df699 100644 --- a/tests/spark.py +++ b/tests/spark.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pyspark.sql import SparkSession spark = SparkSession.builder.master("local").appName("chispa").getOrCreate() diff --git a/tests/test_column_comparer.py b/tests/test_column_comparer.py index 825d3c3..1101545 100644 --- a/tests/test_column_comparer.py +++ b/tests/test_column_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from chispa import ColumnsNotEqualError, assert_approx_column_equality, assert_column_equality diff --git a/tests/test_dataframe_comparer.py b/tests/test_dataframe_comparer.py index 97035cf..9f14c20 100644 --- a/tests/test_dataframe_comparer.py +++ b/tests/test_dataframe_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math import pytest diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py new file mode 100644 index 0000000..a38b83d --- /dev/null +++ b/tests/test_deprecated.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import warnings +from dataclasses import dataclass + +import pytest + +from chispa import DataFramesNotEqualError, assert_basic_rows_equality +from chispa.default_formats import DefaultFormats +from chispa.formatting import FormattingConfig + +from .spark import spark + + +def test_default_formats_deprecation_warning(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + DefaultFormats() + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "DefaultFormats is deprecated" in str(w[-1].message) + + +def test_that_default_formats_still_works(): + data1 = [(1, "jose"), (2, "li"), (3, "laura")] + df1 = spark.createDataFrame(data1, ["num", "expected_name"]) + data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] + df2 = spark.createDataFrame(data2, ["name", "expected_name"]) + with pytest.raises(DataFramesNotEqualError): + assert_basic_rows_equality(df1.collect(), df2.collect(), formats=DefaultFormats()) + + +def test_deprecated_arbitrary_dataclass(): + data1 = [(1, "jose"), (2, "li"), (3, "laura")] + df1 = spark.createDataFrame(data1, ["num", "expected_name"]) + data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] + df2 = spark.createDataFrame(data2, ["name", "expected_name"]) + + @dataclass + class CustomFormats: + mismatched_rows = ["green"] # noqa: RUF012 + matched_rows = ["yellow"] # noqa: RUF012 + mismatched_cells = ["purple", "bold"] # noqa: RUF012 + matched_cells = ["cyan"] # noqa: RUF012 + + with warnings.catch_warnings(record=True) as w: + try: + assert_basic_rows_equality(df1.collect(), df2.collect(), formats=CustomFormats()) + # should not reach the line below due to the raised error. + # pytest.raises does not work as expected since then we cannot verify the warning. + assert False + except DataFramesNotEqualError: + warnings.simplefilter("always") + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "Using an arbitrary dataclass is deprecated." in str(w[-1].message) + + +def test_invalid_value_in_default_formats(): + @dataclass + class InvalidFormats: + mismatched_rows = ["green"] # noqa: RUF012 + matched_rows = ["yellow"] # noqa: RUF012 + mismatched_cells = ["purple", "invalid"] # noqa: RUF012 + matched_cells = ["cyan"] # noqa: RUF012 + + with pytest.raises(ValueError): + FormattingConfig._from_arbitrary_dataclass(InvalidFormats()) diff --git a/tests/test_readme_examples.py b/tests/test_readme_examples.py index bd02ba7..fe28ee0 100644 --- a/tests/test_readme_examples.py +++ b/tests/test_readme_examples.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pyspark.sql.functions as F import pytest from pyspark.sql import SparkSession diff --git a/tests/test_row_comparer.py b/tests/test_row_comparer.py index a8d13ee..fe9b48c 100644 --- a/tests/test_row_comparer.py +++ b/tests/test_row_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pyspark.sql import Row from chispa.row_comparer import are_rows_approx_equal, are_rows_equal, are_rows_equal_enhanced diff --git a/tests/test_rows_comparer.py b/tests/test_rows_comparer.py index eda4d2c..34b06fc 100644 --- a/tests/test_rows_comparer.py +++ b/tests/test_rows_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from chispa import DataFramesNotEqualError, assert_basic_rows_equality diff --git a/tests/test_schema_comparer.py b/tests/test_schema_comparer.py index 0fa87f2..eee7d9b 100644 --- a/tests/test_schema_comparer.py +++ b/tests/test_schema_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType, StructField, StructType diff --git a/tests/test_structfield_comparer.py b/tests/test_structfield_comparer.py index 2acf8c7..a6d181d 100644 --- a/tests/test_structfield_comparer.py +++ b/tests/test_structfield_comparer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType from chispa.structfield_comparer import are_structfields_equal diff --git a/tests/test_terminal_str_formatter.py b/tests/test_terminal_str_formatter.py deleted file mode 100644 index b387524..0000000 --- a/tests/test_terminal_str_formatter.py +++ /dev/null @@ -1,9 +0,0 @@ -from chispa.terminal_str_formatter import format_string - - -def test_it_can_make_a_blue_string(): - print(format_string("hi", ["bold", "blink"])) - - -def test_it_works_with_no_formats(): - print(format_string("hi", []))