Skip to content

Commit

Permalink
Encoding technique options. New fruit dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
izzat5233 committed Dec 24, 2023
1 parent f046fed commit 5550449
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 105 deletions.
61 changes: 45 additions & 16 deletions web/content/control/data.html
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,51 @@
}
</script>
<div class="row">
<div class="col">
<div class="px-2">
<p><span class="fw-bold fs-5">Encoding Technique</span><br><span class="fw-light">
Data is always encoded if categorical columns are detected.</span>
</p>
<div class="px-2 d-flex flex-column gap-2">
<div onchange="handleDataEncodingOptionChange()">
<div class="form-check">
<input class="form-check-input" type="radio" name="dataEncodingRadio"
id="dataEncodingRadioLabelOption" checked>
<label class="form-check-label" for="dataEncodingRadioLabelOption">
Label Encoding <span class="fw-light">(0, 1, 2,...)</span>
</label>
</div>
<div class="form-check">
<input class="form-check-input" type="radio" name="dataEncodingRadio"
id="dataEncodingRadioOneHotOption">
<label class="form-check-label" for="dataEncodingRadioOneHotOption">
One-Hot Encoding <span class="fw-light">(001, 010, 100, ...)</span>
</label>
</div>
</div>
<div class="form-check" onchange="previewTrainTable(); previewTestTable()">
<input class="form-check-input" type="checkbox" id="dataOriginalDataCheck">
<label class="form-check-label" for="dataOriginalDataCheck">
Show original data in preview
</label>
</div>
</div>
</div>
<script>
function handleDataEncodingOptionChange() {
const isLabelEncoding = document.getElementById('dataEncodingRadioLabelOption').checked;
const value = isLabelEncoding ? 'label' : 'oneHot';
trainInputTableObject.reset(value);
trainOutputTableObject.reset(value);
testInputTableObject.reset(value);
testOutputTableObject.reset(value);
handleTrainInputClear();
handleTrainOutputClear();
handleTestInputClear();
handleTestOutputClear();
}
</script>
</div>
<div class="col">
<div class="btn-group mt-2" role="group">
<button type="button" class="btn btn-primary dropdown-toggle" data-bs-toggle="dropdown"
Expand Down Expand Up @@ -59,22 +104,6 @@
}
</script>
</div>
<div class="col">
<div class="px-2">
<p>Data is always one-hot encoded if categorical columns are found.</p>
<div class="form-check">
<input class="form-check-input" type="checkbox" id="dataOriginalDataCheck"
onchange="handleDataOptionsChange()">
<label class="form-check-label" for="dataOriginalDataCheck">Show Original Data</label>
</div>
</div>
<script>
function handleDataOptionsChange() {
previewTrainTable();
previewTestTable();
}
</script>
</div>
</div>
<hr>
<div class="row">
Expand Down
18 changes: 4 additions & 14 deletions web/content/control/test.html
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
</div>
<script>
function getDecodedPredictions() {
return testOutputTableObject.decodeBasedOnMaxProbability(predictionsEncodedData);
return testOutputTableObject.decodeData(predictionsEncodedData);
}

function handleClearPredictions() {
Expand All @@ -69,7 +69,7 @@
downloadJson({
weights: toArrArrArr(network.getWeights()),
biases: toArrArr(network.getBiases())
}, "network_data.csv");
}, "network_data.json");
}

