Skip to content

Commit f5532e0

Browse files
committed
add pandas-numpy equality test
1 parent fe1a473 commit f5532e0

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

kmodes/tests/test_kmodes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import unittest
77

88
import numpy as np
9+
import pandas as pd
910

1011
from kmodes.kmodes import KModes
1112
from kmodes.util.dissim import ng_dissim, jaccard_dissim_binary, jaccard_dissim_label
@@ -575,3 +576,9 @@ def test_kmodes_fit_predict(self):
575576
data1 = kmodes.fit_predict(TEST_DATA, sample_weight=sample_weight)
576577
data2 = kmodes.fit(TEST_DATA, sample_weight=sample_weight).predict(TEST_DATA)
577578
assert_cluster_splits_equal(data1, data2)
579+
580+
def test_pandas_numpy_equality(self):
581+
kmodes = KModes(n_clusters=4, init='Cao', random_state=42)
582+
result_np = kmodes.fit_predict(SOYBEAN)
583+
result_pd = kmodes.fit_predict(pd.DataFrame(SOYBEAN))
584+
np.testing.assert_array_equal(result_np, result_pd)

kmodes/tests/test_kprototypes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import unittest
77

88
import numpy as np
9+
import pandas as pd
910

1011
from kmodes import kprototypes
1112
from kmodes.tests.test_kmodes import assert_cluster_splits_equal
@@ -421,3 +422,9 @@ def test_kmodes_fit_predict_equality(self):
421422
data1 = model1.predict(STOCKS, categorical=[1, 2])
422423
data2 = kproto.fit_predict(STOCKS, categorical=[1, 2], sample_weight=sample_weight)
423424
assert_cluster_splits_equal(data1, data2)
425+
426+
def test_pandas_numpy_equality(self):
427+
kproto = kprototypes.KPrototypes(n_clusters=4, init='Cao', random_state=42)
428+
result_np = kproto.fit_predict(STOCKS, categorical=[1, 2])
429+
result_pd = kproto.fit_predict(pd.DataFrame(STOCKS), categorical=[1, 2])
430+
np.testing.assert_array_equal(result_np, result_pd)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
'dev': [
3535
'pytest',
3636
'pytest-cov',
37+
'pandas'
3738
]
3839
},
3940
classifiers=['Development Status :: 3 - Alpha',

0 commit comments

Comments
 (0)