-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathDirectLearning.hpp
181 lines (148 loc) · 5.84 KB
/
DirectLearning.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
/**
* @file DirectLearning.hpp
* @brief Declaration of the classes for direct learning algorithms.
* @author Ankit Srivastava <asrivast@gatech.edu>
*
* Copyright 2020 Georgia Institute of Technology
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DIRECTLEARNING_HPP_
#define DIRECTLEARNING_HPP_
#include "LocalLearning.hpp"
/**
* @brief Abstract base class for causal discovery by directly learning PC sets.
*
* @tparam Data Type of the object which is used for querying the data.
* @tparam Var Type of variable indices (expected to be an integer type).
* @tparam Set Type of set container.
*/
template <typename Data, typename Var, typename Set>
class DirectLearning : public LocalLearning<Data, Var, Set> {
public:
DirectLearning(const mxx::comm&, const Data&, const double, const Var);
virtual
~DirectLearning();
protected:
Set
removeFalsePC(const Var, Set&) const;
virtual
Set
getCandidatePC_impl(const Var, Set&&) const = 0;
void
updateMaxPValues(const Var, std::vector<std::pair<double, Var>>&, const Set&, const Set&) const;
void
updateMyPValues(std::vector<std::tuple<Var, Var, double>>&, const std::unordered_map<Var, Set>&, const std::unordered_map<Var, Set>&) const;
template <typename Compare>
std::set<std::tuple<Var, Var, double>>
forwardPhase(const std::vector<std::tuple<Var, Var, double>>&, const Compare&, const bool, std::unordered_map<Var, Set>&) const;
std::set<std::pair<Var, Var>>
backwardPhase(std::unordered_map<Var, Set>&) const;
virtual
void
forwardBackward(std::vector<std::tuple<Var, Var, double>>&&, std::unordered_map<Var, Set>&, std::set<std::pair<Var, Var>>&, const double) const { };
private:
Set
getCandidatePC(const Var, Set&&) const override;
Set
getMBSuperset(const Var) const;
Set
getCandidateMB(const Var, Set&&) const override;
BayesianNetwork<Var>
getSkeleton_parallel(const bool, const double) const override;
std::pair<bool, double>
checkCollider(const Var, const Var, const Var) const override;
protected:
TIMER_DECLARE(m_tForward, mutable);
TIMER_DECLARE(m_tBackward, mutable);
TIMER_DECLARE(m_tDist, mutable);
TIMER_DECLARE(m_tSymmetry, mutable);
TIMER_DECLARE(m_tSync, mutable);
TIMER_DECLARE(m_tNeighbors, mutable);
private:
mutable std::unordered_map<Var, Set> m_cachedCandidatePC;
}; // class DirectLearning
/**
* @brief Class that implements Max-Min algorithm for PC discovery,
* as described by Tsamardinos et al.
*
* @tparam Data Type of the object which is used for querying the data.
* @tparam Var Type of variable indices (expected to be an integer type).
* @tparam Set Type of set container.
*/
template <typename Data, typename Var, typename Set>
class MMPC : public DirectLearning<Data, Var, Set> {
public:
MMPC(const mxx::comm&, const Data&, const double = 0.05, const Var = std::numeric_limits<Var>::max());
private:
Set
getCandidatePC_impl(const Var, Set&&) const override;
void
forwardBackward(std::vector<std::tuple<Var, Var, double>>&&, std::unordered_map<Var, Set>&, std::set<std::pair<Var, Var>>&, const double) const override;
}; // class MMPC
/**
* @brief Class that implements HITON algorithm for PC discovery,
* as described by Aliferis et al.
*
* @tparam Data Type of the object which is used for querying the data.
* @tparam Var Type of variable indices (expected to be an integer type).
* @tparam Set Type of set container.
*/
template <typename Data, typename Var, typename Set>
class HITON : public DirectLearning<Data, Var, Set> {
public:
HITON(const mxx::comm&, const Data&, const double = 0.05, const Var = std::numeric_limits<Var>::max());
private:
Set
getCandidatePC_impl(const Var, Set&&) const override;
BayesianNetwork<Var>
getSkeleton_parallel(const bool, const double) const override;
}; // class HITON
/**
* @brief Class that implements Semi-interleaved HITON algorithm for PC discovery,
* as described by Aliferis et al.
*
* @tparam Data Type of the object which is used for querying the data.
* @tparam Var Type of variable indices (expected to be an integer type).
* @tparam Set Type of set container.
*/
template <typename Data, typename Var, typename Set>
class SemiInterleavedHITON : public DirectLearning<Data, Var, Set> {
public:
SemiInterleavedHITON(const mxx::comm&, const Data&, const double = 0.05, const Var = std::numeric_limits<Var>::max());
private:
Set
getCandidatePC_impl(const Var, Set&&) const override;
void
forwardBackward(std::vector<std::tuple<Var, Var, double>>&&, std::unordered_map<Var, Set>&, std::set<std::pair<Var, Var>>&, const double) const override;
}; // class SemiInterleavedHITON
/**
* @brief Class that implements GetPC algorithm for PC discovery,
* as described by Pena et al.
*
* @tparam Data Type of the object which is used for querying the data.
* @tparam Var Type of variable indices (expected to be an integer type).
* @tparam Set Type of set container.
*/
template <typename Data, typename Var, typename Set>
class GetPC : public DirectLearning<Data, Var, Set> {
public:
GetPC(const mxx::comm&, const Data&, const double = 0.05, const Var = std::numeric_limits<Var>::max());
private:
Set
getCandidatePC_impl(const Var, Set&&) const override;
BayesianNetwork<Var>
getSkeleton_parallel(const bool, const double) const override;
}; // class GetPC
#include "detail/DirectLearning.hpp"
#endif // DIRECTLEARNING_HPP_