Skip to content

Commit

Permalink
improve estimation of gamma
Browse files Browse the repository at this point in the history
  • Loading branch information
nicodv committed Sep 6, 2022
1 parent f5532e0 commit db581f2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion kmodes/kprototypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def k_prototypes(X, categorical, n_clusters, max_iter, num_dissim, cat_dissim,
# Estimate a good value for gamma, which determines the weighing of
# categorical values in clusters (see Huang [1997]).
if gamma is None:
gamma = 0.5 * Xnum.std()
gamma = 0.5 * np.mean(Xnum.std(axis=0))

results = []
seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init)
Expand Down
20 changes: 20 additions & 0 deletions kmodes/tests/test_kprototypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,23 @@ def test_pandas_numpy_equality(self):
result_np = kproto.fit_predict(STOCKS, categorical=[1, 2])
result_pd = kproto.fit_predict(pd.DataFrame(STOCKS), categorical=[1, 2])
np.testing.assert_array_equal(result_np, result_pd)

def test_gamma_estimation(self):
data = np.hstack([
np.array([
[0.0],
[0.0],
[0.0],
[1.0],
[1.0],
[1.0],
[2.0],
[2.0],
[2.0],
[3.0],
[4.0],
[5.0],
]), STOCKS])
kproto = kprototypes.KPrototypes(n_clusters=4, init='Cao', random_state=42)
kproto_fitted = kproto.fit(data, categorical=[2, 3])
self.assertEqual(kproto_fitted.gamma, 35.33525036439546)

0 comments on commit db581f2

Please sign in to comment.