Skip to content

Commit

Permalink
Merge pull request #2 from PAIR-code/supervised
Browse files Browse the repository at this point in the history
Add supervised (categorical) projection
  • Loading branch information
cannoneyed authored Mar 6, 2019
2 parents 4b64548 + 7569afc commit 9e186c1
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 36 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
language: node_js
node_js:
- '8'
- '10'
script:
- yarn test
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ for (let i = 0; i < nEpochs; i++) {
const embedding = umap.getEmbedding();
```

#### Supervised projection using labels

```typescript
import { UMAP } from 'umap-js';

const umap = new UMAP();
umap.setSupervisedProjection(labels);
const embedding = umap.fit(data);
```

#### Parameters

The UMAP constructor can accept a number of parameters via a `UMAPParameters` object:
Expand Down
92 changes: 91 additions & 1 deletion src/matrix.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ export class SparseMatrix {
}
}

getDims(): number[] {
return [this.nRows, this.nCols];
}

getRows(): number[] {
return [...this.rows];
}
Expand Down Expand Up @@ -157,7 +161,10 @@ export function identity(size: number[]): SparseMatrix {
/**
* Element-wise multiplication of two matrices
*/
export function dotMultiply(a: SparseMatrix, b: SparseMatrix): SparseMatrix {
export function pairwiseMultiply(
a: SparseMatrix,
b: SparseMatrix
): SparseMatrix {
return elementWise(a, b, (x, y) => x * y);
}

Expand All @@ -184,6 +191,89 @@ export function multiplyScalar(a: SparseMatrix, scalar: number): SparseMatrix {
});
}

/**
* Returns a new matrix with zero entries removed.
*/
export function eliminateZeros(m: SparseMatrix) {
const zeroIndices = new Set();
const values = m.getValues();
const rows = m.getRows();
const cols = m.getCols();
for (let i = 0; i < values.length; i++) {
if (values[i] === 0) {
zeroIndices.add(i);
}
}
const removeByZeroIndex = (_, index: number) => !zeroIndices.has(index);
const nextValues = values.filter(removeByZeroIndex);
const nextRows = rows.filter(removeByZeroIndex);
const nextCols = cols.filter(removeByZeroIndex);

return new SparseMatrix(nextRows, nextCols, nextValues, m.getDims());
}

/**
* Normalization of a sparse matrix.
*/
export function normalize(m: SparseMatrix, normType = NormType.l2) {
const normFn = normFns[normType];

const colsByRow = new Map<number, number[]>();
m.forEach((_, row, col) => {
const cols = colsByRow.get(row) || [];
cols.push(col);
colsByRow.set(row, cols);
});

const nextMatrix = new SparseMatrix([], [], [], m.getDims());

for (let row of colsByRow.keys()) {
const cols = colsByRow.get(row)!.sort();

const vals = cols.map(col => m.get(row, col));
const norm = normFn(vals);
for (let i = 0; i < norm.length; i++) {
nextMatrix.set(row, cols[i], norm[i]);
}
}

return nextMatrix;
}

/**
* Vector normalization functions
*/
type NormFns = { [key in NormType]: (v: number[]) => number[] };
const normFns: NormFns = {
[NormType.max]: (xs: number[]) => {
let max = -Infinity;
for (let i = 0; i < xs.length; i++) {
max = xs[i] > max ? xs[i] : max;
}
return xs.map(x => x / max);
},
[NormType.l1]: (xs: number[]) => {
let sum = 0;
for (let i = 0; i < xs.length; i++) {
sum += xs[i];
}
return xs.map(x => x / sum);
},
[NormType.l2]: (xs: number[]) => {
let sum = 0;
for (let i = 0; i < xs.length; i++) {
sum += xs[i] ** 2;
}
return xs.map(x => Math.sqrt(x ** 2 / sum));
},
};

export const enum NormType {
max = 'max',
l1 = 'l1',
l2 = 'l2',
}

/**
* Helper function for element-wise operations.
*/
Expand Down
Loading

0 comments on commit 9e186c1

Please sign in to comment.