Skip to content

Commit

Permalink
Adjust AAclust test
Browse files Browse the repository at this point in the history
  • Loading branch information
breimanntools committed Jun 30, 2024
1 parent d731fb6 commit f372ef3
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 148 deletions.
8 changes: 4 additions & 4 deletions aaanalysis/_utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,11 @@ def plot_legend_(ax=None, dict_color=None, list_cat=None, labels=None,
lw, edgecolor, linestyle[i], hatch[i], hatchcolor)
for i, cat in enumerate(list_cat)]
# Create new legend
legend = ax.legend(handles=handles, **args)
ax.add_artist(legend)
legend = ax.legend(handles=handles, labels=labels, **args)
if title_align_left:
legend._legend_box.align = "left"
# Add the legend as an artist if add_legend is True
if keep_legend:
# Add the legend as an artist (must be inside plot)
if keep_legend and old_legend:
print("here")
ax.add_artist(old_legend)
return ax
1 change: 0 additions & 1 deletion aaanalysis/explainable_ai_pro/_shap_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def check_is_selected(is_selected=None, n_feat=None):

def check_match_labels_fuzzy_labeling(labels=None, fuzzy_labeling=False, verbose=True):
"""Check if only on label is fuzzy labeled and that the remaining sample balanced (best training scenario)"""
# TODO adjust for multi-class
if not fuzzy_labeling:
return # Skip check if fuzzy labeling is not enabled
# Check for one fuzzy label and balance among other labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _sort_X_labels_names(X, labels=None, names=None):
names = [names[i] for i in sorted_order]
return X, labels, names


def _get_df_corr(X=None, X_ref=None):
"""Get df with correlations"""
# Temporary labels to avoid any confusion with potential duplicates
Expand Down
Empty file.
39 changes: 27 additions & 12 deletions aaanalysis/feature_engineering/_numerical_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,32 @@ def check_match_df_scales_letter_new(df_scales=None, letter_new=None):
raise ValueError(f"Letter '{letter_new}' already exists in alphabet of 'df_scales': {alphabet}")


# TODO add corr filtering, scale_coverage ...
# TODO (see AAclust.comp_coverage, AAclust.comp_correlation),
# TODO filter_correlation)
# TODO add filter_coverage
"""
def get_select_scales(list_subcat_ref=None):
""""""
print("top")
df_scales = aa.load_scales()
df_cat = aa.load_scales(name="scales_cat")
df_cat = df_cat[df_cat["subcategory"].isin(list_subcat_ref)]
df_scales = df_scales[df_cat["scale_id"].to_list()]
aac = aa.AAclust()
n = len(list_subcat_ref)
coverage = 0
scales = []
while coverage != 100:
X = np.array(df_scales).T
scales = aac.fit(X, names=list(df_scales), n_clusters=n).medoid_names_
list_subcat = df_cat[df_cat["scale_id"].isin(scales)]["subcategory"].to_list()
coverage = aac.comp_coverage(names=list_subcat, names_ref=list_subcat_ref)
#print(len(list_subcat), coverage, [x for x in list_subcat_ref if x not in list_subcat])
n += 1
df_scales = df_scales[scales]
return df_scales
"""
# II Main Functions
class NumericalFeature:
"""
Expand All @@ -34,8 +59,7 @@ def comp_correlation():
@staticmethod
def filter_correlation():
@staticmethod
def scale_coverage():
"""

