diff --git a/.travis.yml b/.travis.yml index e7c84e8..ea3fe5e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,5 @@ language: node_js node_js: - - '8' - '10' script: - yarn test diff --git a/README.md b/README.md index 111cbd2..24b58da 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,16 @@ for (let i = 0; i < nEpochs; i++) { const embedding = umap.getEmbedding(); ``` +#### Supervised projection using labels + +```typescript +import { UMAP } from 'umap-js'; + +const umap = new UMAP(); +umap.setSupervisedProjection(labels); +const embedding = umap.fit(data); +``` + #### Parameters The UMAP constructor can accept a number of parameters via a `UMAPParameters` object: diff --git a/src/matrix.ts b/src/matrix.ts index beb66c9..96741a2 100644 --- a/src/matrix.ts +++ b/src/matrix.ts @@ -85,6 +85,10 @@ export class SparseMatrix { } } + getDims(): number[] { + return [this.nRows, this.nCols]; + } + getRows(): number[] { return [...this.rows]; } @@ -157,7 +161,10 @@ export function identity(size: number[]): SparseMatrix { /** * Element-wise multiplication of two matrices */ -export function dotMultiply(a: SparseMatrix, b: SparseMatrix): SparseMatrix { +export function pairwiseMultiply( + a: SparseMatrix, + b: SparseMatrix +): SparseMatrix { return elementWise(a, b, (x, y) => x * y); } @@ -184,6 +191,89 @@ export function multiplyScalar(a: SparseMatrix, scalar: number): SparseMatrix { }); } +/** + * Returns a new matrix with zero entries removed. + */ +export function eliminateZeros(m: SparseMatrix) { + const zeroIndices = new Set(); + const values = m.getValues(); + const rows = m.getRows(); + const cols = m.getCols(); + for (let i = 0; i < values.length; i++) { + if (values[i] === 0) { + zeroIndices.add(i); + } + } + const removeByZeroIndex = (_, index: number) => !zeroIndices.has(index); + const nextValues = values.filter(removeByZeroIndex); + const nextRows = rows.filter(removeByZeroIndex); + const nextCols = cols.filter(removeByZeroIndex); + + return new SparseMatrix(nextRows, nextCols, nextValues, m.getDims()); +} + +/** + * Normalization of a sparse matrix. + */ +export function normalize(m: SparseMatrix, normType = NormType.l2) { + const normFn = normFns[normType]; + + const colsByRow = new Map(); + m.forEach((_, row, col) => { + const cols = colsByRow.get(row) || []; + cols.push(col); + colsByRow.set(row, cols); + }); + + const nextMatrix = new SparseMatrix([], [], [], m.getDims()); + + for (let row of colsByRow.keys()) { + const cols = colsByRow.get(row)!.sort(); + + const vals = cols.map(col => m.get(row, col)); + const norm = normFn(vals); + for (let i = 0; i < norm.length; i++) { + nextMatrix.set(row, cols[i], norm[i]); + } + } + + return nextMatrix; +} + +/** + * Vector normalization functions + */ +type NormFns = { [key in NormType]: (v: number[]) => number[] }; +const normFns: NormFns = { + [NormType.max]: (xs: number[]) => { + let max = -Infinity; + for (let i = 0; i < xs.length; i++) { + max = xs[i] > max ? xs[i] : max; + } + return xs.map(x => x / max); + }, + [NormType.l1]: (xs: number[]) => { + let sum = 0; + for (let i = 0; i < xs.length; i++) { + sum += xs[i]; + } + return xs.map(x => x / sum); + }, + [NormType.l2]: (xs: number[]) => { + let sum = 0; + for (let i = 0; i < xs.length; i++) { + sum += xs[i] ** 2; + } + return xs.map(x => Math.sqrt(x ** 2 / sum)); + }, +}; + +export const enum NormType { + max = 'max', + l1 = 'l1', + l2 = 'l2', +} + /** * Helper function for element-wise operations. */ diff --git a/src/umap.ts b/src/umap.ts index 04cb1ec..bdd3c98 100644 --- a/src/umap.ts +++ b/src/umap.ts @@ -1,11 +1,8 @@ /* Copyright 2019 Google Inc. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -63,6 +60,11 @@ export type DistanceFn = (x: Vector, y: Vector) => number; export type EpochCallback = (epoch: number) => boolean | void; export type Vector = number[]; export type Vectors = Vector[]; +export const enum TargetMetric { + categorical = 'categorical', + l1 = 'l1', + l2 = 'l2', +} const SMOOTH_K_TOLERANCE = 1e-5; const MIN_K_DIST_SCALE = 1e-3; @@ -110,10 +112,33 @@ export interface UMAPParameters { * The effective scale of embedded points. In combination with ``min_dist`` * this determines how clustered/clumped the embedded points are. */ - spread?: number; } +export interface UMAPSupervisedParams { + /** + * The metric used to measure distance for a target array is using supervised + * dimension reduction. By default this is 'categorical' which will measure + * distance in terms of whether categories match or are different. Furthermore, + * if semi-supervised is required target values of -1 will be treated as + * unlabelled under the 'categorical' metric. If the target array takes + * continuous values (e.g. for a regression problem) then metric of 'l1' + * or 'l2' is probably more appropriate. + */ + targetMetric?: TargetMetric; + /** + * Weighting factor between data topology and target topology. A value of + * 0.0 weights entirely on data, a value of 1.0 weights entirely on target. + * The default of 0.5 balances the weighting equally between data and target. + */ + targetWeight?: number; + /** + * The number of nearest neighbors to use to construct the target simplcial + * set. Defaults to the `nearestNeighbors` parameter. + */ + targetNNeighbors?: number; +} + /** * UMAP projection system, based on the python implementation from McInnes, L, * Healy, J, UMAP: Uniform Manifold Approximation and Projection for Dimension @@ -144,6 +169,11 @@ export class UMAP { private random = Math.random; private spread = 1.0; + // Supervised projection params + private targetMetric = TargetMetric.categorical; + private targetWeight = 0.5; + private targetNNeighbors = this.nNeighbors; + private distanceFn: DistanceFn = euclidean; // KNN state (can be precomputed and supplied via initializeFit) @@ -152,9 +182,12 @@ export class UMAP { // Internal graph connectivity representation private graph!: matrix.SparseMatrix; - private data!: Vectors; + private X!: Vectors; private isInitialized = false; + // Supervised projection labels / targets + private Y?: number[]; + // Projected embedding private embedding: number[][] = []; private optimizationState = new OptimizationState(); @@ -164,8 +197,8 @@ export class UMAP { this.nComponents = params.nComponents || this.nComponents; this.nEpochs = params.nEpochs || this.nEpochs; this.nNeighbors = params.nNeighbors || this.nNeighbors; - this.spread = params.spread || this.spread; this.random = params.random || this.random; + this.spread = params.spread || this.spread; } /** @@ -192,34 +225,48 @@ export class UMAP { return this.embedding; } + /** + * Initializes parameters needed for supervised projection. + */ + setSupervisedProjection(Y: number[], params: UMAPSupervisedParams = {}) { + this.Y = Y; + this.targetMetric = params.targetMetric || this.targetMetric; + this.targetWeight = params.targetWeight || this.targetWeight; + this.targetNNeighbors = params.targetNNeighbors || this.targetNNeighbors; + } + + /** + * Initializes umap with precomputed KNN indices and distances. + */ + setPrecomputedKNN(knnIndices: number[][], knnDistances: number[][]) { + this.knnIndices = knnIndices; + this.knnDistances = knnDistances; + } + /** * Initializes fit by computing KNN and a fuzzy simplicial set, as well as * initializing the projected embeddings. Sets the optimization state ahead * of optimization steps. Returns the number of epochs to be used for the * SGD optimization. */ - initializeFit( - X: Vectors, - knnIndices?: number[][], - knnDistances?: number[][] - ): number { + initializeFit(X: Vectors): number { // We don't need to reinitialize if we've already initialized for this data. - if (this.data === X && this.isInitialized) { + if (this.X === X && this.isInitialized) { return this.getNEpochs(); } - this.data = X; + this.X = X; - if (knnIndices && knnDistances) { - this.knnIndices = knnIndices; - this.knnDistances = knnDistances; - } else { + if (!this.knnIndices && !this.knnDistances) { const knnResults = this.nearestNeighbors(X); this.knnIndices = knnResults.knnIndices; this.knnDistances = knnResults.knnDistances; } - this.graph = this.fuzzySimplicialSet(X); + this.graph = this.fuzzySimplicialSet(X, this.nNeighbors); + + // Check if supervised projection, then adjust the graph. + this.processGraphForSupervisedProjection(); const { head, @@ -236,6 +283,30 @@ export class UMAP { return this.getNEpochs(); } + /** + * Checks if we're using supervised projection, then process the graph + * accordingly. + */ + private processGraphForSupervisedProjection() { + const { Y, X } = this; + if (Y) { + if (Y.length !== X.length) { + throw new Error('Length of X and y must be equal'); + } + + if (this.targetMetric === TargetMetric.categorical) { + const lt = this.targetWeight < 1.0; + const farDist = lt ? 2.5 * (1.0 / (1.0 - this.targetWeight)) : 1.0e12; + this.graph = this.categoricalSimplicialSetIntersection( + this.graph, + Y, + farDist + ); + } + // TODO (andycoenen@): add non-categorical supervised embeddings. + } + } + /** * Manually step through the optimization process one epoch at a time. */ @@ -298,10 +369,11 @@ export class UMAP { */ private fuzzySimplicialSet( X: Vectors, + nNeighbors: number, localConnectivity = 1.0, setOpMixRatio = 1.0 ) { - const { nNeighbors, knnIndices = [], knnDistances = [] } = this; + const { knnIndices = [], knnDistances = [] } = this; const { sigmas, rhos } = this.smoothKNNDistance( knnDistances, @@ -320,7 +392,7 @@ export class UMAP { const sparseMatrix = new matrix.SparseMatrix(rows, cols, vals, size); const transpose = matrix.transpose(sparseMatrix); - const prodMatrix = matrix.dotMultiply(sparseMatrix, transpose); + const prodMatrix = matrix.pairwiseMultiply(sparseMatrix, transpose); const a = matrix.subtract(matrix.add(sparseMatrix, transpose), prodMatrix); const b = matrix.multiplyScalar(a, setOpMixRatio); @@ -330,6 +402,28 @@ export class UMAP { return result; } + /** + * Combine a fuzzy simplicial set with another fuzzy simplicial set + * generated from categorical data using categorical distances. The target + * data is assumed to be categorical label data (a vector of labels), + * and this will update the fuzzy simplicial set to respect that label data. + */ + private categoricalSimplicialSetIntersection( + simplicialSet: matrix.SparseMatrix, + target: number[], + farDist: number, + unknownDist = 1.0 + ) { + let intersection = fastIntersection( + simplicialSet, + target, + unknownDist, + farDist + ); + intersection = matrix.eliminateZeros(intersection); + return resetLocalConnectivity(intersection); + } + /** * Compute a continuous version of the distance to the kth nearest * neighbor. That is, this is similar to knn-distance but allows continuous @@ -850,3 +944,41 @@ export function findABParams(spread: number, minDist: number) { const [a, b] = parameterValues as number[]; return { a, b }; } + +/** + * Under the assumption of categorical distance for the intersecting + * simplicial set perform a fast intersection. + */ +export function fastIntersection( + graph: matrix.SparseMatrix, + target: number[], + unknownDist = 1.0, + farDist = 5.0 +) { + return graph.map((value, row, col) => { + if (target[row] === -1 || target[col] === -1) { + return value * Math.exp(-unknownDist); + } else if (target[row] !== target[col]) { + return value * Math.exp(-farDist); + } else { + return value; + } + }); +} + +/** + * Reset the local connectivity requirement -- each data sample should + * have complete confidence in at least one 1-simplex in the simplicial set. + * We can enforce this by locally rescaling confidences, and then remerging the + * different local simplicial sets together. + */ +export function resetLocalConnectivity(simplicialSet: matrix.SparseMatrix) { + simplicialSet = matrix.normalize(simplicialSet, matrix.NormType.max); + const transpose = matrix.transpose(simplicialSet); + const prodMatrix = matrix.pairwiseMultiply(transpose, simplicialSet); + simplicialSet = matrix.add( + simplicialSet, + matrix.subtract(transpose, prodMatrix) + ); + return matrix.eliminateZeros(simplicialSet); +} diff --git a/test/matrix.test.ts b/test/matrix.test.ts index 3e53c9a..539ad34 100644 --- a/test/matrix.test.ts +++ b/test/matrix.test.ts @@ -1,11 +1,8 @@ /* Copyright 2019 Google Inc. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,10 +14,13 @@ import { SparseMatrix, transpose, identity, - dotMultiply, + pairwiseMultiply, add, subtract, multiplyScalar, + eliminateZeros, + normalize, + NormType, } from '../src/matrix'; describe('sparse matrix', () => { @@ -40,7 +40,7 @@ describe('sparse matrix', () => { test('sparse matrix has get / set methods', () => { const rows = [0, 0, 1, 1]; const cols = [0, 1, 0, 1]; - const vals = [1, 2]; + const vals = [1, 2, 3, 4]; const dims = [2, 2]; const matrix = new SparseMatrix(rows, cols, vals, dims); @@ -111,8 +111,8 @@ describe('helper methods', () => { expect(I.toArray()).toEqual([[1, 0], [0, 1]]); }); - test('dot multiply method', () => { - const X = dotMultiply(A, B); + test('pairwise multiply method', () => { + const X = pairwiseMultiply(A, B); expect(X.toArray()).toEqual([[1, 4], [9, 16]]); }); @@ -130,4 +130,64 @@ describe('helper methods', () => { const X = multiplyScalar(A, 3); expect(X.toArray()).toEqual([[3, 6], [9, 12]]); }); + + test('eliminateZeros method', () => { + const defaultValue = 11; + const rows = [0, 1, 1]; + const cols = [0, 0, 1]; + const vals = [0, 1, 3]; + const dims = [2, 2]; + const matrix = new SparseMatrix(rows, cols, vals, dims); + + expect(matrix.get(0, 0, defaultValue)).toEqual(0); + const eliminated = eliminateZeros(matrix); + + expect(eliminated.getValues()).toEqual([1, 3]); + expect(eliminated.getRows()).toEqual([1, 1]); + expect(eliminated.getCols()).toEqual([0, 1]); + + expect(eliminated.get(0, 0, defaultValue)).toEqual(defaultValue); + }); +}); + +describe('normalize method', () => { + let A: SparseMatrix; + + beforeEach(() => { + const rows = [0, 0, 0, 1, 1, 1, 2, 2, 2]; + const cols = [0, 1, 2, 0, 1, 2, 0, 1, 2]; + const vals = [1, 2, 3, 4, 5, 6, 7, 8, 9]; + const dims = [3, 3]; + A = new SparseMatrix(rows, cols, vals, dims); + }); + + test('max normalization method', () => { + const expected = [ + [0.3333333333333333, 0.6666666666666666, 1.0], + [0.6666666666666666, 0.8333333333333334, 1.0], + [0.7777777777777778, 0.8888888888888888, 1.0], + ]; + const n = normalize(A, NormType.max); + expect(n.toArray()).toEqual(expected); + }); + + test('l1 normalization method', () => { + const expected = [ + [0.16666666666666666, 0.3333333333333333, 0.5], + [0.26666666666666666, 0.3333333333333333, 0.4], + [0.2916666666666667, 0.3333333333333333, 0.375], + ]; + const n = normalize(A, NormType.l1); + expect(n.toArray()).toEqual(expected); + }); + + test('l2 normalization method (default)', () => { + const expected = [ + [0.2672612419124244, 0.5345224838248488, 0.8017837257372732], + [0.4558423058385518, 0.5698028822981898, 0.6837634587578277], + [0.5025707110324167, 0.5743665268941904, 0.6461623427559643], + ]; + const n = normalize(A); + expect(n.toArray()).toEqual(expected); + }); }); diff --git a/test/test_data.ts b/test/test_data.ts index 5278d59..a9270ac 100644 --- a/test/test_data.ts +++ b/test/test_data.ts @@ -116,6 +116,8 @@ export const testData: number[][] = [ [0,0,1,15,13,0,0,0,0,0,1,16,16,5,0,0,0,0,7,16,16,0,0,0,0,0,13,16,13,0,0,0,0,7,16,16,13,0,0,0,0,1,11,16,13,0,0,0,0,0,2,16,16,0,0,0,0,0,1,14,16,3,0,0] ] +export const testLabels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 9, 5, 5, 6, 5, 0, 9, 8, 9, 8, 4, 1, 7, 7, 3, 5, 1, 0, 0, 2, 2, 7, 8, 2, 0, 1, 2, 6, 3, 3, 7, 3, 3, 4, 6, 6, 6, 4, 9, 1, 5, 0, 9, 5, 2, 8, 2, 0, 0, 1, 7, 6, 3, 2, 1, 7, 4, 6, 3, 1, 3, 9, 1, 7, 6, 8, 4, 3, 1]; + export const testResults2D = [ [-2.904975618700953, 3.683494083841041], [-0.879124321765863, -0.4426951405143409], diff --git a/test/umap.test.ts b/test/umap.test.ts index 2152640..2d30a34 100644 --- a/test/umap.test.ts +++ b/test/umap.test.ts @@ -1,11 +1,8 @@ /* Copyright 2019 Google Inc. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,12 +10,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import { UMAP, findABParams } from '../src/umap'; -import { testData, testResults2D, testResults3D } from './test_data'; +import { UMAP, findABParams, euclidean, TargetMetric } from '../src/umap'; +import * as utils from '../src/utils'; +import { + testData, + testLabels, + testResults2D, + testResults3D, +} from './test_data'; import Prando from 'prando'; describe('UMAP', () => { let random: () => number; + + // Expected "clustering" ratios, representing inter-cluster distance vs mean + // distance to other points. + const UNSUPERVISED_CLUSTER_RATIO = 0.15; + const SUPERVISED_CLUSTER_RATIO = 0.04; + beforeEach(() => { const prng = new Prando(42); random = () => prng.next(); @@ -28,12 +37,14 @@ describe('UMAP', () => { const umap = new UMAP({ random, nComponents: 2 }); const embedding = umap.fit(testData); expect(embedding).toEqual(testResults2D); + checkClusters(embedding, testLabels, UNSUPERVISED_CLUSTER_RATIO); }); test('UMAP fit 3d synchronous method', () => { const umap = new UMAP({ random, nComponents: 3 }); const embedding = umap.fit(testData); expect(embedding).toEqual(testResults3D); + checkClusters(embedding, testLabels, UNSUPERVISED_CLUSTER_RATIO); }); test('UMAP fitAsync method', async () => { @@ -63,7 +74,7 @@ describe('UMAP', () => { const nEpochs = 200; const umap = new UMAP({ random, nEpochs }); let nEpochsComputed = 0; - const embedding = await umap.fitAsync(testData, () => { + await umap.fitAsync(testData, () => { nEpochsComputed += 1; }); expect(nEpochsComputed).toEqual(nEpochs); @@ -85,13 +96,34 @@ describe('UMAP', () => { const { knnIndices, knnDistances } = knnUMAP['nearestNeighbors'](testData); const umap = new UMAP({ random }); + umap.setPrecomputedKNN(knnIndices, knnDistances); spyOn(umap, 'nearestNeighbors'); - umap.initializeFit(testData, knnIndices, knnDistances); umap.fit(testData); expect(umap['nearestNeighbors']).toBeCalledTimes(0); }); + test('supervised projection', () => { + const umap = new UMAP({ random, nComponents: 2 }); + umap.setSupervisedProjection(testLabels); + const embedding = umap.fit(testData); + + expect(embedding.length).toEqual(testResults2D.length); + checkClusters(embedding, testLabels, SUPERVISED_CLUSTER_RATIO); + }); + + test('non-categorical supervised projection is not implemented', () => { + const umap = new UMAP({ random, nComponents: 2 }); + + // Unimplemented target metric. + const targetMetric = TargetMetric.l1; + umap.setSupervisedProjection(testLabels, { targetMetric }); + const embedding = umap.fit(testData); + + // Supervision with unimplemented target metric is a noop. + expect(embedding).toEqual(testResults2D); + }); + test('finds AB params using levenberg-marquardt', () => { // The default parameters from the python implementation const minDist = 0.1; @@ -108,4 +140,48 @@ describe('UMAP', () => { expect(diff(params.a, a)).toBeLessThanOrEqual(epsilon); expect(diff(params.b, b)).toBeLessThanOrEqual(epsilon); }); + + const computeMeanDistances = (vectors: number[][]) => { + return vectors.map(vector => { + return utils.mean( + vectors.map(other => { + return euclidean(vector, other); + }) + ); + }); + }; + + /** + * Check the ratio between distances within a cluster and for all points to + * indicate "clustering" + */ + const checkClusters = ( + embeddings: number[][], + labels: number[], + expectedClusterRatio: number + ) => { + const distances = computeMeanDistances(embeddings); + const overallMeanDistance = utils.mean(distances); + + const embeddingsByLabel = new Map(); + for (let i = 0; i < labels.length; i++) { + const label = labels[i]; + const embedding = embeddings[i]; + const group = embeddingsByLabel.get(label) || []; + group.push(embedding); + embeddingsByLabel.set(label, group); + } + + let totalIntraclusterDistance = 0; + for (let label of embeddingsByLabel.keys()) { + const group = embeddingsByLabel.get(label)!; + const distances = computeMeanDistances(group); + const meanDistance = utils.mean(distances); + totalIntraclusterDistance += meanDistance * group.length; + } + const meanInterclusterDistance = + totalIntraclusterDistance / embeddings.length; + const clusterRatio = meanInterclusterDistance / overallMeanDistance; + expect(clusterRatio).toBeLessThan(expectedClusterRatio); + }; }); diff --git a/tsconfig.json b/tsconfig.json index de38011..85d8dfe 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -14,6 +14,7 @@ "noImplicitAny": false, "removeComments": true, "allowUnreachableCode": true, + "downlevelIteration": true, "lib": ["dom", "es2015", "es2016", "es2017"] }, "compileOnSave": false,