-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMLP.h
36 lines (32 loc) · 1.03 KB
/
MLP.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#ifndef MLP_H
#define MLP_H
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include "mnist.h"
// Structure to represent the MLP
typedef struct
{
int input_size;
int hidden_size;
int output_size;
double *hidden;
double *output;
double *hidden_bias;
double *output_bias;
double **input_hidden_weights;
double **output_hidden_weights;
double *hidden_delta;
double *output_delta;
} MLP;
// Function prototypes
MLP *initialize_network(int input_size, int hidden_size, int output_size);
void forward_propagation(MLP *network, double *input);
void backpropagation(MLP *network, double *input, double *target, double learning_rate);
void update_weights(MLP *network, double *input, double learning_rate);
double calculate_error(double *output, double *target, int size);
void train(MLP *network, DataItem *train_data, int train_size, int epochs, double learning_rate);
double evaluate(MLP *network, DataItem *test_data, int test_size);
void free_network(MLP *network);
#endif // MLP_H