Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions correlation_clustering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
"metadata": {},
"outputs": [],
"source": [
"X, y = make_blobs(n_samples=100, centers=5, n_features=2, cluster_std=0.4, random_state=0)"
"X, y = make_blobs(n_samples=100, centers=5, n_features=2, cluster_std=0.4, random_state=0)\n",
"\n",
"# test case for transitivity and unary coding clauses\n",
"# X, y = make_blobs(n_samples=25, centers=5, n_features=2, cluster_std=0.8, random_state=0)"
]
},
{
Expand Down Expand Up @@ -183,12 +186,13 @@
"\n",
" # Constraints\n",
" f.write('Subject To\\n')\n",
" constraint_i = 0\n",
" for i in range(dim-2):\n",
" for j in range(i+1, dim-1):\n",
" for k in range(j+1, dim):\n",
" constraint_i += 1\n",
" # Note: only one constraint \"x[i,j] & x[j,k] => x[i,k]\" is not enough!\n",
" f.write(' x{} + x{} - x{} <= 1\\n'.format(x[i,j], x[j,k], x[i,k]))\n",
" f.write(' x{} + x{} - x{} <= 1\\n'.format(x[i,j], x[i,k], x[j,k]))\n",
" f.write(' x{} + x{} - x{} <= 1\\n'.format(x[i,k], x[j,k], x[i,j]))\n",
"\n",
" # Binary\n",
" f.write('Binary\\n')\n",
Expand Down Expand Up @@ -560,7 +564,10 @@
"for i in range(dim-2):\n",
" for j in range(i+1, dim-1):\n",
" for k in range(j+1, dim):\n",
" # Note: only one constraint \"x[i,j] & x[j,k] => x[i,k]\" is not enough (see the test case at the top)\n",
" hard_clauses.append('{} -{} -{} {} 0\\n'.format(hard_clause_weight, int_map[i,j], int_map[j,k], int_map[i,k]))\n",
" hard_clauses.append('{} -{} -{} {} 0\\n'.format(hard_clause_weight, int_map[i,j], int_map[i,k], int_map[j,k]))\n",
" hard_clauses.append('{} -{} -{} {} 0\\n'.format(hard_clause_weight, int_map[i,k], int_map[j,k], int_map[i,j]))\n",
"\n",
"# Soft clauses\n",
"soft_clauses= []\n",
Expand Down Expand Up @@ -673,6 +680,25 @@
" if len(cluster) > 0:\n",
" clusters.append(cluster)\n",
"\n",
"# make `x` matrix\n",
"x = np.zeros(W.shape, dtype=np.int32)\n",
"for i in range(dim-1):\n",
" for j in range(i+1, dim):\n",
" if inc_dict[int_map[i,j]]:\n",
" x[i,j] = 1\n",
"\n",
"# test the transitivity of `x`\n",
"for i in range(dim-2):\n",
" for j in range(i+1, dim-1):\n",
" for k in range(j+1, dim):\n",
" if x[i,j] == 1 and x[j,k] == 1 and x[i,k] == 0:\n",
" print(f\"non-transitive: {i} {j} {k}\")\n",
" if x[i,j] == 1 and x[j,k] == 0 and x[i,k] == 1:\n",
" print(f\"non-transitive: {i} {j} {k}\")\n",
" if x[i,j] == 0 and x[j,k] == 1 and x[i,k] == 1:\n",
" print(f\"non-transitive: {i} {j} {k}\")\n",
"\n",
"\n",
"print('Found {} clusters'.format(len(clusters)))"
]
},
Expand Down Expand Up @@ -784,6 +810,11 @@
" hard_clauses.append('{} -{} {} 0\\n'.format(hard_clause_weight, y[i][k], s[i][k]))\n",
" hard_clauses.append('{} -{} {} 0\\n'.format(hard_clause_weight, s[i][k-1], s[i][k]))\n",
" hard_clauses.append('{} -{} -{} 0\\n'.format(hard_clause_weight, y[i][k], s[i][k-1]))\n",
" # Also require that sum(y[i][:]) >= 1 (see the test case at the top)\n",
" or_clause = '{}'.format(hard_clause_weight)\n",
" for k in range(K):\n",
" or_clause += ' {}'.format(y[i][k])\n",
" hard_clauses.append(or_clause + ' 0\\n')\n",
"\n",
"for i in range(dim-1):\n",
" for j in range(i+1, dim):\n",
Expand Down