Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
node.cpp
Go to the documentation of this file.
1#include "node.h"
2
3namespace Brush {
4
5ostream& operator<<(ostream& os, const NodeType& nt)
6{
7 os << "nt: " << nt << endl;
8 os << NodeTypeName.at(nt);
9 return os;
10}
11
12ostream& operator<<(ostream& os, const Node& n)
13{
14 os << n.get_name();
15 return os;
16}
17
20auto Node::get_name(bool include_weight) const noexcept -> std::string
21{
22 constexpr float atol = 1e-6f;
23 const bool weight_is_one = std::fabs(W - 1.0f) <= atol;
24
26 {
27 if (is_weighted && !weight_is_one && include_weight)
28 return fmt::format("{:.2f}*{}", W, feature);
29 else
30 return feature;
31 }
32 else if (Is<NodeType::Constant>(node_type) && include_weight)
33 {
34 return fmt::format("{:.2f}", W);
35 }
37 {
38 // this will show (MeanLabel) in the terminal name so we can differentiate
39 // a meanLabel from a constant.
40 return fmt::format("{:.2f}", W);
41 }
43 {
44 if (is_weighted && !weight_is_one)
45 return fmt::format("{:.2f}+Add", W);
46
47 return "Sum";
48 }
49 else if (is_weighted && !weight_is_one && include_weight)
50 {
51 return fmt::format("{:.2f}*{}", W, name);
52 }
53
54 return name;
55}
56
57
58string Node::get_model(const vector<string>& children) const noexcept
59{
60 if (children.empty())
61 return get_name();
63 return fmt::format("If({}>={:.2f},{},{})",
64 feature,
65 W,
66 children.at(0),
67 children.at(1)
68 );
69 }
71 if (arg_types.at(0) == DataType::ArrayB)
72 {
73 // booleans dont use thresholds (they are used directly as mask in split)
74 return fmt::format("If({},{},{})",
75 children.at(0),
76 children.at(1),
77 children.at(2)
78 );
79 }
80 // integers or floating points (they have a threshold)
81 return fmt::format("If({}>={:.2f},{},{})",
82 children.at(0),
83 W,
84 children.at(1),
85 children.at(2)
86 );
87 }
89 string args = "";
90
91 if (is_weighted && W != 1.0)
92 args = fmt::format("{:.2f},", W);
93
94 for (int i = 0; i < children.size(); ++i){
95 args += children.at(i);
96 if (i < children.size()-1)
97 args += ",";
98 }
99
100 return fmt::format("Add({})", args);
101 }
102 else{
103 string args = "";
104 for (int i = 0; i < children.size(); ++i){
105 args += children.at(i);
106 if (i < children.size()-1)
107 args += ",";
108 }
109
110 return fmt::format("{}({})", get_name(), args);
111 }
112
113}
114
116// serialization
117// serialization for Node
118// using json = nlohmann::json;
119
120void to_json(json& j, const Node& p)
121{
122 j = json{
123 {"name", p.name},
124 {"center_op", p.center_op},
125 {"node_is_fixed", p.node_is_fixed},
126 {"weight_is_fixed", p.weight_is_fixed},
127 {"prob_change", p.prob_change},
128 {"is_weighted", p.is_weighted},
129 {"W", p.W},
130 {"node_type", p.node_type},
131 {"sig_hash", p.sig_hash},
132 {"sig_dual_hash", p.sig_dual_hash},
133 {"ret_type", p.ret_type},
134 {"arg_types", p.arg_types},
135 {"feature", p.get_feature()},
136 {"feature_type", p.get_feature_type()}
137 // {"node_hash", p.get_node_hash()}
138 };
139}
140
141using NT = NodeType;
143{
144 // if (Is<
145 // NT::Add,
146 // NT::Mul,
147 // NT::Min,
148 // NT::Max
149 // >(nt))
150 // return Signature<ArrayXf(ArrayXf,ArrayXf)>{};
151 NT n = node.node_type;
152 if (Is<
153 NT::Abs,
154 NT::Acos,
155 NT::Asin,
156 NT::Atan,
157 NT::Cos,
158 NT::Cosh,
159 NT::Sin,
160 NT::Sinh,
161 NT::Tan,
162 NT::Tanh,
163 NT::Ceil,
164 NT::Floor,
165 NT::Exp,
166 NT::Log,
168 NT::Log1p,
169 NT::Sqrt,
173 NT::OffsetSum, // unary version
175 >(n))
176 {
177 node.set_signature<Signature<ArrayXf(ArrayXf)>>();
178 }
179 else if (Is<
180 NT::Add,
181 NT::Sub,
182 NT::Mul,
183 NT::Div,
184 NT::Pow,
187 >(n))
188 {
189 node.set_signature<Signature<ArrayXf(ArrayXf,ArrayXf)>>();
190 }
191 else if (Is<
192 NT::And,
193 NT::Or
194 >(n))
195 {
197 }
198 else if (Is<
199 NT::Not
200 >(n))
201 {
203 }
204 else if (Is<
205 NT::Geq
206 >(n))
207 {
208 node.set_signature<Signature<ArrayXb(ArrayXf,ArrayXf)>>();
209 }
210 else if (Is<
212 >(n))
213 {
215 }
216 else if (Is<
218 NT::After,
220 >(n))
221 {
223 }
224 else if (Is<
226 >(n))
227 {
228 node.set_signature<Signature<ArrayXf(TimeSeriesf)>>();
229 }
230 else if (Is<
232 >(n))
233 {
234 node.set_signature<Signature<ArrayXi(ArrayXXf)>>();
235 }
236 else if (Is<
237 NT::Min,
238 NT::Max,
239 NT::Mean,
241 NT::Sum,
242 // NT::OffsetSum,
243 NT::Prod,
245 // NT::SplitOn,
247 >(n))
248 {
249 auto msg = fmt::format("Can't infer arguments for {} from json."
250 " Please provide them.\n",n);
252 }
253 else if (Is<NT::SplitOn>(n))
254 {
255 // lets make split on always defaults to floats (so it will work
256 // regardless of datatype).
257 // This only matters for weakly defined nodes when doing manual
258 // construction of brush programs as json objects, and this behavior
259 // can be ignored by avoiding the need of generating the signature
260 // (that is, defining the node with missing hash values and ret types)
261 node.set_signature<Signature<ArrayXf(ArrayXb,ArrayXf,ArrayXf)>>();
262 }
263 else if (Is<NT::Constant>(n))
264 {
265 // "feature" starts with "const"
266 char last_char = node.feature.back();
267
268 switch (last_char) {
269 case 'F':
270 node.set_signature<Signature<ArrayXf()>>();
271 break;
272 case 'I':
274 break;
275 case 'B':
277 break;
278 default:
279 node.set_signature<Signature<ArrayXf()>>();
280 }
281 }
282 else if (Is<NT::MeanLabel>(n))
284 else if (Is<NT::Terminal>(n))
285 {
286 // For terminals, use feature_type to determine the correct signature
287 switch (node.get_feature_type()) {
288 case DataType::ArrayB:
290 break;
291 case DataType::ArrayI:
293 break;
294 case DataType::ArrayF:
295 default:
296 node.set_signature<Signature<ArrayXf()>>();
297 break;
298 }
299 }
300 else
301 node.set_signature<Signature<ArrayXf()>>();
302}
303
304void from_json(const json &j, Node& p)
305{
306 // This serialization tries to build the nodes with the fewest information possible,
307 // so interface is easier when doing manual generation of trees.
308
309 // First we start with required information, then we set the optional ones
310 // (they can be inferred from the required ones)
311
312 if (j.contains("node_type"))
313 j.at("node_type").get_to(p.node_type);
314 else
315 HANDLE_ERROR_THROW("Node json must contain node_type");
316
317 if (j.contains("name"))
318 j.at("name").get_to(p.name);
319 else
321
322 if (j.contains("center_op"))
323 j.at("center_op").get_to(p.center_op);
324
325 // used in split nodes
326 if (j.contains("feature"))
327 {
328 j.at("feature").get_to(p.feature);
329 }
330
331 if (j.contains("feature_type"))
332 {
333 j.at("feature_type").get_to(p.feature_type);
334 }
335
336 // if node has a ret_type and arg_types, get them. if not we need to make
337 // a signature
338 bool make_signature=false;
339
340 if (j.contains("ret_type"))
341 j.at("ret_type").get_to(p.ret_type);
342 else
343 make_signature=true;
344 if (j.contains("arg_types"))
345 j.at("arg_types").get_to(p.arg_types);
346 else
347 make_signature=true;
348 if (j.contains("sig_hash"))
349 j.at("sig_hash").get_to(p.sig_hash);
350 else
351 make_signature=true;
352 if (j.contains("sig_dual_hash"))
353 j.at("sig_dual_hash").get_to(p.sig_dual_hash);
354 else
355 make_signature=true;
356
357 if (make_signature){
358 p.is_weighted = false; // TODO: remove this line
360 }
361
362 // after this point we set attributes that are modified in init
363 p.init();
364
365 // these below needs to be set after init(), since `init` sets these values
366 if (j.contains("node_is_fixed"))
367 {
368 j.at("node_is_fixed").get_to(p.node_is_fixed);
369 }
370
371 if (j.contains("weight_is_fixed"))
372 {
373 j.at("weight_is_fixed").get_to(p.weight_is_fixed);
374 }
375
376 if (j.contains("is_weighted"))
377 {
378 j.at("is_weighted").get_to(p.is_weighted);
379 }
380
381 if (j.contains("prob_change"))
382 {
383 j.at("prob_change").get_to(p.prob_change);
384 }
385
386 if (j.contains("W"))
387 {
388 j.at("W").get_to(p.W);
389 }
390
391 json new_json = p;
392}
393
394
395}
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
TimeSeries< float > TimeSeriesf
Definition types.h:112
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
ostream & operator<<(ostream &os, DataType n)
NodeType
Definition nodetype.h:31
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:31
auto Is(NodeType nt) -> bool
Definition node.h:313
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
void init_node_with_default_signature(Node &node)
Definition node.cpp:142
NodeType NT
Definition node.cpp:141
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
class holding the data for a node in a tree.
Definition node.h:89
bool center_op
whether to center the operator in pretty printing
Definition node.h:125
std::vector< DataType > arg_types
argument data types
Definition node.h:99
void set_signature()
Definition node.h:164
DataType feature_type
feature type for terminals or splitting nodes
Definition node.h:306
bool weight_is_fixed
whether the weight should be kept during variation. Notice that weight_is_fixed alows us to fix the w...
Definition node.h:114
NodeType node_type
the node type
Definition node.h:94
DataType get_feature_type() const
Definition node.h:284
DataType ret_type
return data type
Definition node.h:97
string get_feature() const
Definition node.h:281
std::size_t sig_hash
a hash of the signature
Definition node.h:101
void init()
Definition node.h:173
float prob_change
chance of node being selected for variation. This will take into account if the node is fixed,...
Definition node.h:119
string get_name(bool include_weight=true) const noexcept
gets a string version of the node for printing.
Definition node.cpp:20
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:122
bool is_weighted
whether this node is weighted (ignored in nodes that must have weights, such as meanLabel,...
Definition node.h:116
bool node_is_fixed
whether the node is replaceable. Weights are still optimized.
Definition node.h:112
string feature
feature name for terminals or splitting nodes
Definition node.h:303
string name
full name of the node, with types
Definition node.h:92
std::size_t sig_dual_hash
a hash of the dual of the signature (for NLS)
Definition node.h:103
string get_model(const vector< string > &) const noexcept
Definition node.cpp:58