Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
node.cpp
Go to the documentation of this file.
1#include "node.h"
2
3namespace Brush {
4
5ostream& operator<<(ostream& os, const NodeType& nt)
6{
7 os << "nt: " << nt << endl;
8 os << NodeTypeName.at(nt);
9 return os;
10}
11
12ostream& operator<<(ostream& os, const Node& n)
13{
14 os << n.get_name();
15 return os;
16}
17
20auto Node::get_name(bool include_weight) const noexcept -> std::string
21{
22
24 {
25 if (is_weighted && W != 1.0 && include_weight)
26 return fmt::format("{:.2f}*{}",W,feature);
27 else
28 return feature;
29 }
30 else if (Is<NodeType::Constant>(node_type) && include_weight)
31 {
32 return fmt::format("{:.2f}", W);
33 }
35 {
36 // this will show (MeanLabel) in the terminal name
37 // return fmt::format("{:.2f} ({})", W, feature);
38
39 return fmt::format("{:.2f}", W);
40 }
42 if (is_weighted && W != 1.0)
43 return fmt::format("{:.2f}+Sum", W);
44
45 return fmt::format("Sum");
46 }
47 else if (is_weighted && include_weight)
48 return fmt::format("{:.2f}*{}",W,name);
49
50 return name;
51}
52
53string Node::get_model(const vector<string>& children) const noexcept
54{
55 if (children.empty())
56 return get_name();
58 return fmt::format("If({}>{:.2f},{},{})",
59 feature,
60 W,
61 children.at(0),
62 children.at(1)
63 );
64 }
66 if (arg_types.at(0) == DataType::ArrayB)
67 {
68 // booleans dont use thresholds (they are used directly as mask in split)
69 return fmt::format("If({},{},{})",
70 children.at(0),
71 children.at(1),
72 children.at(2)
73 );
74 }
75 // integers or floating points (they have a threshold)
76 return fmt::format("If({}>{:.2f},{},{})",
77 children.at(0),
78 W,
79 children.at(1),
80 children.at(2)
81 );
82 }
84 string args = "";
85
86 if (is_weighted && W != 1.0)
87 args = fmt::format("{:.2f},", W);
88
89 for (int i = 0; i < children.size(); ++i){
90 args += children.at(i);
91 if (i < children.size()-1)
92 args += ",";
93 }
94
95 return fmt::format("Sum({})", args);
96 }
97 else{
98 string args = "";
99 for (int i = 0; i < children.size(); ++i){
100 args += children.at(i);
101 if (i < children.size()-1)
102 args += ",";
103 }
104
105 return fmt::format("{}({})", get_name(), args);
106 }
107
108}
109
111// serialization
112// serialization for Node
113// using json = nlohmann::json;
114
115void to_json(json& j, const Node& p)
116{
117 j = json{
118 {"name", p.name},
119 {"center_op", p.center_op},
120 {"fixed", p.fixed},
121 {"prob_change", p.prob_change},
122 {"is_weighted", p.is_weighted},
123 {"W", p.W},
124 {"node_type", p.node_type},
125 {"sig_hash", p.sig_hash},
126 {"sig_dual_hash", p.sig_dual_hash},
127 {"ret_type", p.ret_type},
128 {"arg_types", p.arg_types},
129 {"feature", p.get_feature()}
130 // {"node_hash", p.get_node_hash()}
131 };
132}
133
134using NT = NodeType;
136{
137 // if (Is<
138 // NT::Add,
139 // NT::Mul,
140 // NT::Min,
141 // NT::Max
142 // >(nt))
143 // return Signature<ArrayXf(ArrayXf,ArrayXf)>{};
144 NT n = node.node_type;
145 if (Is<
146 NT::Abs,
147 NT::Acos,
148 NT::Asin,
149 NT::Atan,
150 NT::Cos,
151 NT::Cosh,
152 NT::Sin,
153 NT::Sinh,
154 NT::Tan,
155 NT::Tanh,
156 NT::Ceil,
157 NT::Floor,
158 NT::Exp,
159 NT::Log,
161 NT::Log1p,
162 NT::Sqrt,
166 NT::OffsetSum, // unary version
168 >(n))
169 {
170 node.set_signature<Signature<ArrayXf(ArrayXf)>>();
171 }
172 else if (Is<
173 NT::Add,
174 NT::Sub,
175 NT::Mul,
176 NT::Div,
177 NT::Pow,
180 >(n))
181 {
182 node.set_signature<Signature<ArrayXf(ArrayXf,ArrayXf)>>();
183 }
184 else if (Is<
185 NT::And,
186 NT::Or
187 >(n))
188 {
190 }
191 // else if (Is<
192 // NT::Not
193 // >(n))
194 // {
195 // node.set_signature<Signature<ArrayXb(ArrayXb)>>();
196 // }
197 else if (Is<
198 NT::Min,
199 NT::Max,
200 NT::Mean,
202 NT::Sum,
203 // NT::OffsetSum, // n-ary version
204 NT::Prod,
206 >(n))
207 {
208 auto msg = fmt::format("Can't infer arguments for {} from json."
209 " Please provide them.\n",n);
211 }
212 else if (Is<
214 >(n))
215 {
216 node.set_signature<Signature<ArrayXf(ArrayXb,ArrayXf,ArrayXf)>>();
217 }
218 else{
219 node.set_signature<Signature<ArrayXf()>>();
220 }
221
222}
223
224void from_json(const json &j, Node& p)
225{
226 if (j.contains("node_type"))
227 j.at("node_type").get_to(p.node_type);
228 else
229 HANDLE_ERROR_THROW("Node json must contain node_type");
230
231 if (j.contains("name"))
232 j.at("name").get_to(p.name);
233 else
235
236 if (j.contains("center_op"))
237 j.at("center_op").get_to(p.center_op);
238
239 if (j.contains("feature"))
240 {
241 // j.at("feature").get_to(p.feature);
242 p.set_feature(j.at("feature"));
243 }
244 if (j.contains("feature_type"))
245 {
246 p.set_feature_type(j.at("feature_type"));
247 }
248
249 // if node has a ret_type and arg_types, get them. if not we need to make
250 // a signature
251 bool make_signature=false;
252
253 if (j.contains("ret_type"))
254 j.at("ret_type").get_to(p.ret_type);
255 else
256 make_signature=true;
257 if (j.contains("arg_types"))
258 j.at("arg_types").get_to(p.arg_types);
259 else
260 make_signature=true;
261 if (j.contains("sig_hash"))
262 j.at("sig_hash").get_to(p.sig_hash);
263 else
264 make_signature=true;
265 if (j.contains("sig_dual_hash"))
266 j.at("sig_dual_hash").get_to(p.sig_dual_hash);
267 else
268 make_signature=true;
269
270 if (make_signature){
271 p.is_weighted = false;
273 }
274 p.init();
275
276 // these 4 below needs to be set after init(), since it resets these values
277 if (j.contains("fixed"))
278 {
279 j.at("fixed").get_to(p.fixed);
280 }
281
282 if (j.contains("is_weighted"))
283 {
284 j.at("is_weighted").get_to(p.is_weighted);
285 }
286
287 if (j.contains("prob_change"))
288 {
289 j.at("prob_change").get_to(p.prob_change);
290 }
291
292 if (j.contains("W"))
293 {
294 j.at("W").get_to(p.W);
295 }
296
297 json new_json = p;
298}
299
300
301}
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
ostream & operator<<(ostream &os, DataType n)
NodeType
Definition nodetype.h:31
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:25
auto Is(NodeType nt) -> bool
Definition node.h:272
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
void init_node_with_default_signature(Node &node)
Definition node.cpp:135
NodeType NT
Definition node.cpp:134
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
class holding the data for a node in a tree.
Definition node.h:84
bool center_op
whether to center the operator in pretty printing
Definition node.h:89
std::vector< DataType > arg_types
argument data types
Definition node.h:103
void set_signature()
Definition node.h:145
bool fixed
whether node is modifiable
Definition node.h:93
NodeType node_type
the node type
Definition node.h:95
DataType ret_type
return data type
Definition node.h:101
void set_feature(string f)
Definition node.h:249
string get_feature() const
Definition node.h:250
std::size_t sig_hash
a hash of the signature
Definition node.h:97
void init()
Definition node.h:154
void set_feature_type(DataType ft)
Definition node.h:252
float prob_change
chance of node being selected for variation
Definition node.h:91
string get_name(bool include_weight=true) const noexcept
gets a string version of the node for printing.
Definition node.cpp:20
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:107
bool is_weighted
whether this node is weighted
Definition node.h:105
string feature
feature name for terminals or splitting nodes
Definition node.h:264
string name
full name of the node, with types
Definition node.h:87
std::size_t sig_dual_hash
a hash of the dual of the signature (for NLS)
Definition node.h:99
string get_model(const vector< string > &) const noexcept
Definition node.cpp:53