Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
operator.h
Go to the documentation of this file.
1#ifndef OPERATOR_H
2#define OPERATOR_H
3
4#include "../init.h"
5#include "tree_node.h"
6#include "../util/utils.h"
8namespace Brush{
10namespace util{
11 ////////////////////////////////////////////////////////////////////////////////
12 /// @brief get weight
13 /// @tparam T return type
14 /// @tparam Scalar scalar type of return type
15 /// @param tn tree node
16 /// @param weights option pointer to a weight array, used in place of node weight
17 /// @return
18 template<typename T, typename Scalar, typename W>
20 Scalar get_weight(const TreeNode& tn, const W** weights=nullptr)
21 {
22 Scalar w;
23 // Prediction case: weight is stored in the node data.
24 if (weights == nullptr)
25 {
26 if constexpr (std::is_floating_point_v<decltype(tn.data.W)> || std::is_integral_v<decltype(tn.data.W)>) {
27 if (std::isnan(tn.data.W) || tn.data.W == std::numeric_limits<decltype(tn.data.W)>::lowest()) {
28 HANDLE_ERROR_THROW("TreeNode weight (W) is not set or is invalid for node: " + tn.data.name);
29 }
30 }
32 try //TODO: remove this try catch after debugging it
33 {
34 w = Scalar(tn.data.W);
35 }
36 catch (const std::exception& e) {
37 std::string err_msg = "Null pointer dereference: *weights is nullptr. "
38 "TreeNode ret_type: " + std::to_string(static_cast<int>(tn.data.ret_type)) +
39 ", name: " + tn.data.name;
40 std::cerr << "[EXCEPTION] get_weight: caught std::exception: " << e.what() << err_msg << std::endl;
41 throw; // Re-throw to allow crash
42 }
43 }
44 else
45 {
46 try //TODO: remove this try catch after debugging it
47 {
48 if (*weights == nullptr) {
49 std::string err_msg = "Null pointer dereference: *weights is nullptr. "
50 "TreeNode ret_type: " + std::to_string(static_cast<int>(tn.data.ret_type)) +
51 ", name: " + tn.data.name;
52 HANDLE_ERROR_THROW("Null pointer dereference: *weights is nullptr. " + err_msg);
53 }
54
55 // NLS case 1: floating point weight is stored in weights
56 if constexpr (is_same_v<Scalar, W>)
57 w = **weights;
59 // NLS case 2: a Jet/Dual weight is stored in weights, but this constant is a
60 // integer type. We need to do some casting
61 else if constexpr (is_same_v<Scalar, iJet> && is_same_v<W, fJet>) {
62 using WScalar = typename Scalar::Scalar;
63 WScalar tmp = WScalar((**weights).a);
64 w = Scalar(tmp);
65 }
66 // NLS case 3: a Jet/Dual weight is stored in weights, matching Scalar type
67 else
68 w = Scalar(**weights);
69
70 *weights = *weights+1;
71 }
72 catch (const std::exception& e) {
73 std::string err_msg = "Null pointer dereference: *weights is nullptr. "
74 "TreeNode ret_type: " + std::to_string(static_cast<int>(tn.data.ret_type)) +
75 ", name: " + tn.data.name;
76 std::cerr << "[EXCEPTION] get_weight: caught std::exception: " << e.what() << err_msg << std::endl;
77 throw; // Re-throw to allow crash
78 }
79 }
80 return w;
81 };
82 template<typename T, typename Scalar, typename W>
84 Scalar get_weight(const TreeNode& tn, const W** weights=nullptr)
85 {
86 // we cannot weight a boolean feature. Nevertheless, we need to provide
87 // an implementation for get_weight behavior, so the metaprogramming
88 // doesn't fail to get a matching signature.
89
90 if (tn.data.get_is_weighted())
91 // Node's init() function avoids the creation of weighted nodes,
92 // and the setter for `is_weighted` prevent enabling weight on
93 // boolean values.
94 HANDLE_ERROR_THROW(fmt::format("boolean terminal is weighted, but "
95 "it should not\n"));
96
97 // std::cout << "Returning weight: Scalar(true) for a boolean node tn " << tn.data.name << std::endl;
98 return Scalar(true);
99 };
100}
102// Operator class
103
109template<NodeType NT, typename S, bool Fit, typename E=void>
110struct Operator
111{
119 * array and the operator is applied to that array
120 */
121 using ArgTypes = conditional_t<
122 ((UnaryOp<NT> || NaryOp<NT>) && S::ArgCount > 1),
123 Array<typename S::FirstArg::Scalar, -1, S::ArgCount>,
124 typename S::ArgTypes>;
125
127 using RetType = typename S::RetType;
128
130 static constexpr size_t ArgCount = S::ArgCount;
131
133 template <std::size_t N>
134 using NthType = typename S::template NthType<N>;
135
137 using W = typename S::WeightType;
138
140 static constexpr auto F = [](const auto& ...args) {
141 Function<NT> f;
142 return f(args...);
143 };
144
145 Operator() = default;
147 // Utilities to grab child outputs.
148
150 template<typename T=ArgTypes> requires(is_std_array_v<T> || is_eigen_array_v<T>)
151 T get_kids(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
152 {
153 T child_outputs;
154 using arg_type = std::conditional_t<is_std_array_v<T>,
155 typename T::value_type, Array<typename S::FirstArg::Scalar, -1, 1>>;
156 if constexpr (is_eigen_array_v<T>)
157 child_outputs.resize(d.get_n_samples(), Eigen::NoChange);
158
159 TreeNode* sib = tn.first_child;
160 for (int i = 0; i < ArgCount; ++i)
161 {
162 if (sib == nullptr)
163 HANDLE_ERROR_THROW("bad sibling ptr in get kids");
164 if constexpr (Fit){
165 if constexpr(is_std_array_v<T>)
166 child_outputs.at(i) = sib->fit<arg_type>(d);
167 else
168 child_outputs.col(i) = sib->fit<arg_type>(d);
169 }
170 else{
171 if constexpr(is_std_array_v<T>)
172 child_outputs.at(i) = sib->predict<arg_type>(d, weights);
173 else
174 child_outputs.col(i) = sib->predict<arg_type>(d, weights);
175 }
176 sib = sib->next_sibling;
177 }
178 return child_outputs;
179 };
180
182 template<int I>
183 NthType<I> get_kid(const Dataset& d, TreeNode& tn, const W** weights ) const
184 {
185 TreeNode* sib = tn.first_child;
186 for (int i = 0; i < I; ++i)
187 {
188 sib= sib->next_sibling;
189 }
190 if constexpr(Fit)
191 return sib->fit<NthType<I>>(d);
192 else
193 return sib->predict<NthType<I>>(d,weights);
194 };
205 template<typename T, size_t ...Is> requires(is_tuple_v<T>)
206 T get_kids_seq(const Dataset& d, TreeNode& tn, const W** weights, std::index_sequence<Is...>) const
207 {
208 return std::make_tuple(get_kid<Is>(d,tn,weights)...);
209 };
210
217 template<typename T=ArgTypes> requires(is_tuple_v<T>)
218 T get_kids(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
219 {
220 return get_kids_seq<T>(d, tn, weights, std::make_index_sequence<ArgCount>{});
221 };
222
224
228
229 template<typename T=ArgTypes> requires ( is_std_array_v<T> || is_tuple_v<T>)
230 RetType apply(const T& inputs) const
231 {
232 return std::apply(F, inputs);
233 }
234
239 template<typename T=ArgTypes> requires ( is_eigen_array_v<T> && !is_std_array_v<T>)
240 RetType apply(const T& inputs) const
241 {
242 return F(inputs);
243 }
244
252 template<typename T=ArgTypes, typename Scalar=RetType::Scalar>
253 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
254 {
255 auto inputs = get_kids(d, tn, weights);
256 if constexpr (is_one_of_v<Scalar,float,fJet>)
257 {
258 if (tn.data.get_is_weighted())
259 {
260 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
261 return this->apply(inputs)*w;
262 }
263 }
264 return this->apply(inputs);
265 };
266
267 // overloaded version for offset sum
268 template<typename T=ArgTypes, typename Scalar=RetType::Scalar>
270 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
271 {
272 auto inputs = get_kids(d, tn, weights);
273 if constexpr (is_one_of_v<Scalar,float,fJet>)
274 {
275 if (tn.data.get_is_weighted())
276 {
277 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
278 return this->apply(inputs) + w;
279 }
280 }
281 return this->apply(inputs);
282 };
283};
284
287template<typename S, bool Fit>
288struct Operator<NodeType::Terminal, S, Fit>
289{
290 using RetType = typename S::RetType;
291 using W = typename S::WeightType;
292
293 // Standard C++ types
294 template<typename T=RetType, typename Scalar=typename T::Scalar>
296 RetType eval(const Dataset& d, const TreeNode& tn, const W** weights=nullptr) const
297 {
298 if constexpr (is_one_of_v<Scalar,float,fJet>)
299 {
300 if (tn.data.get_is_weighted())
301 {
302 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
303 return this->get<RetType>(d, tn.data.get_feature())*w;
304 }
305 }
306 return this->get<RetType>(d,tn.data.get_feature());
307 };
308
309 // Jet types
310 template <typename T = RetType, typename Scalar=typename T::Scalar>
312 RetType eval(const Dataset &d, const TreeNode &tn, const W **weights = nullptr) const
313 {
314 using nonJetType = UnJetify_t<RetType>;
315 if constexpr (is_one_of_v<Scalar,float,fJet>)
316 {
317 if (tn.data.get_is_weighted())
318 {
319 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
320 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>()*w;
321 }
322 }
323 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>();
324 };
325
326 // Accessing dataset directly
327 template<typename T>
328 auto get(const Dataset& d, const string& feature) const
329 {
330 if (std::holds_alternative<T>(d[feature]))
331 return std::get<T>(d[feature]);
332
333 HANDLE_ERROR_THROW(fmt::format("Failed to return type {} for '{}'. The feature's original ret type is {}.\n",
335 feature,
337 ));
338
339 return T();
340 }
341};
342
344// Constant Overloads
345template<typename S, bool Fit>
346struct Operator<NodeType::Constant, S, Fit>
347{
348 using RetType = typename S::RetType;
349 using W = typename S::WeightType;
350
351 template<typename T=RetType, typename Scalar=T::Scalar, int N=T::NumDimensions>
352 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
353 {
354 // Scalar w = get_weight(tn, weights);
356
357 if constexpr (Fit)
358 {
359 // there is no need to fit a constant node, get_weight will be called
360 }
361
362 if constexpr (N == 1)
363 return RetType::Constant(d.get_n_samples(), w);
364 else
365 return RetType::Constant(d.get_n_samples(), d.get_n_features(), w);
366 };
367
368};
369
371// MeanLabel overload
372template<typename S, bool Fit>
374{
375 using RetType = typename S::RetType;
376 using W = typename S::WeightType;
377
378 RetType fit(const Dataset& d, TreeNode& tn) const {
379 // we take the mode of the labels if it is a classification problem
380 if (d.classification)
381 {
382 std::unordered_map<float, int> counters;
383 for (float val : d.y) {
384 if (counters.find(val) != counters.end()) {
385 counters[val] += 1;
386 }
387 else
388 {
389 counters[val] = 1;
390 }
391 }
392
393 auto mode = std::max_element(
394 counters.begin(), counters.end(),
395 [](const auto& a, const auto& b) { return a.second < b.second; }
396 );
397
398 tn.data.W = mode->first;
399 }
400 else
401 {
402 tn.data.W = d.y.mean();
403 }
404
405 return predict(d, tn);
406 };
407
408 template<typename T=RetType, typename Scalar=T::Scalar, int N=T::NumDimensions>
409 RetType predict(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
410 {
412 if constexpr (N == 1)
413 return RetType::Constant(d.get_n_samples(), w);
414 else
415 return RetType::Constant(d.get_n_samples(), d.get_n_features(), w);
416 };
417
418 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const {
419 if constexpr (Fit)
420 return fit(d,tn);
421 else
422 return predict(d,tn,weights);
423 };
424};
425
427// Operator overloads
428// Split
429#include "split.h"
431// Dispatch functions
432template<typename R, NodeType NT, typename S, bool Fit, typename W>
433inline R DispatchOp(const Dataset& d, TreeNode& tn, const W** weights)
434{
435 const auto op = Operator<NT,S,Fit>{};
436 return op.eval(d, tn, weights);
437};
438
439template<typename R, NodeType NT, typename S, bool Fit>
440inline R DispatchOp(const Dataset& d, TreeNode& tn)
441{
442 const auto op = Operator<NT,S,Fit>{};
443 return op.eval(d, tn);
444};
445
446} // Brush
447
448#endif
holds variable type data.
Definition data.h:51
bool classification
whether this is a classification problem
Definition data.h:83
int get_n_samples() const
Definition data.h:222
int get_n_features() const
Definition data.h:228
ArrayXf y
length N array, the target label
Definition data.h:80
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
Scalar get_weight(const TreeNode &tn, const W **weights=nullptr)
get weight
Definition operator.h:20
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
static constexpr bool is_tuple_v
Definition types.h:259
NodeType
Definition nodetype.h:31
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
Definition operator.h:433
static constexpr bool is_eigen_array_v
Definition types.h:253
static constexpr bool is_one_of_v
Definition types.h:33
typename UnJetify< T >::type UnJetify_t
Definition signatures.h:56
static constexpr bool is_std_array_v
Definition types.h:245
static constexpr bool is_in_v
Definition nodetype.h:269
static constexpr bool NaryOp
Definition nodetype.h:320
static constexpr bool UnaryOp
Definition nodetype.h:274
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:352
RetType fit(const Dataset &d, TreeNode &tn) const
Definition operator.h:378
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:409
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:418
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:312
auto get(const Dataset &d, const string &feature) const
Definition operator.h:328
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:296
Core computation of a node's function to data.
Definition operator.h:111
Operator()=default
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
evaluate the operator on the data. main entry point.
Definition operator.h:253
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:270
typename S::WeightType W
set weight type
Definition operator.h:137
T get_kids_seq(const Dataset &d, TreeNode &tn, const W **weights, std::index_sequence< Is... >) const
Makes and returns a tuple of child outputs.
Definition operator.h:206
static constexpr size_t ArgCount
stores the argument count of the operator
Definition operator.h:130
typename S::RetType RetType
return type of the operator
Definition operator.h:127
typename S::template NthType< N > NthType
utility for returning the type of the Nth argument
Definition operator.h:134
static constexpr auto F
wrapper function for the node function
Definition operator.h:140
conditional_t<((UnaryOp< NT >||NaryOp< NT >) &&S::ArgCount > 1), Array< typename S::FirstArg::Scalar, -1, S::ArgCount >, typename S::ArgTypes > ArgTypes
set argument types to those of the signature unless:
Definition operator.h:121
NthType< I > get_kid(const Dataset &d, TreeNode &tn, const W **weights) const
gets one kid for a tuple of kids
Definition operator.h:183
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::tuple of kids. Used when child arguments are different types.
Definition operator.h:218
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::array or eigen array of kids
Definition operator.h:151
RetType apply(const T &inputs) const
Apply node function in a functional style.
Definition operator.h:230
RetType apply(const T &inputs) const
Apply the node function like a function.
Definition operator.h:240