Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
operator.h
Go to the documentation of this file.
1#ifndef OPERATOR_H
2#define OPERATOR_H
3
4#include "../init.h"
5#include "tree_node.h"
6#include "../util/utils.h"
7
8namespace Brush{
10namespace util{
18 template<typename T, typename Scalar, typename W>
20 Scalar get_weight(const TreeNode& tn, const W** weights=nullptr)
21 {
22 Scalar w;
23 // Prediction case: weight is stored in the node data.
24 if (weights == nullptr)
25 {
26 w = Scalar(tn.data.W);
27 }
28 else
29 {
30 // NLS case 1: floating point weight is stored in weights
31 if constexpr (is_same_v<Scalar, W>)
32 w = **weights;
33 // NLS case 2: a Jet/Dual weight is stored in weights, but this constant is a
34 // integer type. We need to do some casting
35 else if constexpr (is_same_v<Scalar, iJet> && is_same_v<W, fJet>) {
36 using WScalar = typename Scalar::Scalar;
37 WScalar tmp = WScalar((**weights).a);
38 w = Scalar(tmp);
39 }
40 // NLS case 3: a Jet/Dual weight is stored in weights, matching Scalar type
41 else
42 w = Scalar(**weights);
43 *weights = *weights+1;
44
45 }
46 return w;
47 };
48 template<typename T, typename Scalar, typename W>
50 Scalar get_weight(const TreeNode& tn, const W** weights=nullptr)
51 {
52 // we cannot weight a boolean feature. Nevertheless, we need to provide
53 // an implementation for get_weight behavior, so the metaprogramming
54 // doesn't fail to get a matching signature.
55
56 if (tn.data.get_is_weighted())
57 // Node's init() function avoids the creation of weighted nodes,
58 // and the setter for `is_weighted` prevent enabling weight on
59 // boolean values.
60 HANDLE_ERROR_THROW(fmt::format("boolean terminal is weighted, but "
61 "it should not\n"));
62
63 return Scalar(true);
64 };
65}
67// Operator class
68
74template<NodeType NT, typename S, bool Fit, typename E=void>
75struct Operator
76{
87 ((UnaryOp<NT> || NaryOp<NT>) && S::ArgCount > 1),
88 Array<typename S::FirstArg::Scalar, -1, S::ArgCount>,
89 typename S::ArgTypes>;
90
92 using RetType = typename S::RetType;
93
95 static constexpr size_t ArgCount = S::ArgCount;
96
98 template <std::size_t N>
99 using NthType = typename S::NthType<N>;
100
102 using W = typename S::WeightType;
103
105 static constexpr auto F = [](const auto& ...args) {
107 return f(args...);
108 };
109
110 Operator() = default;
112 // Utilities to grab child outputs.
113
115 template<typename T=ArgTypes> requires(is_std_array_v<T> || is_eigen_array_v<T>)
116 T get_kids(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
117 {
119 using arg_type = std::conditional_t<is_std_array_v<T>,
120 typename T::value_type, Array<typename S::FirstArg::Scalar, -1, 1>>;
121 if constexpr (is_eigen_array_v<T>)
122 child_outputs.resize(d.get_n_samples(), Eigen::NoChange);
123
124 TreeNode* sib = tn.first_child;
125 for (int i = 0; i < ArgCount; ++i)
126 {
127 if (sib == nullptr)
128 HANDLE_ERROR_THROW("bad sibling ptr in get kids");
129 if constexpr (Fit){
130 if constexpr(is_std_array_v<T>)
131 child_outputs.at(i) = sib->fit<arg_type>(d);
132 else
133 child_outputs.col(i) = sib->fit<arg_type>(d);
134 }
135 else{
136 if constexpr(is_std_array_v<T>)
137 child_outputs.at(i) = sib->predict<arg_type>(d, weights);
138 else
139 child_outputs.col(i) = sib->predict<arg_type>(d, weights);
140 }
141 sib = sib->next_sibling;
142 }
143 return child_outputs;
144 };
145
147 template<int I>
148 NthType<I> get_kid(const Dataset& d,TreeNode& tn, const W** weights ) const
149 {
150 auto sib = tree<TreeNode>::sibling_iterator(tn.first_child) ;
151 sib += I;
152 if constexpr(Fit)
153 return sib->fit<NthType<I>>(d);
154 else
155 return sib->predict<NthType<I>>(d,weights);
156 };
157
167 template<typename T, size_t ...Is> requires(is_tuple_v<T>)
168 T get_kids_seq(const Dataset& d, TreeNode& tn, const W** weights, std::index_sequence<Is...>) const
169 {
170 return std::make_tuple(get_kid<Is>(d,tn,weights)...);
171 };
172
179 template<typename T=ArgTypes> requires(is_tuple_v<T>)
180 T get_kids(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
181 {
182 return get_kids_seq<T>(d, tn, weights, std::make_index_sequence<ArgCount>{});
183 };
184
186
191 template<typename T=ArgTypes> requires ( is_std_array_v<T> || is_tuple_v<T>)
192 RetType apply(const T& inputs) const
193 {
194 return std::apply(F, inputs);
195 }
196
201 template<typename T=ArgTypes> requires ( is_eigen_array_v<T> && !is_std_array_v<T>)
202 RetType apply(const T& inputs) const
203 {
204 return F(inputs);
205 }
206
214 template<typename T=ArgTypes, typename Scalar=RetType::Scalar>
215 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
216 {
217 auto inputs = get_kids(d, tn, weights);
218 if constexpr (is_one_of_v<Scalar,float,fJet>)
219 {
220 if (tn.data.get_is_weighted())
221 {
222 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
223 return this->apply(inputs)*w;
224 }
225 }
226 return this->apply(inputs);
227 };
228
229 // overloaded version for offset sum
230 template<typename T=ArgTypes, typename Scalar=RetType::Scalar>
232 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
233 {
234 auto inputs = get_kids(d, tn, weights);
235 if constexpr (is_one_of_v<Scalar,float,fJet>)
236 {
237 if (tn.data.get_is_weighted())
238 {
239 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
240 return this->apply(inputs) + w;
241 }
242 }
243 return this->apply(inputs);
244 };
245};
246
249template<typename S, bool Fit>
251{
252 using RetType = typename S::RetType;
253 using W = typename S::WeightType;
254
255 // Standard C++ types
256 template<typename T=RetType, typename Scalar=typename T::Scalar>
258 RetType eval(const Dataset& d, const TreeNode& tn, const W** weights=nullptr) const
259 {
260 if constexpr (is_one_of_v<Scalar,float,fJet>)
261 {
262 if (tn.data.get_is_weighted())
263 {
264 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
265 return this->get<RetType>(d, tn.data.get_feature())*w;
266 }
267 }
268 return this->get<RetType>(d,tn.data.get_feature());
269 };
270
271 // Jet types
272 template <typename T = RetType, typename Scalar=typename T::Scalar>
274 RetType eval(const Dataset &d, const TreeNode &tn, const W **weights = nullptr) const
275 {
277 if constexpr (is_one_of_v<Scalar,float,fJet>)
278 {
279 if (tn.data.get_is_weighted())
280 {
281 auto w = util::get_weight<RetType,Scalar,W>(tn, weights);
282 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>()*w;
283 }
284 }
285 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>();
286 };
287
288 // Accessing dataset directly
289 template<typename T>
290 auto get(const Dataset& d, const string& feature) const
291 {
292 if (std::holds_alternative<T>(d[feature]))
293 return std::get<T>(d[feature]);
294
295 HANDLE_ERROR_THROW(fmt::format("Failed to return type {} for '{}'\n",
297 feature
298 ));
299
300 return T();
301 }
302};
303
305// Constant Overloads
306template<typename S, bool Fit>
308{
309 using RetType = typename S::RetType;
310 using W = typename S::WeightType;
311
312 template<typename T=RetType, typename Scalar=T::Scalar, int N=T::NumDimensions>
313 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
314 {
315 // Scalar w = get_weight(tn, weights);
316 Scalar w = util::get_weight<RetType,Scalar,W>(tn, weights);
317 if constexpr (N == 1)
318 return RetType::Constant(d.get_n_samples(), w);
319 else
320 return RetType::Constant(d.get_n_samples(), d.get_n_features(), w);
321 };
322
323};
324
326// MeanLabel overload
327template<typename S, bool Fit>
329{
330 using RetType = typename S::RetType;
331 using W = typename S::WeightType;
332
333 RetType fit(const Dataset& d, TreeNode& tn) const {
334 tn.data.W = d.y.mean();
335 return predict(d, tn);
336 };
337
338 template<typename T=RetType, typename Scalar=T::Scalar, int N=T::NumDimensions>
339 RetType predict(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const
340 {
341 Scalar w = util::get_weight<RetType,Scalar,W>(tn, weights);
342 if constexpr (N == 1)
343 return RetType::Constant(d.get_n_samples(), w);
344 else
345 return RetType::Constant(d.get_n_samples(), d.get_n_features(), w);
346 };
347
348 RetType eval(const Dataset& d, TreeNode& tn, const W** weights=nullptr) const {
349 if constexpr (Fit)
350 return fit(d,tn);
351 else
352 return predict(d,tn,weights);
353 };
354};
355
357// Operator overloads
358// Split
359#include "split.h"
361// Dispatch functions
362template<typename R, NodeType NT, typename S, bool Fit, typename W>
363inline R DispatchOp(const Dataset& d, TreeNode& tn, const W** weights)
364{
365 const auto op = Operator<NT,S,Fit>{};
366 return op.eval(d, tn, weights);
367};
368
369template<typename R, NodeType NT, typename S, bool Fit>
370inline R DispatchOp(const Dataset& d, TreeNode& tn)
371{
372 const auto op = Operator<NT,S,Fit>{};
373 return op.eval(d, tn);
374};
375
376} // Brush
377
378#endif
void bind_engine(py::module &m, string name)
holds variable type data.
Definition data.h:51
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
Scalar get_weight(const TreeNode &tn, const W **weights=nullptr)
get weight
Definition operator.h:20
< nsga2 selection operator for getting the front
Definition data.cpp:12
NodeType
Definition nodetype.h:31
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
Definition operator.h:363
typename UnJetify< T >::type UnJetify_t
Definition signatures.h:56
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:313
RetType fit(const Dataset &d, TreeNode &tn) const
Definition operator.h:333
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:339
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:348
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:274
auto get(const Dataset &d, const string &feature) const
Definition operator.h:290
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:258
Core computation of a node's function to data.
Definition operator.h:76
Operator()=default
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
evaluate the operator on the data. main entry point.
Definition operator.h:215
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
Definition operator.h:232
typename S::WeightType W
set weight type
Definition operator.h:102
T get_kids_seq(const Dataset &d, TreeNode &tn, const W **weights, std::index_sequence< Is... >) const
Makes and returns a tuple of child outputs.
Definition operator.h:168
static constexpr size_t ArgCount
stores the argument count of the operator
Definition operator.h:95
typename S::NthType< N > NthType
utility for returning the type of the Nth argument
Definition operator.h:99
typename S::RetType RetType
return type of the operator
Definition operator.h:92
static constexpr auto F
wrapper function for the node function
Definition operator.h:105
conditional_t<((UnaryOp< NT >||NaryOp< NT >) &&S::ArgCount > 1), Array< typename S::FirstArg::Scalar, -1, S::ArgCount >, typename S::ArgTypes > ArgTypes
set argument types to those of the signature unless:
Definition operator.h:86
NthType< I > get_kid(const Dataset &d, TreeNode &tn, const W **weights) const
gets one kid for a tuple of kids
Definition operator.h:148
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::tuple of kids. Used when child arguments are different types.
Definition operator.h:180
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::array or eigen array of kids
Definition operator.h:116
RetType apply(const T &inputs) const
Apply node function in a functional style.
Definition operator.h:192
RetType apply(const T &inputs) const
Apply the node function like a function.
Definition operator.h:202