Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
split.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4*/
5#ifndef SPLIT_H
6#define SPLIT_H
7
9// Split Node Overloads
10namespace Split{
11 template<typename T>
12 ArrayXb threshold_mask(const T& x, const float& threshold);
14 template<typename T> requires same_as<typename T::Scalar, bool>
15 ArrayXb threshold_mask(const T& x, const float& threshold) {
16 return x;
17 }
18 template<typename T> requires same_as<typename T::Scalar, bJet>
19 ArrayXb threshold_mask(const T& x, const float& threshold) {
20 ArrayXb ret(x.size());
21 for (int i = 0; i< x.size(); ++i)
22 ret(i) = x(i).a;
23 return ret;
24 }
25 template<typename T> requires same_as<typename T::Scalar, float>
26 ArrayXb threshold_mask(const T& x, const float& threshold) {
27 return (x >= threshold);
28 }
29 template<typename T> requires same_as<typename T::Scalar, fJet>
30 ArrayXb threshold_mask(const T& x, const float& threshold) {
31 ArrayXb ret(x.size());
32 std::transform(
33 x.begin(), x.end(), ret.begin(),
34 [&](const auto& e){return e >= threshold;}
35 );
36 return ret;
37 }
38 template<typename T> requires same_as<typename T::Scalar, int>
39 ArrayXb threshold_mask(const T& x, const float& threshold) {
40 return (x == threshold);
41 }
42
43 template<typename T> requires same_as<typename T::Scalar, iJet>
44 ArrayXb threshold_mask(const T& x, const float& threshold) {
45 // return (x == threshold);
46 ArrayXb ret(x.size());
47 std::transform(
48 x.begin(), x.end(), ret.begin(),
49 [&](const auto& e){return e == threshold;}
50 );
51 return ret;
52 }
53 float gini_impurity_index(const ArrayXf& classes, const vector<float>& uc);
54 float gain(const ArrayXf& lsplit, const ArrayXf& rsplit, bool classification,
55 vector<float> unique_classes);
56
57 template<typename T> vector<float> get_thresholds(const T& x);
58 tuple<string,float> get_best_variable_and_threshold(const Dataset& d, TreeNode& tn);
59
60 template<typename T>
61 tuple<float,float> best_threshold(const T& x, const ArrayXf& y, bool classification)
62 {
63 /* for each unique value in x, calculate the reduction in the
64 * heuristic brought about by
65 * splitting between that value and the next.
66 * set threshold according to the biggest reduction.
67 *
68 * returns: the threshold and the score.
69 */
70 // get all possible split masks based on variant type
71
72 vector<float> all_thresholds = get_thresholds(x);
73
75 float best_thresh, best_score = MAX_FLT;
76 int i = 0 ;
77
78 vector<float> unique_classes;
79 if (classification)
80 unique_classes = unique(y);
81
82 // all_thresholds contains the unique values to be used as thresholds
83 // with a >= operator
84 for (const auto thresh: all_thresholds)
85 {
86
87 ArrayXb mask = threshold_mask(x, thresh);
88 vector<size_t> L_idx, R_idx;
89 tie (L_idx, R_idx) = Util::mask_to_indices(mask);
90
91 // split data
92 const ArrayXf& lhs = y(L_idx);
93 const ArrayXf& rhs = y(R_idx);
94
95 if (lhs.size() == 0 || rhs.size() == 0)
96 continue;
97
98 //TODO: templatize gain for classification/regression
99 float score = gain(lhs, rhs, classification, unique_classes);
100 /* fmt::print("threshold={}; lhs={};rhs={}; score = {}\n",thresh,lhs,rhs,score); */
101 if (score < best_score || i == 0)
102 {
103 best_score = score;
104 best_thresh = thresh;
105 }
106 ++i;
107 }
108
109 best_thresh = std::isinf(best_thresh)?
110 0 : std::isnan(best_thresh)?
111 0 : best_thresh;
112
113 return make_tuple(best_thresh, best_score);
114
115 }
116
117 template<typename T>
118 void get_best_threshold_by_type(const Dataset& d, auto& results)
119 {
120 DataType DT = DataTypeEnum<T>::value;
121 /* fmt::print("get_best_threshold_by_type [T = {}]\n",DT); */
122
123 vector<string> keys;
124 float best_score = MAX_FLT;
125 string feature="";
126 float threshold=0.0;
127 int i = 0;
128
129 if (d.features_of_type.find(DT) != d.features_of_type.end())
130 keys = d.features_of_type.at(DT);
131 else
132 {
133 /* fmt::print("didn't find features of type {} in data\n",DT); */
134 return; // std::make_tuple(feature, threshold, best_score);
135 }
136
137 for (const auto& key : keys)
138 {
139 float tmp_thresh, score;
140
141 tie(tmp_thresh, score) = best_threshold(std::get<T>(d[key]), d.y, d.classification);
142 // fmt::print("best threshold for {} = {:.3f}, score = {:.3f}\n",key,tmp_thresh,score);
143 if (score < best_score | i == 0)
144 {
145 best_score = score;
146 feature = key;
147 threshold = tmp_thresh;
148 }
149 ++i;
150 }
151 auto tmp = std::make_tuple(feature, threshold, best_score);
152 results.push_back(std::make_tuple(feature, threshold, best_score));
153 }
154
155 template<typename Ts, std::size_t... Is>
156 auto get_best_thresholds(const Dataset&d, std::index_sequence<Is...>)
157 {
158 /* fmt::print("get_best_thresholds\n"); */
159 using entry = tuple<string, float, float>;
160 auto compare = [](const entry& a, const entry& b){
161 return (std::get<2>(a) < std::get<2>(b));
162 };
163
164 vector<entry> results;
165 /* fmt::print("get_best_thresholds::results size:{}\n",results.size()); */
167 /* fmt::print("getting best\n"); */
168 auto best = std::ranges::min_element(results, compare);
169 /* fmt::print("best: {}\n",(*best)); */
170 return (*best);
171 }
172
174 template<typename T>
175 T stitch(array<T,2>& child_outputs, const ArrayXb& mask)
176 {
177 T result(mask.size());
178
179 vector<size_t> L_idx, R_idx;
180 tie (L_idx, R_idx) = Util::mask_to_indices(mask);
181 result(L_idx) = child_outputs.at(0);
182 result(R_idx) = child_outputs.at(1);
183 return result;
184
185 }
186} // namespace Split
187
189// Split operator overload
190template<NodeType NT, typename S, bool Fit>
191struct Operator<NT, S, Fit, enable_if_t<is_in_v<NT, NodeType::SplitOn, NodeType::SplitBest>>>
192{
193 using ArgTypes = typename S::ArgTypes;
194 using FirstArg = typename S::FirstArg;
195 using RetType = typename S::RetType;
196 using W = typename S::WeightType;
197 static constexpr size_t ArgCount = S::ArgCount;
198 // get arg types from tuple by index
199 template <std::size_t N>
200 using NthType = typename S::template NthType<N>;
201
202 /* static constexpr auto F = [](const auto& ...args){ Function<NT> f{}; return f(args...); }; */
203 static constexpr Function<NT> F{};
204
205 array<RetType,2> get_kids(const array<Dataset, 2>& d, TreeNode& tn, const W** weights=nullptr) const
206 {
207 using arg_type = NthType<1>;
208 array<arg_type,2> child_outputs;
209
210 TreeNode* sib = tn.first_child;
211 if constexpr (NT==NodeType::SplitOn)
212 sib = sib->next_sibling;
213
214 for (int i = 0; i < 2; ++i)
215 {
216 if (d.at(i).get_n_samples() > 0)
217 {
218 if constexpr (Fit)
219 child_outputs.at(i) = sib->fit<arg_type>(d.at(i));
220 else
221 child_outputs.at(i) = sib->predict<arg_type>(d.at(i), weights);
222 }
223 sib = sib->next_sibling;
224 }
225 return child_outputs;
226 };
227
228 RetType fit(const Dataset& d, TreeNode& tn) const {
229 auto& threshold = tn.data.W;
230
231 // set feature and threshold
232 if constexpr (NT == NodeType::SplitOn)
233 {
234 // split on first child
235 FirstArg split_feature = tn.first_child->fit<FirstArg>(d);
236 // get the best splitting threshold
237 tie(threshold, ignore) = Split::best_threshold(split_feature, d.y, d.classification);
238 }
239 else // splitbest
240 {
241 // avoid updating the split feature
242 if (tn.data.get_keep_split_feature() && tn.data.get_feature()!="")
243 {
244 // TODO: I think the if-else clausules could be simplified
245
246 auto values = d[tn.data.get_feature()];
247
248 // Threshold will be optimized regardless.
249 if (std::holds_alternative<ArrayXf>(values))
250 tie(threshold, ignore) = Split::best_threshold(std::get<ArrayXf>(values), d.y, d.classification);
251 else if (std::holds_alternative<ArrayXi>(values))
252 tie(threshold, ignore) = Split::best_threshold(std::get<ArrayXi>(values), d.y, d.classification);
253 else if (std::holds_alternative<ArrayXb>(values))
254 tie(threshold, ignore) = Split::best_threshold(std::get<ArrayXb>(values), d.y, d.classification);
255 }
256 else // keep_split_feature == false
257 {
258 string feature = "";
259 tie(feature, threshold) = Split::get_best_variable_and_threshold(d, tn);
260 tn.data.set_feature(feature);
261 tn.data.set_feature_type(d.get_feature_type(feature));
262 }
263 }
264
265 return predict(d, tn);
266 }
267
268 RetType predict(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
269 {
270 const auto& threshold = tn.data.W;
271 const auto& feature = tn.data.get_feature();
272
273 // split the data
274 ArrayXb mask;
275 if (feature == "")
276 {
277 mask.resize(d.get_n_samples());
278 mask.fill(true);
279 }
280 else if constexpr (NT==NodeType::SplitBest)
281 mask = Split::threshold_mask(d[feature], threshold);
282 else {
283 auto split_feature = tn.first_child->predict<FirstArg>(d, weights);
284 mask = Split::threshold_mask(split_feature, threshold);
285 }
286
287 array<Dataset, 2> data_splits = d.split(mask);
288
289 auto child_outputs = get_kids(data_splits, tn, weights);
290
291 // stitch together outputs
292 // fmt::print("stitching outputs\n");
293 auto out = Split::stitch(child_outputs, mask);
294 /* auto out = mask.select(child_outputs.at(0), child_outputs.at(1)); */
295 /* cout << "returning " << std::get<RetType>(out) << endl; */
296
297 return out;
298 }
299 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const {
300 if constexpr (Fit)
301 return fit(d,tn);
302 else
303 return predict(d,tn,weights);
304 }
305};
306
307
308#endif
holds variable type data.
Definition data.h:51
bool classification
whether this is a classification problem
Definition data.h:83
DataType get_feature_type(const string &name) const
Definition data.h:235
int get_n_samples() const
Definition data.h:222
std::unordered_map< DataType, vector< string > > features_of_type
map from data types to features having that type.
Definition data.h:71
ArrayXf y
length N array, the target label
Definition data.h:80
std::array< Dataset, 2 > split(const ArrayXb &mask) const
Definition data.cpp:186
STL class.
STL class.
class tree_node_< Node > TreeNode
NodeType
Definition nodetype.h:31
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition functions.h:25
static float MAX_FLT
Definition init.h:61
Definition split.h:10
ArrayXb threshold_mask(const T &x, const float &threshold)
Applies a learned threshold to a feature, returning a mask.
Definition split.h:15
auto get_best_thresholds(const Dataset &d, std::index_sequence< Is... >)
Definition split.h:156
vector< float > get_thresholds(const T &x)
float gain(const ArrayXf &lsplit, const ArrayXf &rsplit, bool classification, vector< float > unique_classes)
tuple< float, float > best_threshold(const T &x, const ArrayXf &y, bool classification)
Definition split.h:61
void get_best_threshold_by_type(const Dataset &d, auto &results)
Definition split.h:118
float gini_impurity_index(const ArrayXf &classes, const vector< float > &uc)
T stitch(array< T, 2 > &child_outputs, const ArrayXb &mask)
Stitches together outputs from left or right child based on threshold.
Definition split.h:175
tuple< string, float > get_best_variable_and_threshold(const Dataset &d, TreeNode &tn)
DataType
data types.
Definition types.h:143
static constexpr bool is_in_v
Definition nodetype.h:269
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition split.h:299
array< RetType, 2 > get_kids(const array< Dataset, 2 > &d, TreeNode &tn, const W **weights=nullptr) const
Definition split.h:205
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition split.h:268