@staticmethod
Expand Down Expand Up @@ -87,12 +111,3 @@ def extend_alphabet(df_scales: pd.DataFrame = None,
# Add the new letter to the DataFrame
df_scales.loc[new_letter] = new_values
return df_scales

"""
@staticmethod
def merge_alphabet(df_scales : pd.DataFrame = None,
letters_to_merge : List[str] = None,
letter_new : str = None,
value_type: Literal["min", "mean", "median", "max"] = "mean"
) -> pd.DataFrame:
"""
2 changes: 1 addition & 1 deletion aaanalysis/plotting/_plot_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def plot_legend(ax: Optional[plt.Axes] = None,
title_align_left : bool, default=True
Whether to align the title to the left.
keep_legend: bool, default=False
If ``True``, keep existing legend and add a new one.
If ``True``, keep existing legend (must be within plot) and add a new one.
**kwargs
Further key word arguments for :attr:`matplotlib.axes.Axes.legend`.
Expand Down
50 changes: 25 additions & 25 deletions examples/feature_engineering/aac_plot_centers.ipynb

Large diffs are not rendered by default.

173 changes: 90 additions & 83 deletions examples/plotting/plot_legend.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions tests/unit/aaclust_tests/test_aaclust_comp_centers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,7 @@ def test_combination_valid_parameters(self, X):
elements=some.floats()))
def test_combination_invalid_parameters(self, X, labels):
"""Test combination of invalid parameters."""
with pytest.raises(ValueError):
aa.AAclust().comp_centers(X, labels)
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
with pytest.raises(ValueError):
aa.AAclust().comp_centers(X, labels)
55 changes: 35 additions & 20 deletions tests/unit/plotting_tests/test_plot_legend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This is a script for testing the plot_set_legend function.
This is a script for testing the plot_legend function.
"""
from hypothesis import given, example, settings
from hypothesis import strategies as st
Expand All @@ -23,7 +23,7 @@ def dict_color():


class TestPlotSetLegend:
"""Test plot_set_legend function"""
"""Test plot_legend function"""

@pytest.fixture(autouse=True)
def create_fig_and_ax(self):
Expand Down Expand Up @@ -62,31 +62,31 @@ def test_invalid_marker(self, marker, dict_color):

@settings(max_examples=5, deadline=500)
@given(n_cols=st.integers(min_value=1, max_value=10))
def test_plot_set_legend_n_cols(self, n_cols, dict_color):
def test_plot_legend_n_cols(self, n_cols, dict_color):
"""Test the 'n_cols' parameter."""
ax = plt.gca()
result = aa.plot_legend(ax=ax, n_cols=n_cols, dict_color=dict_color)
assert isinstance(result, plt.Axes)

@settings(max_examples=5, deadline=500)
@given(labelspacing=st.floats(0, 5))
def test_plot_set_legend_labelspacing(self, labelspacing, dict_color):
def test_plot_legend_labelspacing(self, labelspacing, dict_color):
"""Test the 'labelspacing' parameter."""
ax = plt.gca()
result = aa.plot_legend(ax=ax, labelspacing=labelspacing, dict_color=dict_color)
assert isinstance(result, plt.Axes)

@settings(max_examples=5, deadline=500)
@given(columnspacing=st.floats(0, 5))
def test_plot_set_legend_columnspacing(self, columnspacing, dict_color):
def test_plot_legend_columnspacing(self, columnspacing, dict_color):
"""Test the 'columnspacing' parameter."""
ax = plt.gca()
result = aa.plot_legend(ax=ax, columnspacing=columnspacing, dict_color=dict_color)
assert isinstance(result, plt.Axes)

@settings(max_examples=5, deadline=500)
@given(handletextpad=st.floats(0, 5))
def test_plot_set_legend_handletextpad(self, handletextpad, dict_color):
def test_plot_legend_handletextpad(self, handletextpad, dict_color):
"""Test the 'handletextpad' parameter."""
ax = plt.gca()
result = aa.plot_legend(ax=ax, handletextpad=handletextpad, dict_color=dict_color)
Expand All @@ -98,28 +98,28 @@ def test_legend_positioning(self, dict_color):
result = aa.plot_legend(ax=ax, loc_out=True, x=0, y=0, dict_color=dict_color)
assert isinstance(result, plt.Axes)

def test_plot_set_legend_loc_outside(self, dict_color):
def test_plot_legend_loc_outside(self, dict_color):
"""Test 'loc_out' parameter."""
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 2], label="Sample Line")
aa.plot_legend(ax=ax, loc_out=True, dict_color=dict_color)
assert ax.get_legend().get_bbox_to_anchor().y0 <= 0

def test_plot_set_legend_invalid_fontsize(self, dict_color):
def test_plot_legend_invalid_fontsize(self, dict_color):
"""Test with 'fontsize' less than 0."""
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 2], label="Sample Line")
with pytest.raises(ValueError):
aa.plot_legend(ax=ax, dict_color=dict_color, fontsize=-5)

def test_plot_set_legend_invalid_marker_size(self, dict_color):
def test_plot_legend_invalid_marker_size(self, dict_color):
"""Test with negative 'marker_size'."""
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 2], label="Sample Line")
with pytest.raises(ValueError):
aa.plot_legend(ax=ax, dict_color=dict_color, marker_size=-10)

def test_plot_set_legend_color_and_category(self, dict_color):
def test_plot_legend_color_and_category(self, dict_color):
"""Test 'dict_color' and 'list_cat' parameters together."""
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 2], label="Sample Line")
Expand All @@ -129,7 +129,7 @@ def test_plot_set_legend_color_and_category(self, dict_color):
legend_texts = [text.get_text() for text in ax.get_legend().get_texts()]
assert set(categories) == set(legend_texts)

def test_plot_set_legend_invalid_color_and_category(self, dict_color):
def test_plot_legend_invalid_color_and_category(self, dict_color):
"""Test with invalid 'dict_color' and 'list_cat'."""
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 2], label="Sample Line")
Expand Down Expand Up @@ -158,29 +158,44 @@ def test_invalid_list_cat(self, random_cat):
else:
aa.plot_legend(ax=ax, dict_color={'A': 'red'}, list_cat=["B"])

def test_plot_keep_legend(self, dict_color):
"""Test the 'labelspacing' parameter."""
ax = plt.gca()
for keep_legend in [False, True]:
result = aa.plot_legend(ax=ax, dict_color=dict_color, keep_legend=keep_legend)
assert isinstance(result, plt.Axes)

def test_plot_keep_legend_invalid(self, dict_color):
"""Test the 'labelspacing' parameter."""
ax = plt.gca()
for keep_legend in [None, 1, "invalid", []]:
with pytest.raises(ValueError):
result = aa.plot_legend(ax=ax, dict_color=dict_color, keep_legend=keep_legend)


# II. Complex Cases Test Class
class TestPlotSetLegendComplex:
"""Test plot_set_legend function with complex scenarios."""
class TestPlotLegendComplex:
"""Test plot_legend function with complex scenarios."""

@settings(max_examples=5, deadline=500)
@given(st.floats(1, 10))
def test_plot_set_legend_n_cols(self, n_cols):
def test_plot_legend_n_cols(self, n_cols):
ax = plt.gca()
dict_color = {str(i): "r" for i in range(0, 10)}
result = aa.plot_legend(ax=ax, n_cols=int(n_cols), dict_color=dict_color)
assert isinstance(result, plt.Axes)

@settings(max_examples=5, deadline=500)
@given(st.lists(st.text(), min_size=2, max_size=5))
def test_plot_set_legend_custom_labels(self, labels):
def test_plot_legend_custom_labels(self, labels):
fig, ax = plt.subplots()
labels = list(set([x.replace("_", "X") for x in labels if len(x) > 1]))
# Use the plot_legend function to create the legend
ax = aa.plot_legend(ax=ax, dict_color={label: "red" for label in labels})
# Extract and verify the legend labels
legend_labels = [text.get_text() for text in ax.get_legend().get_texts()]
assert set(legend_labels) == set(labels)
if len(labels) > 1:
# Use the plot_legend function to create the legend
ax = aa.plot_legend(ax=ax, dict_color={label: "red" for label in labels})
# Extract and verify the legend labels
legend_labels = [text.get_text() for text in ax.get_legend().get_texts()]
assert set(legend_labels) == set(labels)

def test_handles_generation(self, dict_color):
"""Test handles based on dict_color and list_cat."""
Expand Down

0 comments on commit f372ef3

Please sign in to comment.