-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDecisionTree.hpp
38 lines (35 loc) · 1.56 KB
/
DecisionTree.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include "Node.hpp"
#include "error.hpp"
#include "utils.hpp"
#include <iostream>
#include <cmath>
#include <map>
#include <stack>
#include <utility>
#define endLine '\n'
class DecisionTree{
private:
int n_features, n_classes;
std::string criterion, target_name;
Node* tree = NULL;
std::vector<std::string> feature_names;
// Store the feature nodes to merge the edge and a feature node
std::stack<Node*> features;
std::map<std::string, std::map<std::string, int>> class_association(std::vector<std::string> sample, std::vector<std::string> classes);
// Get the distribution of the target data
std::map<std::string, int> c_distribution(std::vector<std::string> classes);
double entropy_S(std::map<std::string, int>values, int samples);
double entropy_a(std::map<std::string, std::map<std::string, int>>values, int samples);
double gini_S(std::map<std::string, int>values, int samples);
double gini_a(std::map<std::string, std::map<std::string, int>>values, int samples);
double gain(double set, double attribute);
public:
DecisionTree(std::string criterion);
~DecisionTree();
void names(std::vector<std::string>feature_names, std::string target_name);
std::vector<std::string> get_features();
void fit(std::vector<std::vector<std::string>>X, std::vector<std::string>y);
std::string predict(std::vector<std::string>X);
std::vector<std::string> predict(std::vector<std::vector<std::string>>X);
Node* get_tree();
};