-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from fszewczyk/image
Image class
- Loading branch information
Showing
15 changed files
with
8,615 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
build/ | ||
datasets/ | ||
|
||
docs/html | ||
docs/latex | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#include <filesystem> | ||
#include <iostream> | ||
|
||
#include "../include/ShkyeraGrad.hpp" | ||
|
||
namespace fs = std::filesystem; | ||
using namespace shkyera; | ||
|
||
Dataset<Vec32, Vec32> load(std::string directory) { | ||
Dataset<Vec32, Vec32> dataset; | ||
|
||
std::cerr << "Loading [" << std::flush; | ||
for (size_t digit = 0; digit < 10; ++digit) { | ||
std::cerr << "▮" << std::flush; | ||
int added = 0; | ||
for (const auto &entry : fs::directory_iterator(directory + std::to_string(digit))) { | ||
Image image(entry.path().string()); | ||
auto target = Vec32::oneHotEncode(digit, 10); | ||
|
||
dataset.addSample(image.flatten<Type::float32>() / 255.0f, target); | ||
} | ||
} | ||
std::cerr << "]" << std::endl; | ||
|
||
return dataset; | ||
} | ||
|
||
int main() { | ||
Dataset<Vec32, Vec32> trainData = load("datasets/mnist/train/"); | ||
std::cerr << "Loaded training data." << std::endl; | ||
|
||
DataLoader trainLoader(trainData, 16, true); | ||
|
||
// clang-format off | ||
auto mlp = SequentialBuilder32::begin() | ||
.add(Linear32::create(784, 100)) | ||
.add(ReLU32::create()) | ||
.add(Linear32::create(100, 50)) | ||
.add(Sigmoid32::create()) | ||
.add(Linear32::create(50, 10)) | ||
.add(Softmax32::create()) | ||
.build(); | ||
// clang-format on | ||
|
||
auto optimizer = Adam32(mlp->parameters(), 0.01, 0.99); | ||
auto lossFunction = Loss::CrossEntropy<Type::float32>; | ||
|
||
for (size_t epoch = 0; epoch < 50; epoch++) { | ||
float epochLoss = 0; | ||
double epochAccuracy = 0; | ||
|
||
for (const auto [x, y] : trainLoader) { | ||
optimizer.reset(); | ||
|
||
auto pred = mlp->forward(x); | ||
|
||
double accuracy = 0; | ||
for (size_t i = 0; i < pred.size(); ++i) { | ||
size_t predictedDigit = pred[i].argMax(); | ||
size_t trueDigit = y[i].argMax(); | ||
|
||
if (predictedDigit == trueDigit) | ||
accuracy += 1; | ||
} | ||
|
||
accuracy /= pred.size(); | ||
epochAccuracy += accuracy; | ||
|
||
auto loss = Loss::compute(lossFunction, pred, y); | ||
epochLoss = epochLoss + loss->getValue(); | ||
|
||
optimizer.step(); | ||
|
||
std::cerr << "Loss: " << loss->getValue() << " Accuracy: " << accuracy << std::endl; | ||
} | ||
std::cerr << "Epoch: " << epoch + 1 << " Loss: " << epochLoss / trainLoader.getTotalBatches() | ||
<< " Accuracy: " << epochAccuracy / trainLoader.getTotalBatches() << std::endl; | ||
} | ||
} |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/** | ||
* Copyright © 2023 Franciszek Szewczyk. None of the rights reserved. | ||
* This code is released under the Beerware License. If you find this code useful or you appreciate the work, you are | ||
* encouraged to buy the author a beer in return. | ||
* Contact the author at szewczyk.franciszek02@gmail.com for inquiries and support. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#define STB_IMAGE_IMPLEMENTATION | ||
#include "../external/stb_image.h" | ||
|
||
#include "Vector.hpp" | ||
|
||
namespace shkyera { | ||
|
||
class Image { | ||
private: | ||
std::vector<uint8_t> _data; | ||
|
||
public: | ||
Image() = default; | ||
Image(std::string filename, bool grayscale = true); | ||
|
||
template <typename T> Vector<T> flatten(size_t takeEvery = 1) const; | ||
}; | ||
|
||
Image::Image(std::string filename, bool grayscale) { | ||
int width, height, channels; | ||
uint8_t *imageData = nullptr; | ||
|
||
if (grayscale) | ||
imageData = stbi_load(filename.c_str(), &width, &height, &channels, 1); | ||
else | ||
imageData = stbi_load(filename.c_str(), &width, &height, &channels, 3); | ||
|
||
if (!imageData) { | ||
std::cerr << "Error loading image: " << filename << std::endl; | ||
return; | ||
} | ||
|
||
if (grayscale) | ||
_data.assign(imageData, imageData + (width * height)); | ||
else | ||
_data.assign(imageData, imageData + (width * height * 3)); | ||
|
||
stbi_image_free(imageData); | ||
} | ||
|
||
template <typename T> Vector<T> Image::flatten(size_t takeEvery) const { | ||
std::vector<T> converted; | ||
converted.reserve(_data.size()); | ||
for (size_t i = 0; i < _data.size(); i += takeEvery) | ||
converted.push_back(static_cast<T>(_data[i])); | ||
return Vector<T>::of(converted); | ||
} | ||
|
||
} // namespace shkyera |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.