Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
operator.h
Go to the documentation of this file.
1#ifndef OPERATOR_H
2#define OPERATOR_H
3
4#include "../init.h"
5#include "tree_node.h"
6#include "../util/utils.h"
8namespace Brush{
10namespace util{
11 ////////////////////////////////////////////////////////////////////////////////
12 /// @brief get weight
13 /// @tparam T return type
14 /// @tparam Scalar scalar type of return type
15 /// @param tn tree node
16 /// @param weights option pointer to a weight array, used in place of node weight
17 /// @return
18 template<typename T, typename Scalar, typename W>
20 Scalar get_weight(const TreeNode& tn, const W** weights=nullptr)
21 {
22 Scalar w;
23 // Prediction case: weight is stored in the node data.
24 if (weights == nullptr)
25 {
26 if constexpr (std::is_floating_point_v<decltype(tn.data.W)> || std::is_integral_v<decltype(tn.data.W)>) {
27 if (std::isnan(tn.data.W) || tn.data.W == std::numeric_limits<decltype(tn.data.W)>::lowest()) {
28 HANDLE_ERROR_THROW("TreeNode weight (W) is not set or is invalid for node: " + tn.data.name);
29 }
30 }
32 w = Scalar(tn.data.W);
33 }
34 else
35 {
36 if (*weights == nullptr) {
37 std::string err_msg = "Null pointer dereference: *weights is nullptr. "
38 "TreeNode ret_type: " + std::to_string(static_cast<int>(tn.data.ret_type)) +
39 ", name: " + tn.data.name;
40 HANDLE_ERROR_THROW("Null pointer dereference: *weights is nullptr. " + err_msg);
41 }
42
43 // NLS case 1: floating point weight is stored in weights
44 if constexpr (is_same_v<Scalar, W>)
45 w = **weights;
46
47 // NLS case 2: a Jet/Dual weight is stored in weights, but this constant is a
48 // integer type. We need to do some casting
49 else if constexpr (is_same_v<Scalar, iJet> && is_same_v<W, fJet>) {
50 using WScalar = typename Scalar::Scalar;
51 WScalar tmp = WScalar((**weights).a);
52 w = Scalar(tmp);
53 }
54 // NLS case 3: a Jet/Dual weight is stored in weights, matching Scalar type
55 else
56 w = Scalar(**weights);
57
58 *weights = *weights+1;
59 }
60 return w;
61 };
62 template<typename T, typename Scalar, typename W>
64 Scalar get_weight(const TreeNode& tn, const W** weights=nullptr)
65 {
66 // we cannot weight a boolean feature. Nevertheless, we need to provide
67 // an implementation for get_weight behavior, so the metaprogramming
68 // doesn't fail to get a matching signature.
69
70 if (tn.data.get_is_weighted())
71 // Node's init() function avoids the creation of weighted nodes,
72 // and the setter for `is_weighted` prevent enabling weight on
73 // boolean values.
74 HANDLE_ERROR_THROW(fmt::format("boolean terminal is weighted, but "
75 "it should not be\n"));
76
77 // std::cout << "Returning weight: Scalar(true) for a boolean node tn " << tn.data.name << std::endl;
78 return Scalar(true);
79 };
80}
82// Operator class
83
89template<NodeType NT, typename S, bool Fit, typename E=void>
90struct Operator
91{
101 using ArgTypes = conditional_t<
102 ((UnaryOp<NT> || NaryOp<NT>) && S::ArgCount > 1),
103 Array<typename S::FirstArg::Scalar, -1, S::ArgCount>,
104 typename S::ArgTypes>;
105
107 using RetType = typename S::RetType;
108
110 static constexpr size_t ArgCount = S::ArgCount;
111
113 template <std::size_t N>
114 using NthType = typename S::template NthType<N>;
115
117 using W = typename S::WeightType;
118
119 /// @brief wrapper function for the node function
120 static constexpr auto F = [](const auto& ...args) {
121 Function<NT> f;
122 return f(args...);
123 };
124
125 Operator() = default;
127 // Utilities to grab child outputs.
128
130 template<typename T=ArgTypes> requires(is_std_array_v<T> || is_eigen_array_v<T>)
131 T get_kids(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
132 {
133 T child_outputs;
134 using arg_type = std::conditional_t<is_std_array_v<T>,
135 typename T::value_type, Array<typename S::FirstArg::Scalar, -1, 1>>;
136 if constexpr (is_eigen_array_v<T>)
137 child_outputs.resize(d.get_n_samples(), Eigen::NoChange);
138
139 TreeNode* sib = tn.first_child;
140 for (int i = 0; i < ArgCount; ++i)
141 {
142 if (sib == nullptr)
143 HANDLE_ERROR_THROW("bad sibling ptr in get kids");
144 if constexpr (Fit){
145 if constexpr(is_std_array_v<T>)
146 child_outputs.at(i) = sib->fit<arg_type>(d);
147 else
148 child_outputs.col(i) = sib->fit<arg_type>(d);
149 }
150 else{
151 if constexpr(is_std_array_v<T>)
152 child_outputs.at(i) = sib->predict<arg_type>(d, weights);
153 else
154 child_outputs.col(i) = sib->predict<arg_type>(d, weights);
155 }
156 sib = sib->next_sibling;
158 return child_outputs;
159 };
160
162 template<int I>
163 NthType<I> get_kid(const Dataset& d, TreeNode& tn, const W** weights ) const
164 {
165 TreeNode* sib = tn.first_child;
166 for (int i = 0; i < I; ++i)
167 {
168 sib= sib->next_sibling;
169 }
170 if constexpr(Fit)
171 return sib->fit<NthType<I>>(d);
172 else
173 return sib->predict<NthType<I>>(d,weights);
174 };
175
176 /**
177 * @brief Makes and returns a tuple of child outputs
178 *
179 * @tparam T a tuple
180 * @tparam Is integer sequence
181 * @param d dataset
182 * @param tn a tree node
183 * @return a tuple with elements corresponding to each child node
184 */
185 template<typename T, size_t ...Is> requires(is_tuple_v<T>)
186 T get_kids_seq(const Dataset& d, TreeNode& tn, const W** weights, std::index_sequence<Is...>) const
187 {
188 return std::make_tuple(get_kid<Is>(d,tn,weights)...);
189 };
190
192 /// @tparam T argument types
193 /// @param d the dataset
194 /// @param tn the tree node
195 /// @param weights option pointer to a weight array, used in place of node weight
196 /// @return a tuple of the child arguments
197 template<typename T=ArgTypes> requires(is_tuple_v<T>)
198 T get_kids(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
199 {
200 return get_kids_seq<T>(d, tn, weights, std::make_index_sequence<ArgCount>{});
201 };
202
206 /// @tparam T argument types
207 /// @param inputs the child node outputs
208 /// @return return values applying F to the inputs
209 template<typename T=ArgTypes> requires ( is_std_array_v<T> || is_tuple_v<T>)
210 RetType apply(const T& inputs) const
211 {
212 return std::apply(F, inputs);
213 }
214
219 template<typename T=ArgTypes> requires ( is_eigen_array_v<T> && !is_std_array_v<T>)
220 RetType apply(const T& inputs) const
221 {
222 return F(inputs);
223 }
224
228
229 /// @param tn tree node
230 /// @param weights option pointer to a weight array, used in place of node weight
231 /// @return output values from applying operator function
232 template<typename T=ArgTypes, typename Scalar=RetType::Scalar>
233 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
234 {
235 auto inputs = get_kids(d, tn, weights);
236 if constexpr (is_one_of_v<Scalar,float,fJet>)
237 {
238 if (tn.data.get_is_weighted())
239 {
240 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
241 return this->apply(inputs)*w;
242 }
243 }
244 return this->apply(inputs);
245 };
246
247 // overloaded version for offset sum
248 template<typename T=ArgTypes, typename Scalar=RetType::Scalar>
250 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
251 {
252 auto inputs = get_kids(d, tn, weights);
253 if constexpr (is_one_of_v<Scalar,float,fJet>)
254 {
255 if (tn.data.get_is_weighted())
256 {
257 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
258 return this->apply(inputs) + w;
259 }
260 }
261 return this->apply(inputs);
262 };
263};
264
267template<typename S, bool Fit>
268struct Operator<NodeType::Terminal, S, Fit>
269{
270 using RetType = typename S::RetType;
271 using W = typename S::WeightType;
272
273 // Standard C++ types
274 template<typename T=RetType, typename Scalar=typename T::Scalar>
276 RetType eval(const Dataset& d, const TreeNode& tn, const W** weights=nullptr) const
277 {
278 if constexpr (is_one_of_v<Scalar,float,fJet>)
279 {
280 if (tn.data.get_is_weighted())
281 {
282 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
283 return this->get<RetType>(d, tn.data.get_feature())*w;
284 }
285 }
286 return this->get<RetType>(d,tn.data.get_feature());
287 };
288
289 // Jet types
290 template <typename T = RetType, typename Scalar=typename T::Scalar>
292 RetType eval(const Dataset &d, const TreeNode &tn, const W **weights = nullptr) const
293 {
294 using nonJetType = UnJetify_t<RetType>;
295 if constexpr (is_one_of_v<Scalar,float,fJet>)
296 {
297 if (tn.data.get_is_weighted())
298 {
299 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
300 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>()*w;
301 }
302 }
303 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>();
304 };
305
306 // Accessing dataset directly
307 template<typename T>
308 auto get(const Dataset& d, const string& feature) const
309 {
310 if (std::holds_alternative<T>(d[feature]))
311 return std::get<T>(d[feature]);
312
313 HANDLE_ERROR_THROW(fmt::format("Failed to return type {} for '{}'. The feature's original ret type is {}.\n",
315 feature,
316 DataTypeName.at(d.get_feature_type(feature))
317 ));
318
319 return T();
320 }
321};
322
324// Constant Overloads
325template<typename S, bool Fit>
326struct Operator<NodeType::Constant, S, Fit>
327{
328 using RetType = typename S::RetType;
329 using W = typename S::WeightType;
330
331 template<typename T=RetType, typename Scalar=T::Scalar, int N=T::NumDimensions>
332 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
333 {
334 // Scalar w = get_weight(tn, weights);
336
337 if constexpr (Fit)
338 {
339 // there is no need to fit a constant node, get_weight will be called
340 }
341
342 if constexpr (N == 1)
343 return RetType::Constant(d.get_n_samples(), w);
344 else
345 return RetType::Constant(d.get_n_samples(), d.get_n_features(), w);
346 };
347
348};
349
351// MeanLabel overload
352template<typename S, bool Fit>
354{
355 using RetType = typename S::RetType;
356 using W = typename S::WeightType;
357
358 RetType fit(const Dataset& d, TreeNode& tn) const {
359 // we take the mode of the labels if it is a classification problem
360 if (d.classification)
361 {
362 std::unordered_map<float, int> counters;
363 for (float val : d.y) {
364 if (counters.find(val) != counters.end()) {
365 counters[val] += 1;
366 }
367 else
368 {
369 counters[val] = 1;
370 }
371 }
372
373 auto mode = std::max_element(
374 counters.begin(), counters.end(),
375 [](const auto& a, const auto& b) { return a.second < b.second; }
376 );
377
378 tn.data.W = mode->first;
379 }
380 else
381 {
382 tn.data.W = d.y.mean();
383 }
384
385 return predict(d, tn);
386 };
387
388 template<typename T=RetType, typename Scalar=T::Scalar, int N=T::NumDimensions>
389 RetType predict(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
390 {
392 if constexpr (N == 1)
393 return RetType::Constant(d.get_n_samples(), w);
394 else
395 return RetType::Constant(d.get_n_samples(), d.get_n_features(), w);
396 };
397
398 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const {
399 if constexpr (Fit)
400 return fit(d,tn);
401 else
402 return predict(d,tn,weights);
403 };
404};
405
407// Operator overloads
408// Split
409#include "split.h"
411// Dispatch functions
412template<typename R, NodeType NT, typename S, bool Fit, typename W>
413inline R DispatchOp(const Dataset& d, TreeNode& tn, const W** weights)
414{
415 const auto op = Operator<NT,S,Fit>{};
416 return op.eval(d, tn, weights);
417};
418
419template<typename R, NodeType NT, typename S, bool Fit>
420inline R DispatchOp(const Dataset& d, TreeNode& tn)
421{
422 const auto op = Operator<NT,S,Fit>{};
423 return op.eval(d, tn);
424};
425
426} // Brush
427
428#endif
holds variable type data.
Definition data.h:51
bool classification
whether this is a classification problem
Definition data.h:85
DataType get_feature_type(const string &name) const
Definition data.h:238
int get_n_samples() const
Definition data.h:225
int get_n_features() const
Definition data.h:231
ArrayXf y
length N array, the target label
Definition data.h:82
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
Scalar get_weight(const TreeNode &tn, const W **weights=nullptr)
get weight
Definition operator.h:20
< nsga2 selection operator for getting the front
Definition bandit.cpp:3
static constexpr bool is_tuple_v
Definition types.h:259
NodeType
Definition nodetype.h:31
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
Definition operator.h:413
static constexpr bool is_eigen_array_v
Definition types.h:253
static constexpr bool is_one_of_v
Definition types.h:33
typename UnJetify< T >::type UnJetify_t
Definition signatures.h:56
map< DataType, string > DataTypeName
Definition data.cpp:14
static constexpr bool is_std_array_v
Definition types.h:245
static constexpr bool is_in_v
Definition nodetype.h:269
static constexpr bool NaryOp
Definition nodetype.h:320
static constexpr bool UnaryOp
Definition nodetype.h:274
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:332
RetType fit(const Dataset &d, TreeNode &tn) const
Definition operator.h:358
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:389
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:398
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:292
auto get(const Dataset &d, const string &feature) const
Definition operator.h:308
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:276
Core computation of a node's function to data.
Definition operator.h:91
Operator()=default
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
evaluate the operator on the data. main entry point.
Definition operator.h:233
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:250
typename S::WeightType W
set weight type
Definition operator.h:117
T get_kids_seq(const Dataset &d, TreeNode &tn, const W **weights, std::index_sequence< Is... >) const
Makes and returns a tuple of child outputs.
Definition operator.h:186
static constexpr size_t ArgCount
stores the argument count of the operator
Definition operator.h:110
typename S::RetType RetType
return type of the operator
Definition operator.h:107
typename S::template NthType< N > NthType
utility for returning the type of the Nth argument
Definition operator.h:114
static constexpr auto F
wrapper function for the node function
Definition operator.h:120
conditional_t<((UnaryOp< NT >||NaryOp< NT >) &&S::ArgCount > 1), Array< typename S::FirstArg::Scalar, -1, S::ArgCount >, typename S::ArgTypes > ArgTypes
set argument types to those of the signature unless:
Definition operator.h:101
NthType< I > get_kid(const Dataset &d, TreeNode &tn, const W **weights) const
gets one kid for a tuple of kids
Definition operator.h:163
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::array or eigen array of kids
Definition operator.h:131
RetType apply(const T &inputs) const
Apply node function in a functional style.
Definition operator.h:210
RetType apply(const T &inputs) const
Apply the node function like a function.
Definition operator.h:220