-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist.c
212 lines (176 loc) · 5.05 KB
/
mnist.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <stdbool.h>
#include <math.h>
#include "mnist.h"
DataItem *createDataItem(const char *imageFileame,
const char *labelFileame,
int setSize,
int imageSize,
int labelSize)
{
uint8_t **images, *labels;
double temp;
// Allocate memory for DataItem
DataItem *items = (DataItem *)malloc(setSize * sizeof(DataItem));
for (int i = 0; i < setSize; i++)
{
items[i].image = (double *)malloc(imageSize * sizeof(double));
items[i].label = (double *)malloc(labelSize * sizeof(double));
// Set all labels to zero
for (int j = 0; j < labelSize; j++)
{
items[i].label[j] = 0;
};
};
// Load images and labels in uint8
images = read_mnist_images(imageFileame);
labels = read_mnist_labels(labelFileame);
// Set images and labels data in item in proper format
// Iterate in range of set size
for (int i = 0; i < setSize; i++)
{
// Convert images to double format, standardize them and set to data
for (int j = 0; j < imageSize; j++)
{
temp = ((double)images[i][j]) / 255.0;
items[i].image[j] = temp;
};
// Set label
items[i].label[labels[i] - 1] = (double)labels[i];
};
// Free images and labels
free(images);
free(labels);
return items;
};
void freeDataItem(DataItem *items, int setSize)
{
free(items);
return;
}
// Function to read 4 bytes from a file and convert to a 32-bit integer
uint32_t read_uint32(FILE *file)
{
uint8_t buffer[4];
fread(buffer, 1, 4, file);
return (buffer[0] << 24) | (buffer[1] << 16) | (buffer[2] << 8) | buffer[3];
};
bool isValidFile(FILE *file, const char *filename)
{
if (file == NULL)
{
fprintf(stderr, "Could not open file %s\n", filename);
return false;
};
return true;
};
// Function to read magic number and validate it
bool isValidMagicNumber(FILE *file, const char *filename, uint32_t validMagicNumber)
{
// Read the magic number and validate it
uint32_t magicNumber = read_uint32(file);
if (magicNumber != validMagicNumber)
{
fprintf(stderr, "Invalid magic number in %s\n", filename);
fclose(file);
return false;
}
return true;
};
// Function to read MNIST images
uint8_t **read_mnist_images(const char *filename)
{
int number_of_images;
int rows;
int cols;
// Open file and validate it
FILE *file = fopen(filename, "rb");
if (isValidFile(file, filename) == false)
{
fclose(file);
return NULL;
};
// Read the magic number and validate it
if (isValidMagicNumber(file, filename, 2051) == false)
{
return NULL;
};
// Read the number of images, rows, and columns
number_of_images = read_uint32(file);
rows = read_uint32(file);
cols = read_uint32(file);
// Allocate memory for the images
uint8_t **images = (uint8_t **)malloc(number_of_images * sizeof(uint8_t *));
for (int i = 0; i < number_of_images; i++)
{
images[i] = (uint8_t *)malloc(rows * cols * sizeof(uint8_t));
fread(images[i], 1, rows * cols, file);
}
fclose(file);
return images;
};
// Function to read MNIST labels
uint8_t *read_mnist_labels(const char *filename)
{
int number_of_labels;
FILE *file = fopen(filename, "rb");
if (isValidFile(file, filename) == false)
{
fclose(file);
return NULL;
};
// Read the magic number and validate it
if (isValidMagicNumber(file, filename, 2049) == false)
{
return NULL;
};
// Read the number of labels
number_of_labels = read_uint32(file);
// Allocate memory for the labels
uint8_t *labels = (uint8_t *)malloc(number_of_labels * sizeof(uint8_t));
fread(labels, 1, number_of_labels, file);
fclose(file);
return labels;
};
void printMnistImages(uint8_t **imagesArray, uint8_t *labelsArray, int imageIndex, int imageSize)
{
int side = (int)sqrt((double)imageSize);
// Print label
printf("Label of image is %d.\n", labelsArray[imageIndex]);
// Print array with alignment
for (int i = 0; i < side; i++)
{
for (int j = 0; j < side; j++)
{
printf("%3d ", imagesArray[imageIndex][i * side + j]);
}
printf("\n");
};
};
void printMnistItem(DataItem *items, int imageIndex, int imageSize)
{
int side = (int)sqrt((double)imageSize);
double label;
// Find nonzero label
for (int i = 0; i < 10; i++)
{
if (items[imageIndex].label[i] != 0)
{
label = items[imageIndex].label[i];
break;
}
};
// Print label
printf("Label of image is %.0f.\n", label);
// Print array with alignment
for (int i = 0; i < side; i++)
{
for (int j = 0; j < side; j++)
{
printf("%1.0f ", items[imageIndex].image[i * side + j]);
}
printf("\n");
};
};