Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
node.cpp
Go to the documentation of this file.
1#include "node.h"
2
3namespace Brush {
4
5ostream& operator<<(ostream& os, const NodeType& nt)
6{
7 os << "nt: " << nt << endl;
8 os << NodeTypeName.at(nt);
9 return os;
10}
11
12ostream& operator<<(ostream& os, const Node& n)
13{
14 os << n.get_name();
15 return os;
16}
17
20auto Node::get_name(bool include_weight) const noexcept -> std::string
21{
23 {
24 if (is_weighted && W != 1.0 && include_weight)
25 return fmt::format("{:.2f}*{}",W,feature);
26 else
27 return feature;
28 }
29 else if (Is<NodeType::Constant>(node_type) && include_weight)
30 {
31 return fmt::format("{:.2f}", W);
32 }
34 {
35 // this will show (MeanLabel) in the terminal name
36 // return fmt::format("{:.2f} ({})", W, feature);
37
38 return fmt::format("{:.2f}", W);
39 }
41 if (is_weighted && W != 1.0)
42 return fmt::format("{:.2f}+Sum", W);
43
44 return fmt::format("Sum");
45 }
46 else if (is_weighted && include_weight)
47 return fmt::format("{:.2f}*{}",W,name);
48
49 return name;
50}
51
52string Node::get_model(const vector<string>& children) const noexcept
53{
54 if (children.empty())
55 return get_name();
57 return fmt::format("If({}>={:.2f},{},{})",
58 feature,
59 W,
60 children.at(0),
61 children.at(1)
62 );
63 }
65 if (arg_types.at(0) == DataType::ArrayB)
66 {
67 // booleans dont use thresholds (they are used directly as mask in split)
68 return fmt::format("If({},{},{})",
69 children.at(0),
70 children.at(1),
71 children.at(2)
72 );
73 }
74 // integers or floating points (they have a threshold)
75 return fmt::format("If({}>={:.2f},{},{})",
76 children.at(0),
77 W,
78 children.at(1),
79 children.at(2)
80 );
81 }
83 string args = "";
84
85 if (is_weighted && W != 1.0)
86 args = fmt::format("{:.2f},", W);
87
88 for (int i = 0; i < children.size(); ++i){
89 args += children.at(i);
90 if (i < children.size()-1)
91 args += ",";
92 }
93
94 return fmt::format("Sum({})", args);
95 }
96 else{
97 string args = "";
98 for (int i = 0; i < children.size(); ++i){
99 args += children.at(i);
100 if (i < children.size()-1)
101 args += ",";
102 }
103
104 return fmt::format("{}({})", get_name(), args);
105 }
106
107}
108
110// serialization
111// serialization for Node
112// using json = nlohmann::json;
113
114void to_json(json& j, const Node& p)
115{
116 j = json{
117 {"name", p.name},
118 {"center_op", p.center_op},
119 {"fixed", p.fixed},
120 {"prob_change", p.prob_change},
121 {"is_weighted", p.is_weighted},
122 {"W", p.W},
123 {"node_type", p.node_type},
124 {"sig_hash", p.sig_hash},
125 {"sig_dual_hash", p.sig_dual_hash},
126 {"ret_type", p.ret_type},
127 {"arg_types", p.arg_types},
128 {"feature", p.get_feature()},
129 {"feature_type", p.get_feature_type()}
130 // {"node_hash", p.get_node_hash()}
131 };
132}
133
134using NT = NodeType;
136{
137 // if (Is<
138 // NT::Add,
139 // NT::Mul,
140 // NT::Min,
141 // NT::Max
142 // >(nt))
143 // return Signature<ArrayXf(ArrayXf,ArrayXf)>{};
144 NT n = node.node_type;
145 if (Is<
146 NT::Abs,
147 NT::Acos,
148 NT::Asin,
149 NT::Atan,
150 NT::Cos,
151 NT::Cosh,
152 NT::Sin,
153 NT::Sinh,
154 NT::Tan,
155 NT::Tanh,
156 NT::Ceil,
157 NT::Floor,
158 NT::Exp,
159 NT::Log,
161 NT::Log1p,
162 NT::Sqrt,
166 NT::OffsetSum, // unary version
168 >(n))
169 {
170 node.set_signature<Signature<ArrayXf(ArrayXf)>>();
171 }
172 else if (Is<
173 NT::Add,
174 NT::Sub,
175 NT::Mul,
176 NT::Div,
177 NT::Pow,
180 >(n))
181 {
182 node.set_signature<Signature<ArrayXf(ArrayXf,ArrayXf)>>();
183 }
184 else if (Is<
185 NT::And,
186 NT::Or
187 >(n))
188 {
190 }
191 else if (Is<
192 NT::Not
193 >(n))
194 {
196 }
197 else if (Is<
198 NT::Min,
199 NT::Max,
200 NT::Mean,
202 NT::Sum,
203 // NT::OffsetSum,
204 NT::Prod,
208 >(n))
209 {
210 auto msg = fmt::format("Can't infer arguments for {} from json."
211 " Please provide them.\n",n);
213 }
214 else if (Is<NT::SplitOn>(n))
215 {
216 node.set_signature<Signature<ArrayXf(ArrayXb,ArrayXf,ArrayXf)>>();
217 }
218 else if (Is<NT::Constant>(n))
219 {
220 // "feature" starts with "const"
221 char last_char = node.feature.back();
222
223 switch (last_char) {
224 case 'F':
225 node.set_signature<Signature<ArrayXf()>>();
226 break;
227 case 'I':
229 break;
230 case 'B':
232 break;
233 default:
234 node.set_signature<Signature<ArrayXf()>>();
235 }
236 }
237 else if (Is<NT::MeanLabel>(n))
239 else
240 node.set_signature<Signature<ArrayXf()>>();
241}
242
243void from_json(const json &j, Node& p)
244{
245 // This serialization tries to build the nodes with the fewest information possible,
246 // so interface is easier when doing manual generation of trees.
247
248 // First we start with required information, then we set the optional ones
249 // (they can be inferred from the required ones)
250
251 if (j.contains("node_type"))
252 j.at("node_type").get_to(p.node_type);
253 else
254 HANDLE_ERROR_THROW("Node json must contain node_type");
255
256 if (j.contains("name"))
257 j.at("name").get_to(p.name);
258 else
260
261 if (j.contains("center_op"))
262 j.at("center_op").get_to(p.center_op);
263
264 // used in split nodes
265 if (j.contains("feature"))
266 {
267 j.at("feature").get_to(p.feature);
268 }
269
270 if (j.contains("feature_type"))
271 {
272 j.at("feature_type").get_to(p.feature_type);
273 }
274
275 // if node has a ret_type and arg_types, get them. if not we need to make
276 // a signature
277 bool make_signature=false;
278
279 if (j.contains("ret_type"))
280 j.at("ret_type").get_to(p.ret_type);
281 else
282 make_signature=true;
283 if (j.contains("arg_types"))
284 j.at("arg_types").get_to(p.arg_types);
285 else
286 make_signature=true;
287 if (j.contains("sig_hash"))
288 j.at("sig_hash").get_to(p.sig_hash);
289 else
290 make_signature=true;
291 if (j.contains("sig_dual_hash"))
292 j.at("sig_dual_hash").get_to(p.sig_dual_hash);
293 else
294 make_signature=true;
295
296 if (make_signature){
297 p.is_weighted = false; // TODO: remove this line
299 }
300
301 // after this point we set attributes that are modified in init
302 p.init();
303
304 // these 4 below needs to be set after init(), since it resets these values
305 if (j.contains("fixed"))
306 {
307 j.at("fixed").get_to(p.fixed);
308 }
309
310 if (j.contains("is_weighted"))
311 {
312 j.at("is_weighted").get_to(p.is_weighted);
313 }
314
315 if (j.contains("prob_change"))
316 {
317 j.at("prob_change").get_to(p.prob_change);
318 }
319
320 if (j.contains("W"))
321 {
322 j.at("W").get_to(p.W);
323 }
324
325 json new_json = p;
326}
327
328
329}
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
ostream & operator<<(ostream &os, DataType n)
NodeType
Definition nodetype.h:31
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:31
auto Is(NodeType nt) -> bool
Definition node.h:291
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
void init_node_with_default_signature(Node &node)
Definition node.cpp:135
NodeType NT
Definition node.cpp:134
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
class holding the data for a node in a tree.
Definition node.h:84
bool center_op
whether to center the operator in pretty printing
Definition node.h:110
std::vector< DataType > arg_types
argument data types
Definition node.h:94
void set_signature()
Definition node.h:148
DataType feature_type
feature type for terminals or splitting nodes
Definition node.h:284
bool fixed
whether the node is replaceable. Weights are still optimized.
Definition node.h:101
NodeType node_type
the node type
Definition node.h:89
DataType get_feature_type() const
Definition node.h:262
DataType ret_type
return data type
Definition node.h:92
string get_feature() const
Definition node.h:259
std::size_t sig_hash
a hash of the signature
Definition node.h:96
void init()
Definition node.h:157
float prob_change
chance of node being selected for variation
Definition node.h:105
string get_name(bool include_weight=true) const noexcept
gets a string version of the node for printing.
Definition node.cpp:20
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:107
bool is_weighted
whether this node is weighted (ignored in nodes that must have weights, such as meanLabel,...
Definition node.h:103
string feature
feature name for terminals or splitting nodes
Definition node.h:281
string name
full name of the node, with types
Definition node.h:87
std::size_t sig_dual_hash
a hash of the dual of the signature (for NLS)
Definition node.h:98
string get_model(const vector< string > &) const noexcept
Definition node.cpp:52