Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
operator.h
Go to the documentation of this file.
1#ifndef OPERATOR_H
2#define OPERATOR_H
3
4#include "../init.h"
5#include "tree_node.h"
6#include "../util/utils.h"
8namespace Brush{
10namespace util{
11 ////////////////////////////////////////////////////////////////////////////////
12 /// @brief get weight
13 /// @tparam T return type
14 /// @tparam Scalar scalar type of return type
15 /// @param tn tree node
16 /// @param weights option pointer to a weight array, used in place of node weight
17 /// @return
18 template<typename T, typename Scalar, typename W>
20 Scalar get_weight(const TreeNode& tn, const W** weights=nullptr)
21 {
22 Scalar w;
23 // Prediction case: weight is stored in the node data.
24 if (weights == nullptr)
25 {
26 w = Scalar(tn.data.W);
27 }
28 else
29 {
30 // NLS case 1: floating point weight is stored in weights
31 if constexpr (is_same_v<Scalar, W>)
32 w = **weights;
33 // NLS case 2: a Jet/Dual weight is stored in weights, but this constant is a
34 // integer type. We need to do some casting
35 else if constexpr (is_same_v<Scalar, iJet> && is_same_v<W, fJet>) {
36 using WScalar = typename Scalar::Scalar;
37 WScalar tmp = WScalar((**weights).a);
38 w = Scalar(tmp);
39 }
40 // NLS case 3: a Jet/Dual weight is stored in weights, matching Scalar type
41 else
42 w = Scalar(**weights);
43 *weights = *weights+1;
44
45 }
46 return w;
47 };
48 template<typename T, typename Scalar, typename W>
50 Scalar get_weight(const TreeNode& tn, const W** weights=nullptr)
51 {
52 // we cannot weight a boolean feature. Nevertheless, we need to provide
53 // an implementation for get_weight behavior, so the metaprogramming
54 // doesn't fail to get a matching signature.
55
56 if (tn.data.get_is_weighted())
57 // Node's init() function avoids the creation of weighted nodes,
58 // and the setter for `is_weighted` prevent enabling weight on
59 // boolean values.
60 HANDLE_ERROR_THROW(fmt::format("boolean terminal is weighted, but "
61 "it should not\n"));
63 return Scalar(true);
64 };
65}
67// Operator class
68
74template<NodeType NT, typename S, bool Fit, typename E=void>
75struct Operator
76{
86 using ArgTypes = conditional_t<
87 ((UnaryOp<NT> || NaryOp<NT>) && S::ArgCount > 1),
88 Array<typename S::FirstArg::Scalar, -1, S::ArgCount>,
89 typename S::ArgTypes>;
90
92 using RetType = typename S::RetType;
93
95 static constexpr size_t ArgCount = S::ArgCount;
96
98 template <std::size_t N>
99 using NthType = typename S::template NthType<N>;
100
102 using W = typename S::WeightType;
103
105 static constexpr auto F = [](const auto& ...args) {
106 Function<NT> f;
107 return f(args...);
108 };
109
110 Operator() = default;
112 // Utilities to grab child outputs.
113
115 template<typename T=ArgTypes> requires(is_std_array_v<T> || is_eigen_array_v<T>)
116 T get_kids(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
117 {
118 T child_outputs;
119 using arg_type = std::conditional_t<is_std_array_v<T>,
120 typename T::value_type, Array<typename S::FirstArg::Scalar, -1, 1>>;
121 if constexpr (is_eigen_array_v<T>)
122 child_outputs.resize(d.get_n_samples(), Eigen::NoChange);
123
124 TreeNode* sib = tn.first_child;
125 for (int i = 0; i < ArgCount; ++i)
126 {
127 if (sib == nullptr)
128 HANDLE_ERROR_THROW("bad sibling ptr in get kids");
129 if constexpr (Fit){
130 if constexpr(is_std_array_v<T>)
131 child_outputs.at(i) = sib->fit<arg_type>(d);
132 else
133 child_outputs.col(i) = sib->fit<arg_type>(d);
134 }
135 else{
136 if constexpr(is_std_array_v<T>)
137 child_outputs.at(i) = sib->predict<arg_type>(d, weights);
138 else
139 child_outputs.col(i) = sib->predict<arg_type>(d, weights);
140 }
141 sib = sib->next_sibling;
142 }
143 return child_outputs;
144 };
145
147 template<int I>
148 NthType<I> get_kid(const Dataset& d, TreeNode& tn, const W** weights ) const
149 {
150 TreeNode* sib = tn.first_child;
151 for (int i = 0; i < I; ++i)
152 {
153 sib= sib->next_sibling;
155 if constexpr(Fit)
156 return sib->fit<NthType<I>>(d);
157 else
158 return sib->predict<NthType<I>>(d,weights);
159 };
160
170 template<typename T, size_t ...Is> requires(is_tuple_v<T>)
171 T get_kids_seq(const Dataset& d, TreeNode& tn, const W** weights, std::index_sequence<Is...>) const
172 {
173 return std::make_tuple(get_kid<Is>(d,tn,weights)...);
174 };
175
182 template<typename T=ArgTypes> requires(is_tuple_v<T>)
183 T get_kids(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
184 {
185 return get_kids_seq<T>(d, tn, weights, std::make_index_sequence<ArgCount>{});
186 };
187
194 template<typename T=ArgTypes> requires ( is_std_array_v<T> || is_tuple_v<T>)
195 RetType apply(const T& inputs) const
196 {
197 return std::apply(F, inputs);
199
203 /// @return return values applying F to the inputs
204 template<typename T=ArgTypes> requires ( is_eigen_array_v<T> && !is_std_array_v<T>)
205 RetType apply(const T& inputs) const
206 {
207 return F(inputs);
208 }
209
217 template<typename T=ArgTypes, typename Scalar=RetType::Scalar>
218 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
219 {
220 auto inputs = get_kids(d, tn, weights);
221 if constexpr (is_one_of_v<Scalar,float,fJet>)
222 {
223 if (tn.data.get_is_weighted())
224 {
225 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
226 return this->apply(inputs)*w;
227 }
228 }
229 return this->apply(inputs);
230 };
231
232 // overloaded version for offset sum
233 template<typename T=ArgTypes, typename Scalar=RetType::Scalar>
235 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
236 {
237 auto inputs = get_kids(d, tn, weights);
238 if constexpr (is_one_of_v<Scalar,float,fJet>)
239 {
240 if (tn.data.get_is_weighted())
241 {
242 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
243 return this->apply(inputs) + w;
244 }
245 }
246 return this->apply(inputs);
247 };
248};
249
252template<typename S, bool Fit>
253struct Operator<NodeType::Terminal, S, Fit>
254{
255 using RetType = typename S::RetType;
256 using W = typename S::WeightType;
258 // Standard C++ types
259 template<typename T=RetType, typename Scalar=typename T::Scalar>
261 RetType eval(const Dataset& d, const TreeNode& tn, const W** weights=nullptr) const
262 {
263 if constexpr (is_one_of_v<Scalar,float,fJet>)
264 {
265 if (tn.data.get_is_weighted())
266 {
267 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
268 return this->get<RetType>(d, tn.data.get_feature())*w;
269 }
270 }
271 return this->get<RetType>(d,tn.data.get_feature());
272 };
273
274 // Jet types
275 template <typename T = RetType, typename Scalar=typename T::Scalar>
277 RetType eval(const Dataset &d, const TreeNode &tn, const W **weights = nullptr) const
278 {
279 using nonJetType = UnJetify_t<RetType>;
280 if constexpr (is_one_of_v<Scalar,float,fJet>)
281 {
282 if (tn.data.get_is_weighted())
283 {
284 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
285 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>()*w;
286 }
287 }
288 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>();
289 };
290
291 // Accessing dataset directly
292 template<typename T>
293 auto get(const Dataset& d, const string& feature) const
294 {
295 if (std::holds_alternative<T>(d[feature]))
296 return std::get<T>(d[feature]);
297
298 HANDLE_ERROR_THROW(fmt::format("Failed to return type {} for '{}'\n",
300 feature
301 ));
302
303 return T();
304 }
305};
306
308// Constant Overloads
309template<typename S, bool Fit>
310struct Operator<NodeType::Constant, S, Fit>
311{
312 using RetType = typename S::RetType;
313 using W = typename S::WeightType;
314
315 template<typename T=RetType, typename Scalar=T::Scalar, int N=T::NumDimensions>
316 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
317 {
318 // Scalar w = get_weight(tn, weights);
320 if constexpr (N == 1)
321 return RetType::Constant(d.get_n_samples(), w);
322 else
323 return RetType::Constant(d.get_n_samples(), d.get_n_features(), w);
324 };
325
326};
327
329// MeanLabel overload
330template<typename S, bool Fit>
332{
333 using RetType = typename S::RetType;
334 using W = typename S::WeightType;
335
336 RetType fit(const Dataset& d, TreeNode& tn) const {
337 // we take the mode of the labels if it is a classification problem
338 if (d.classification)
339 {
340 std::unordered_map<float, int> counters;
341 for (float val : d.y) {
342 if (counters.find(val) != counters.end()) {
343 counters[val] += 1;
344 }
345 else
346 {
347 counters[val] = 1;
348 }
349 }
350
351 auto mode = std::max_element(
352 counters.begin(), counters.end(),
353 [](const auto& a, const auto& b) { return a.second < b.second; }
354 );
355
356 tn.data.W = mode->first;
357 }
358 else
359 {
360 tn.data.W = d.y.mean();
361 }
362
363 return predict(d, tn);
364 };
365
366 template<typename T=RetType, typename Scalar=T::Scalar, int N=T::NumDimensions>
367 RetType predict(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
368 {
370 if constexpr (N == 1)
371 return RetType::Constant(d.get_n_samples(), w);
372 else
373 return RetType::Constant(d.get_n_samples(), d.get_n_features(), w);
374 };
375
376 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const {
377 if constexpr (Fit)
378 return fit(d,tn);
379 else
380 return predict(d,tn,weights);
381 };
382};
383
385// Operator overloads
386// Split
387#include "split.h"
389// Dispatch functions
390template<typename R, NodeType NT, typename S, bool Fit, typename W>
391inline R DispatchOp(const Dataset& d, TreeNode& tn, const W** weights)
392{
393 const auto op = Operator<NT,S,Fit>{};
394 return op.eval(d, tn, weights);
395};
396
397template<typename R, NodeType NT, typename S, bool Fit>
398inline R DispatchOp(const Dataset& d, TreeNode& tn)
399{
400 const auto op = Operator<NT,S,Fit>{};
401 return op.eval(d, tn);
402};
403
404} // Brush
405
406#endif
holds variable type data.
Definition data.h:51
bool classification
whether this is a classification problem
Definition data.h:83
int get_n_samples() const
Definition data.h:222
int get_n_features() const
Definition data.h:228
ArrayXf y
length N array, the target label
Definition data.h:80
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
Scalar get_weight(const TreeNode &tn, const W **weights=nullptr)
get weight
Definition operator.h:20
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
static constexpr bool is_tuple_v
Definition types.h:259
NodeType
Definition nodetype.h:31
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
Definition operator.h:391
static constexpr bool is_eigen_array_v
Definition types.h:253
static constexpr bool is_one_of_v
Definition types.h:33
typename UnJetify< T >::type UnJetify_t
Definition signatures.h:56
static constexpr bool is_std_array_v
Definition types.h:245
static constexpr bool is_in_v
Definition nodetype.h:268
static constexpr bool NaryOp
Definition nodetype.h:313
static constexpr bool UnaryOp
Definition nodetype.h:273
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:316
RetType fit(const Dataset &d, TreeNode &tn) const
Definition operator.h:336
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:367
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:376
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:277
auto get(const Dataset &d, const string &feature) const
Definition operator.h:293
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:261
Core computation of a node's function to data.
Definition operator.h:76
Operator()=default
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
evaluate the operator on the data. main entry point.
Definition operator.h:218
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:235
typename S::WeightType W
set weight type
Definition operator.h:102
T get_kids_seq(const Dataset &d, TreeNode &tn, const W **weights, std::index_sequence< Is... >) const
Makes and returns a tuple of child outputs.
Definition operator.h:171
static constexpr size_t ArgCount
stores the argument count of the operator
Definition operator.h:95
typename S::RetType RetType
return type of the operator
Definition operator.h:92
typename S::template NthType< N > NthType
utility for returning the type of the Nth argument
Definition operator.h:99
static constexpr auto F
wrapper function for the node function
Definition operator.h:105
conditional_t<((UnaryOp< NT >||NaryOp< NT >) &&S::ArgCount > 1), Array< typename S::FirstArg::Scalar, -1, S::ArgCount >, typename S::ArgTypes > ArgTypes
set argument types to those of the signature unless:
Definition operator.h:86
NthType< I > get_kid(const Dataset &d, TreeNode &tn, const W **weights) const
gets one kid for a tuple of kids
Definition operator.h:148
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::tuple of kids. Used when child arguments are different types.
Definition operator.h:183
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::array or eigen array of kids
Definition operator.h:116
RetType apply(const T &inputs) const
Apply node function in a functional style.
Definition operator.h:195
RetType apply(const T &inputs) const
Apply the node function like a function.
Definition operator.h:205