Skip to content

Commit

Permalink
Final modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
zazass8 committed Oct 22, 2024
1 parent 97ae9cf commit 99dad90
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 48 deletions.
30 changes: 20 additions & 10 deletions mlxtend/frequent_patterns/association_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
# License: BSD 3 clause

from itertools import combinations
from typing import Optional

import numpy as np
import pandas as pd

from ..frequent_patterns import fpcommon as fpc

_metrics = [
"antecedent support",
"consequent support",
Expand All @@ -31,8 +34,8 @@

def association_rules(
df: pd.DataFrame,
df_or: pd.DataFrame,
num_itemsets: int,
df_orig: Optional[pd.DataFrame] = None,
null_values=False,
metric="confidence",
min_threshold=0.8,
Expand All @@ -48,13 +51,13 @@ def association_rules(
pandas DataFrame of frequent itemsets
with columns ['support', 'itemsets']
df_or : pandas DataFrame
DataFrame with original input data
df_orig : pandas DataFrame (default: None)
DataFrame with original input data. Only provided when null_values exist
num_itemsets : int
Number of transactions in original input data
null_values : bool (default: True)
null_values : bool (default: False)
In case there are null values as NaNs in the original input data
metric : string (default: 'confidence')
Expand Down Expand Up @@ -112,6 +115,13 @@ def association_rules(
https://rasbt.github.io/mlxtend/user_guide/frequent_patterns/association_rules/
"""
# if null values exist, df_orig must be provided
if null_values and df_orig is None:
raise TypeError("If null values exist, df_orig must be provided.")

# check for valid input
fpc.valid_input_check(df_orig, null_values)

if not df.shape[0]:
raise ValueError(
"The input DataFrame `df` containing " "the frequent itemsets is empty."
Expand All @@ -125,8 +135,8 @@ def association_rules(
)

def kulczynski_helper(sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_):
conf_AC = sAC / sA
conf_CA = sAC / sC
conf_AC = sAC * (num_itemsets - disAC) / (sA * (num_itemsets - disA) - dis_int)
conf_CA = sAC * (num_itemsets - disAC) / (sC * (num_itemsets - disC) - dis_int_)
kulczynski = (conf_AC + conf_CA) / 2
return kulczynski

Expand Down Expand Up @@ -234,13 +244,13 @@ def certainty_metric_helper(sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_):
rule_supports = []

# Define the disabled df, assign columns from original df to be the same on the disabled.
disabled = df_or.copy()
if null_values:
disabled = df_orig.copy()
disabled = np.where(pd.isna(disabled), 1, np.nan) + np.where(
(disabled == 0) | (disabled == 1), np.nan, 0
)
disabled = pd.DataFrame(disabled)
disabled.columns = df_or.columns
disabled.columns = df_orig.columns

# iterate over all frequent itemsets
for k in frequent_items_dict.keys():
Expand Down Expand Up @@ -280,8 +290,8 @@ def certainty_metric_helper(sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_):
__dec = disabled.loc[:, list(consequent)]

# select data of antecedent and consequent from original
dec_ = df_or.loc[:, list(antecedent)]
dec__ = df_or.loc[:, list(consequent)]
dec_ = df_orig.loc[:, list(antecedent)]
dec__ = df_orig.loc[:, list(consequent)]

# disabled counts
disAC, disA, disC, dis_int, dis_int_ = 0, 0, 0, 0, 0
Expand Down
30 changes: 25 additions & 5 deletions mlxtend/frequent_patterns/fpcommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,15 @@ def setup_fptree(df, min_support):
return tree, disabled, rank


def generate_itemsets(
generator, df_or, disabled, min_support, num_itemsets, colname_map
):
def generate_itemsets(generator, df, disabled, min_support, num_itemsets, colname_map):
itemsets = []
supports = []
df = df_or.copy().values
for sup, iset in generator:
itemsets.append(frozenset(iset))
# select data of iset from disabled dataset
dec = disabled[:, iset]
# select data of iset from original dataset
_dec = df[:, iset]
_dec = df.values[:, iset]

# case if iset only has one element
if len(iset) == 1:
Expand Down Expand Up @@ -122,6 +119,10 @@ def generate_itemsets(


def valid_input_check(df, null_values=False):
# Return early if df is None
if df is None:
return

if f"{type(df)}" == "<class 'pandas.core.frame.SparseDataFrame'>":
msg = (
"SparseDataFrame support has been deprecated in pandas 1.0,"
Expand Down Expand Up @@ -163,6 +164,19 @@ def valid_input_check(df, null_values=False):
"Please use a DataFrame with bool type",
DeprecationWarning,
)

# If null_values is True but no NaNs are found, raise an error
has_nans = pd.isna(df).any().any()
if null_values and not has_nans:
raise ValueError(
"null_values=True is not permitted when there are no NaN values in the DataFrame."
)
# If null_values is False but NaNs are found, raise an error
if not null_values and has_nans:
raise ValueError(
"NaN values are not permitted in the DataFrame when null_values=False."
)

# Pandas is much slower than numpy, so use np.where on Numpy arrays
if hasattr(df, "sparse"):
if df.size == 0:
Expand All @@ -185,6 +199,12 @@ def valid_input_check(df, null_values=False):
"The allowed values for a DataFrame"
" are True, False, 0, 1. Found value %s" % (val)
)

if null_values:
s = (
"The allowed values for a DataFrame"
" are True, False, 0, 1, NaN. Found value %s" % (val)
)
raise ValueError(s)


Expand Down
2 changes: 1 addition & 1 deletion mlxtend/frequent_patterns/fpgrowth.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def fpgrowth(
The support is computed as the fraction
transactions_where_item(s)_occur / total_transactions.
null_values : bool (default: True)
null_values : bool (default: False)
In case there are null values as NaNs in the original input data
use_colnames : bool (default: False)
Expand Down
66 changes: 34 additions & 32 deletions mlxtend/frequent_patterns/tests/test_association_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

# fmt: off
def test_default():
res_df = association_rules(df_freq_items, df, len(df))
res_df = association_rules(df_freq_items, len(df), df)
res_df["antecedents"] = res_df["antecedents"].apply(lambda x: str(frozenset(x)))
res_df["consequents"] = res_df["consequents"].apply(lambda x: str(frozenset(x)))
res_df.sort_values(columns_ordered, inplace=True)
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_default():


def test_datatypes():
res_df = association_rules(df_freq_items, df, len(df))
res_df = association_rules(df_freq_items, len(df), df)
for i in res_df["antecedents"]:
assert isinstance(i, frozenset) is True

Expand All @@ -101,7 +101,7 @@ def test_datatypes():
lambda x: set(x)
)

res_df = association_rules(df_freq_items, df, len(df))
res_df = association_rules(df_freq_items, len(df), df)
for i in res_df["antecedents"]:
assert isinstance(i, frozenset) is True

Expand All @@ -111,17 +111,17 @@ def test_datatypes():

def test_no_support_col():
df_no_support_col = df_freq_items.loc[:, ["itemsets"]]
numpy_assert_raises(ValueError, association_rules, df_no_support_col, df, len(df))
numpy_assert_raises(ValueError, association_rules, df_no_support_col, len(df), df)


def test_no_itemsets_col():
df_no_itemsets_col = df_freq_items.loc[:, ["support"]]
numpy_assert_raises(ValueError, association_rules, df_no_itemsets_col, df, len(df))
numpy_assert_raises(ValueError, association_rules, df_no_itemsets_col, len(df), df)


def test_wrong_metric():
numpy_assert_raises(
ValueError, association_rules, df_freq_items, df, len(df), False, "unicorn"
ValueError, association_rules, df_freq_items, len(df), df, False, "unicorn"
)


Expand All @@ -144,68 +144,68 @@ def test_empty_result():
"kulczynski",
]
)
res_df = association_rules(df_freq_items, df, len(df), min_threshold=2)
res_df = association_rules(df_freq_items, len(df), df, min_threshold=2)
assert res_df.equals(expect)


def test_leverage():
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=0.1, metric="leverage"
df_freq_items, len(df), df, min_threshold=0.1, metric="leverage"
)
assert res_df.values.shape[0] == 6

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=0.1, metric="leverage"
df_freq_items_with_colnames, len(df), df, min_threshold=0.1, metric="leverage"
)
assert res_df.values.shape[0] == 6


def test_conviction():
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=1.5, metric="conviction"
df_freq_items, len(df), df, min_threshold=1.5, metric="conviction"
)
assert res_df.values.shape[0] == 11

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=1.5, metric="conviction"
df_freq_items_with_colnames, len(df), df, min_threshold=1.5, metric="conviction"
)
assert res_df.values.shape[0] == 11


def test_lift():
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=1.1, metric="lift"
df_freq_items, len(df), df, min_threshold=1.1, metric="lift"
)
assert res_df.values.shape[0] == 6

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=1.1, metric="lift"
df_freq_items_with_colnames, len(df), df, min_threshold=1.1, metric="lift"
)
assert res_df.values.shape[0] == 6


def test_confidence():
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=0.8, metric="confidence"
df_freq_items, len(df), df, min_threshold=0.8, metric="confidence"
)
assert res_df.values.shape[0] == 9

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=0.8, metric="confidence"
df_freq_items_with_colnames, len(df), df, min_threshold=0.8, metric="confidence"
)
assert res_df.values.shape[0] == 9


def test_representativity():
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=1.0, metric="representativity"
df_freq_items, len(df), df, min_threshold=1.0, metric="representativity"
)
assert res_df.values.shape[0] == 16

res_df = association_rules(
df_freq_items_with_colnames,
df,
len(df),
df,
min_threshold=1.0,
metric="representativity",
)
Expand All @@ -214,42 +214,42 @@ def test_representativity():

def test_jaccard():
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=0.7, metric="jaccard"
df_freq_items, len(df), df, min_threshold=0.7, metric="jaccard"
)
assert res_df.values.shape[0] == 8

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=0.7, metric="jaccard"
df_freq_items_with_colnames, len(df), df, min_threshold=0.7, metric="jaccard"
)
assert res_df.values.shape[0] == 8


def test_certainty():
res_df = association_rules(
df_freq_items, df, len(df), metric="certainty", min_threshold=0.6
df_freq_items, len(df), df, metric="certainty", min_threshold=0.6
)
assert res_df.values.shape[0] == 3

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), metric="certainty", min_threshold=0.6
df_freq_items_with_colnames, len(df), df, metric="certainty", min_threshold=0.6
)
assert res_df.values.shape[0] == 3


def test_kulczynski():
res_df = association_rules(
df_freq_items, df, len(df), metric="kulczynski", min_threshold=0.9
df_freq_items, len(df), df, metric="kulczynski", min_threshold=0.9
)
assert res_df.values.shape[0] == 2

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), metric="kulczynski", min_threshold=0.6
df_freq_items_with_colnames, len(df), df, metric="kulczynski", min_threshold=0.6
)
assert res_df.values.shape[0] == 16


def test_frozenset_selection():
res_df = association_rules(df_freq_items, df, len(df))
res_df = association_rules(df_freq_items, len(df), df)

sel = res_df[res_df["consequents"] == frozenset((3, 5))]
assert sel.values.shape[0] == 1
Expand All @@ -266,18 +266,18 @@ def test_frozenset_selection():

def test_override_metric_with_support():
res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=0.8
df_freq_items_with_colnames, len(df), df, min_threshold=0.8
)
# default metric is confidence
assert res_df.values.shape[0] == 9

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=0.8, metric="support"
df_freq_items_with_colnames, len(df), df, min_threshold=0.8, metric="support"
)
assert res_df.values.shape[0] == 2

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=0.8, support_only=True
df_freq_items_with_colnames, len(df), df, min_threshold=0.8, support_only=True
)
assert res_df.values.shape[0] == 2

Expand Down Expand Up @@ -308,9 +308,9 @@ def test_on_df_with_missing_entries():
],
}

df = pd.DataFrame(dict)
df_missing = pd.DataFrame(dict)

numpy_assert_raises(KeyError, association_rules, df, df, len(df))
numpy_assert_raises(KeyError, association_rules, df_missing, len(df), df)


def test_on_df_with_missing_entries_support_only():
Expand Down Expand Up @@ -339,8 +339,10 @@ def test_on_df_with_missing_entries_support_only():
],
}

df = pd.DataFrame(dict)
df_result = association_rules(df, df, len(df), support_only=True, min_threshold=0.1)
df_missing = pd.DataFrame(dict)
df_result = association_rules(
df_missing, len(df), df, support_only=True, min_threshold=0.1
)

assert df_result["support"].shape == (18,)
assert int(np.isnan(df_result["support"].values).any()) != 1
Expand All @@ -349,4 +351,4 @@ def test_on_df_with_missing_entries_support_only():
def test_with_empty_dataframe():
df_freq = df_freq_items_with_colnames.iloc[:0]
with pytest.raises(ValueError):
association_rules(df_freq, df, len(df))
association_rules(df_freq, len(df), df)

0 comments on commit 99dad90

Please sign in to comment.