Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
split.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4*/
5#ifndef SPLIT_H
6#define SPLIT_H
7
9// Split Node Overloads
10namespace Split{
11 template<typename T>
12 ArrayXb threshold_mask(const T& x, const float& threshold);
14 template<typename T> requires same_as<typename T::Scalar, bool>
15 ArrayXb threshold_mask(const T& x, const float& threshold) {
16 return x;
17 }
18 template<typename T> requires same_as<typename T::Scalar, bJet>
19 ArrayXb threshold_mask(const T& x, const float& threshold) {
20 ArrayXb ret(x.size());
21 for (int i = 0; i< x.size(); ++i)
22 ret(i) = x(i).a;
23 return ret;
24 }
25 template<typename T> requires same_as<typename T::Scalar, float>
26 ArrayXb threshold_mask(const T& x, const float& threshold) {
27 return (x > threshold);
28 }
29 template<typename T> requires same_as<typename T::Scalar, fJet>
30 ArrayXb threshold_mask(const T& x, const float& threshold) {
31 ArrayXb ret(x.size());
32 std::transform(
33 x.begin(), x.end(), ret.begin(),
34 [&](const auto& e){return e > threshold;}
35 );
36 return ret;
37 }
38 template<typename T> requires same_as<typename T::Scalar, int>
39 ArrayXb threshold_mask(const T& x, const float& threshold) {
40 return (x == threshold);
41 }
42
43 template<typename T> requires same_as<typename T::Scalar, iJet>
44 ArrayXb threshold_mask(const T& x, const float& threshold) {
45 // return (x == threshold);
46 ArrayXb ret(x.size());
47 std::transform(
48 x.begin(), x.end(), ret.begin(),
49 [&](const auto& e){return e == threshold;}
50 );
51 return ret;
52 }
53 float gini_impurity_index(const ArrayXf& classes, const vector<float>& uc);
54 float gain(const ArrayXf& lsplit, const ArrayXf& rsplit, bool classification,
55 vector<float> unique_classes);
56
57 template<typename T> vector<float> get_thresholds(const T& x);
58 tuple<string,float> get_best_variable_and_threshold(const Dataset& d, TreeNode& tn);
59
60 template<typename T>
61 tuple<float,float> best_threshold(const T& x, const ArrayXf& y, bool classification)
62 {
63 /* for each unique value in x, calculate the reduction in the
64 * heuristic brought about by
65 * splitting between that value and the next.
66 * set threshold according to the biggest reduction.
67 *
68 * returns: the threshold and the score.
69 */
70 // get all possible split masks based on variant type
71
72 vector<float> all_thresholds = get_thresholds(x);
73
75 float best_thresh, best_score = MAX_FLT;
76 int i = 0 ;
77 vector<float> unique_classes;
78 if (classification)
79 unique_classes = unique(y);
80
81 for (const auto thresh: all_thresholds)
82 {
83
85 vector<size_t> L_idx, R_idx;
86 tie (L_idx, R_idx) = Util::mask_to_indices(mask);
87
88 // split data
89 const ArrayXf& lhs = y(L_idx);
90 const ArrayXf& rhs = y(R_idx);
91
92 if (lhs.size() == 0 || rhs.size() == 0)
93 continue;
94
95 //TODO: templatize gain for classification/regression
96 float score = gain(lhs, rhs, classification, unique_classes);
97 /* fmt::print("threshold={}; lhs={};rhs={}; score = {}\n",thresh,lhs,rhs,score); */
98 if (score < best_score || i == 0)
99 {
100 best_score = score;
102 }
103 ++i;
104 }
105
106 best_thresh = std::isinf(best_thresh)?
107 0 : std::isnan(best_thresh)?
108 0 : best_thresh;
109
110 return make_tuple(best_thresh, best_score);
111
112 }
113
114 template<typename T>
115 void get_best_threshold_by_type(const Dataset& d, auto& results)
116 {
117 DataType DT = DataTypeEnum<T>::value;
118 /* fmt::print("get_best_threshold_by_type [T = {}]\n",DT); */
119
120 vector<string> keys;
121 float best_score = MAX_FLT;
122 string feature="";
123 float threshold=0.0;
124 int i = 0;
125
126 if (d.features_of_type.find(DT) != d.features_of_type.end())
127 keys = d.features_of_type.at(DT);
128 else
129 {
130 /* fmt::print("didn't find features of type {} in data\n",DT); */
131 return; // std::make_tuple(feature, threshold, best_score);
132 }
133
134 for (const auto& key : keys)
135 {
136 float tmp_thresh, score;
137
138 tie(tmp_thresh, score) = best_threshold(std::get<T>(d[key]), d.y, d.classification);
139 // fmt::print("best threshold for {} = {:.3f}, score = {:.3f}\n",key,tmp_thresh,score);
140 if (score < best_score | i == 0)
141 {
142 best_score = score;
143 feature = key;
145 }
146 ++i;
147 }
148 auto tmp = std::make_tuple(feature, threshold, best_score);
149 results.push_back(std::make_tuple(feature, threshold, best_score));
150 }
151
152 template<typename Ts, std::size_t... Is>
153 auto get_best_thresholds(const Dataset&d, std::index_sequence<Is...>)
154 {
155 /* fmt::print("get_best_thresholds\n"); */
156 using entry = tuple<string, float, float>;
157 auto compare = [](const entry& a, const entry& b){
158 return (std::get<2>(a) < std::get<2>(b));
159 };
160
162 /* fmt::print("get_best_thresholds::results size:{}\n",results.size()); */
164 /* fmt::print("getting best\n"); */
165 auto best = std::ranges::min_element(results, compare);
166 /* fmt::print("best: {}\n",(*best)); */
167 return (*best);
168 }
169
171 template<typename T>
173 {
174 T result(mask.size());
175
176 vector<size_t> L_idx, R_idx;
177 tie (L_idx, R_idx) = Util::mask_to_indices(mask);
178 result(L_idx) = child_outputs.at(0);
179 result(R_idx) = child_outputs.at(1);
180 return result;
181
182 }
183} // namespace Split
184
186// Split operator overload
187template<NodeType NT, typename S, bool Fit>
188struct Operator<NT, S, Fit, enable_if_t<is_in_v<NT, NodeType::SplitOn, NodeType::SplitBest>>>
189{
190 using ArgTypes = typename S::ArgTypes;
191 using FirstArg = typename S::FirstArg;
192 using RetType = typename S::RetType;
193 using W = typename S::WeightType;
194 static constexpr size_t ArgCount = S::ArgCount;
195 // get arg types from tuple by index
196 template <std::size_t N>
197 using NthType = typename S::NthType<N>;
198
199 /* static constexpr auto F = [](const auto& ...args){ Function<NT> f{}; return f(args...); }; */
200 static constexpr Function<NT> F{};
201
202 array<RetType,2> get_kids(const array<Dataset, 2>& d, TreeNode& tn, const W** weights=nullptr) const
203 {
204 using arg_type = NthType<1>;
206
207 TreeNode* sib = tn.first_child;
208 if constexpr (NT==NodeType::SplitOn)
209 sib = sib->next_sibling;
210
211 for (int i = 0; i < 2; ++i)
212 {
213 if (d.at(i).get_n_samples() > 0)
214 {
215 if constexpr (Fit)
216 child_outputs.at(i) = sib->fit<arg_type>(d.at(i));
217 else
218 child_outputs.at(i) = sib->predict<arg_type>(d.at(i), weights);
219 }
220 sib = sib->next_sibling;
221 }
222 return child_outputs;
223 };
224
225 RetType fit(const Dataset& d, TreeNode& tn) const {
226 auto& threshold = tn.data.W;
227
228 // set feature and threshold
229 if constexpr (NT == NodeType::SplitOn)
230 {
231 // split on first child
232 FirstArg split_feature = tn.first_child->fit<FirstArg>(d);
233 // get the best splitting threshold
234 tie(threshold, ignore) = Split::best_threshold(split_feature, d.y, d.classification);
235 }
236 else
237 {
238 string feature = "";
240 tn.data.set_feature(feature);
241 }
242
243 return predict(d, tn);
244 }
245
246 RetType predict(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
247 {
248 const auto& threshold = tn.data.W;
249 const auto& feature = tn.data.get_feature();
250
251 // split the data
253 if (feature == "")
254 {
255 mask.resize(d.get_n_samples());
256 mask.fill(true);
257 }
258 else if constexpr (NT==NodeType::SplitBest)
260 else {
261 auto split_feature = tn.first_child->predict<FirstArg>(d, weights);
263 }
264
266
267 auto child_outputs = get_kids(data_splits, tn, weights);
268
269 // stitch together outputs
270 // fmt::print("stitching outputs\n");
272 /* auto out = mask.select(child_outputs.at(0), child_outputs.at(1)); */
273 /* cout << "returning " << std::get<RetType>(out) << endl; */
274
275 return out;
276 }
277 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const {
278 if constexpr (Fit)
279 return fit(d,tn);
280 else
281 return predict(d,tn,weights);
282 }
283};
284
285
286#endif
void bind_engine(py::module &m, string name)
class tree_node_< Node > TreeNode
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition functions.h:25
static float MAX_FLT
Definition init.h:61
Definition split.h:10
ArrayXb threshold_mask(const T &x, const float &threshold)
Applies a learned threshold to a feature, returning a mask.
Definition split.h:15
auto get_best_thresholds(const Dataset &d, std::index_sequence< Is... >)
Definition split.h:153
vector< float > get_thresholds(const T &x)
float gain(const ArrayXf &lsplit, const ArrayXf &rsplit, bool classification, vector< float > unique_classes)
tuple< float, float > best_threshold(const T &x, const ArrayXf &y, bool classification)
Definition split.h:61
void get_best_threshold_by_type(const Dataset &d, auto &results)
Definition split.h:115
float gini_impurity_index(const ArrayXf &classes, const vector< float > &uc)
T stitch(array< T, 2 > &child_outputs, const ArrayXb &mask)
Stitches together outputs from left or right child based on threshold.
Definition split.h:172
tuple< string, float > get_best_variable_and_threshold(const Dataset &d, TreeNode &tn)
NodeType NT
Definition nodetype.h:254
static constexpr bool is_in_v
Definition nodetype.h:252
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition split.h:277
array< RetType, 2 > get_kids(const array< Dataset, 2 > &d, TreeNode &tn, const W **weights=nullptr) const
Definition split.h:202
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition split.h:246