function previewPredictionsTable() {
Expand Down Expand Up @@ -98,18 +98,8 @@
const b2 = appendTableCells(tbody, output.data);
styleTableSelection(h2, b2, 'table-secondary');

let predictions;
if (encoded) {
predictions = {
headers: testOutputTableObject.getEncodedHeaders(),
data: sampleArrArr(predictionsEncodedData),
}
} else {
predictions = {
headers: testOutputTableObject.headers,
data: sampleArrArr(getDecodedPredictions()),
}
}
let predictions = testOutputTableObject.prepareSimilarData(encoded, predictionsEncodedData);
predictions.data = sampleArrArr(predictions.data);
const h3 = appendTableHeaders(thead, predictions.headers);
const b3 = appendTableCells(tbody, predictions.data);
styleTableSelection(h3, b3, 'table-primary');
Expand Down
15 changes: 5 additions & 10 deletions web/static/datasets/fruits/test_in.csv
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
COLOR,SWEETNESS
RED,8
ORANGE,9
RED,6
YELLOW,5
YELLOW,4
ORANGE,7
RED,7
ORANGE,8
GREEN,4
Color,Sweetness
Red,6
Green,7
Orange,5
Brown,6
15 changes: 5 additions & 10 deletions web/static/datasets/fruits/test_out.csv
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
FRUIT
APPLE
ORANGE
APPLE
APPLE
BANANA
ORANGE
APPLE
ORANGE
APPLE
Fruit
Apple
Apple
Orange
Banana
20 changes: 10 additions & 10 deletions web/static/datasets/fruits/train_in.csv
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
COLOR,SWEETNESS
RED,9
RED,8
RED,7
YELLOW,6
GREEN,5
ORANGE,6
YELLOW,3
ORANGE,8
ORANGE,7
Color,Sweetness
Red,7
Red,8
Green,8
Orange,6
Orange,7
Green,2
Yellow,5
Yellow,8
Brown,9
20 changes: 10 additions & 10 deletions web/static/datasets/fruits/train_out.csv
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
FRUIT
APPLE
APPLE
APPLE
APPLE
APPLE
ORANGE
BANANA
ORANGE
ORANGE
Fruit
Apple
Apple
Apple
Orange
Orange
Orange
Banana
Banana
Banana
117 changes: 82 additions & 35 deletions web/static/js/TableObject.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
class TableObject {
constructor() {
constructor(encodingType = 'label') {
this.reset(encodingType);
}

reset(encodingType) {
this._encodingType = encodingType;
this._headers = [];
this._data = [];
this._categoricalColumns = [];
this._encodedColumns = {};
}

clear() {
this.reset(this._encodingType);
}

set headers(headers) {
if (!Array.isArray(headers)) {
throw new Error("Headers must be of type array");
Expand Down Expand Up @@ -52,11 +61,6 @@ class TableObject {
this.data = data;
}

clear() {
this._headers = [];
this._data = [];
}

#updateCategoricalColumns() {
if (this._data.length > 0 && this._headers.length > 0) {
this._categoricalColumns = this.#detectCategoricalColumns();
Expand All @@ -81,23 +85,37 @@ class TableObject {

this._categoricalColumns.forEach(column => {
let uniqueValuesMap = new Set(this.data.map(row => row[this.headers.indexOf(column)]));
let encodedColumnData = this.data.map(row => {
let encodedRow = [];
uniqueValuesMap.forEach(uniqueValue => {
encodedRow.push(row[this.headers.indexOf(column)] === uniqueValue ? 1 : 0);

if (this._encodingType === 'label') {
let labelMap = Array.from(uniqueValuesMap).reduce((acc, val, index) => {
acc[val] = index;
return acc;
}, {});

encodedColumns[column] = {
uniqueValues: Array.from(uniqueValuesMap),
labelMap: labelMap
};
} else {
let encodedColumnData = this.data.map(row => {
let encodedRow = [];
uniqueValuesMap.forEach(uniqueValue => {
encodedRow.push(row[this.headers.indexOf(column)] === uniqueValue ? 1 : 0);
});
return encodedRow;
});
return encodedRow;
});
encodedColumns[column] = {
uniqueValues: Array.from(uniqueValuesMap),
data: encodedColumnData
};
encodedColumns[column] = {
uniqueValues: Array.from(uniqueValuesMap),
data: encodedColumnData
};
}
});

return encodedColumns;
}

getEncodedHeaders() {
if (this._encodingType === 'label') return this._headers;
let encodedHeaders = [];

this.headers.forEach(header => {
Expand All @@ -119,14 +137,19 @@ class TableObject {

this.headers.forEach((header, headerIndex) => {
if (this._categoricalColumns.includes(header)) {
// Find the index of the value in the row for the categorical column
const valueIndex = this._encodedColumns[header].uniqueValues.indexOf(row[headerIndex]);
// Create a one-hot encoded array for this value
const oneHotArray = Array(this._encodedColumns[header].uniqueValues.length).fill(0);
if (valueIndex >= 0) {
oneHotArray[valueIndex] = 1;
if (this._encodingType === 'label') {
// Label encoding
const labelValue = this._encodedColumns[header].labelMap[row[headerIndex]];
encodedRow.push(labelValue);
} else {
// One-hot encoding
const valueIndex = this._encodedColumns[header].uniqueValues.indexOf(row[headerIndex]);
const oneHotArray = Array(this._encodedColumns[header].uniqueValues.length).fill(0);
if (valueIndex >= 0) {
oneHotArray[valueIndex] = 1;
}
encodedRow = encodedRow.concat(oneHotArray);
}
encodedRow = encodedRow.concat(oneHotArray);
} else {
encodedRow.push(row[headerIndex]);
}
Expand All @@ -136,32 +159,40 @@ class TableObject {
});
}

decodeBasedOnMaxProbability(encodedData) {
decodeData(encodedData) {
if (!Array.isArray(encodedData)) {
throw new Error("Encoded data must be of type array");
throw new Error("Encoded data must be an array");
}

return encodedData.map(encodedRow => {
let decodedRow = [];
let encodedIndex = 0;

this.headers.forEach(header => {
this.headers.forEach((header, headerIndex) => {
if (this._categoricalColumns.includes(header)) {
// Decode the categorical column based on maximum probability
const numberOfCategories = this._encodedColumns[header].uniqueValues.length;
const encodedSegment = encodedRow.slice(encodedIndex, encodedIndex + numberOfCategories);
const maxProbIndex = encodedSegment.indexOf(Math.max(...encodedSegment));
const decodedValue = maxProbIndex >= 0 ? this._encodedColumns[header].uniqueValues[maxProbIndex] : null;

decodedRow.push(decodedValue);
encodedIndex += numberOfCategories;
if (this._encodingType === 'label') {
// Decode label encoding with rounding and clipping
const numberOfLabels = this._encodedColumns[header].uniqueValues.length;
let labelValueIndex = Math.round(encodedRow[encodedIndex]);
labelValueIndex = Math.max(0, Math.min(labelValueIndex, numberOfLabels - 1));
const decodedValue = this._encodedColumns[header].uniqueValues[labelValueIndex];
decodedRow.push(decodedValue);
encodedIndex++;
} else {
// Decode one-hot encoding
const numberOfCategories = this._encodedColumns[header].uniqueValues.length;
const encodedSegment = encodedRow.slice(encodedIndex, encodedIndex + numberOfCategories);
const maxProbIndex = encodedSegment.indexOf(Math.max(...encodedSegment));
const decodedValue = this._encodedColumns[header].uniqueValues[maxProbIndex];
decodedRow.push(decodedValue);
encodedIndex += numberOfCategories;
}
} else {
// For non-categorical columns, the value is directly taken from the encoded data
decodedRow.push(encodedRow[encodedIndex]);
encodedIndex++;
}
});

return decodedRow;
});
}
Expand All @@ -181,4 +212,20 @@ class TableObject {
data: sampleArrArr(data),
}
}

prepareSimilarData(encoded, encodedData) {
let headers;
let data;
if (encoded) {
headers = this.getEncodedHeaders();
data = encodedData;
} else {
headers = this._headers;
data = this.decodeData(encodedData)
}
return {
headers: headers,
data: data
}
}
}

0 comments on commit 5550449

Please sign in to comment.