Skip to content

Commit

Permalink
FPGrowth/FPMax and Association Rules with the existence of missing va…
Browse files Browse the repository at this point in the history
…lues (rasbt#1004) (rasbt#1106)

* Updated FPGrowth/FPMax and Association Rules with the existence of missing values

* Re-structure and document code

* Update unit tests

* Update CHANGELOG.md

* Modify the corresponding documentation in Jupyter notebooks

* Final modifications
  • Loading branch information
zazass8 committed Jan 25, 2025
1 parent 4864804 commit 7d08390
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 61 deletions.
51 changes: 0 additions & 51 deletions mlxtend/frequent_patterns/association_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,57 +285,6 @@ def certainty_metric_helper(sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_):
try:
sA = frequent_items_dict[antecedent]
sC = frequent_items_dict[consequent]

# if the input dataframe is complete
if not null_values:
disAC, disA, disC, dis_int, dis_int_ = 0, 0, 0, 0, 0

else:
an = list(antecedent)
con = list(consequent)
an.extend(con)

# select data of antecedent, consequent and combined from disabled
dec = disabled.loc[:, an]
_dec = disabled.loc[:, list(antecedent)]
__dec = disabled.loc[:, list(consequent)]

# select data of antecedent and consequent from original
dec_ = df_orig.loc[:, list(antecedent)]
dec__ = df_orig.loc[:, list(consequent)]

# disabled counts
disAC, disA, disC, dis_int, dis_int_ = 0, 0, 0, 0, 0
for i in range(len(dec.index)):
# select the i-th iset from the disabled dataset
item_comb = list(dec.iloc[i, :])
item_dis_an = list(_dec.iloc[i, :])
item_dis_con = list(__dec.iloc[i, :])

# select the i-th iset from the original dataset
item_or_an = list(dec_.iloc[i, :])
item_or_con = list(dec__.iloc[i, :])

# check and keep count if there is a null value in combined, antecedent, consequent
if 1 in set(item_comb):
disAC += 1
if 1 in set(item_dis_an):
disA += 1
if 1 in item_dis_con:
disC += 1

# check and keep count if there is a null value in consequent AND all items are present in antecedent
if (1 in item_dis_con) and all(
j == 1 for j in item_or_an
):
dis_int += 1

# check and keep count if there is a null value in antecedent AND all items are present in consequent
if (1 in item_dis_an) and all(
j == 1 for j in item_or_con
):
dis_int_ += 1

except KeyError as e:
s = (
str(e) + "You are likely getting this error"
Expand Down
170 changes: 160 additions & 10 deletions mlxtend/frequent_patterns/tests/test_association_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,167 @@ def test_nullability():

expect = pd.DataFrame(
[
[(10, 3), (5,), 0.667, 1.0, 0.667, 1.0, 1.0, 0.6, 0.0, np.inf, 0, 0.667, 0, 0.833],
[(10, 5), (3,), 0.667, 1.0, 0.667, 1.0, 1.0, 0.6, 0.0, np.inf, 0, 0.667, 0.0, 0.833],
[(10,), (3, 5), 0.75, 1.0, 0.667, 1.0, 1.0, 0.6, -0.083, np.inf, -0.333, 0.615, 0.0, 0.833],
[(10,), (3,), 0.75, 1.0, 0.667, 1.0, 1.0, 0.6, -0.083, np.inf, -0.333, 0.615, 0.0, 0.833],
[(10,), (5,), 0.75, 1.0, 0.667, 1.0, 1.0, 0.6, -0.083, np.inf, -0.333, 0.615, 0, 0.833]
[(3, 5), (10,), 1.0, 0.75, 0.667, 0.667, 0.889, 0.6, -0.083, 0.75, -1.0, 0.615, -0.333, 0.833],
[(3,), (10, 5), 1.0, 0.667, 0.667, 0.667, 1.0, 0.6, 0.0, 1.0, 0, 0.667, 0.0, 0.833],
[(3,), (10,), 1.0, 0.75, 0.667, 0.667, 0.889, 0.6, -0.083, 0.75, -1.0, 0.615, -0.333, 0.833],
[
(10, 3),
(5,),
0.667,
1.0,
0.667,
1.0,
1.0,
0.6,
0.0,
np.inf,
0,
0.667,
0,
0.833,
],
[
(10, 5),
(3,),
0.667,
1.0,
0.667,
1.0,
1.0,
0.6,
0.0,
np.inf,
0,
0.667,
0.0,
0.833,
],
[
(10,),
(3, 5),
0.75,
1.0,
0.667,
1.0,
1.0,
0.6,
-0.083,
np.inf,
-0.333,
0.615,
0.0,
0.833,
],
[
(10,),
(3,),
0.75,
1.0,
0.667,
1.0,
1.0,
0.6,
-0.083,
np.inf,
-0.333,
0.615,
0.0,
0.833,
],
[
(10,),
(5,),
0.75,
1.0,
0.667,
1.0,
1.0,
0.6,
-0.083,
np.inf,
-0.333,
0.615,
0,
0.833,
],
[
(3, 5),
(10,),
1.0,
0.75,
0.667,
0.667,
0.889,
0.6,
-0.083,
0.75,
-1.0,
0.615,
-0.333,
0.833,
],
[
(3,),
(10, 5),
1.0,
0.667,
0.667,
0.667,
1.0,
0.6,
0.0,
1.0,
0,
0.667,
0.0,
0.833,
],
[
(3,),
(10,),
1.0,
0.75,
0.667,
0.667,
0.889,
0.6,
-0.083,
0.75,
-1.0,
0.615,
-0.333,
0.833,
],
[(3,), (5,), 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.0, np.inf, 0, 1.0, 0, 1.0],
[(5,), (10, 3), 1.0, 0.667, 0.667, 0.667, 1.0, 0.6, 0.0, 1.0, 0, 0.667, 0, 0.833],
[(5,), (10,), 1.0, 0.75, 0.667, 0.667, 0.889, 0.6, -0.083, 0.75, -1.0, 0.615, -0.333, 0.833],
[
(5,),
(10, 3),
1.0,
0.667,
0.667,
0.667,
1.0,
0.6,
0.0,
1.0,
0,
0.667,
0,
0.833,
],
[
(5,),
(10,),
1.0,
0.75,
0.667,
0.667,
0.889,
0.6,
-0.083,
0.75,
-1.0,
0.615,
-0.333,
0.833,
],
[(5,), (3,), 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.0, np.inf, 0, 1.0, 0.0, 1.0],
],
columns=columns_ordered,
Expand Down

0 comments on commit 7d08390

Please sign in to comment.