Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
node.cpp
Go to the documentation of this file.
1#include "node.h"
2
3namespace Brush {
4
5ostream& operator<<(ostream& os, const NodeType& nt)
6{
7 os << "nt: " << nt << endl;
8 os << NodeTypeName.at(nt);
9 return os;
10}
11
12ostream& operator<<(ostream& os, const Node& n)
13{
14 os << n.get_name();
15 return os;
16}
17
20auto Node::get_name(bool include_weight) const noexcept -> std::string
21{
22
23 if (Is<NodeType::Terminal>(node_type))
24 {
25 if (is_weighted && W != 1.0 && include_weight)
26 return fmt::format("{:.2f}*{}",W,feature);
27 else
28 return feature;
29 }
30 else if (Is<NodeType::Constant>(node_type) && include_weight)
31 {
32 return fmt::format("{:.2f}", W);
33 }
34 else if (Is<NodeType::MeanLabel>(node_type))
35 {
36 if (include_weight)
37 return fmt::format("{:.2f}*{}", W, feature);
38
39 return feature;
40 }
41 else if (Is<NodeType::OffsetSum>(node_type)){
42 return fmt::format("{}+Sum", W);
43 }
44 else if (is_weighted && include_weight)
45 return fmt::format("{:.2f}*{}",W,name);
46
47 return name;
48}
49
50string Node::get_model(const vector<string>& children) const noexcept
51{
52 if (children.empty())
53 return get_name();
54 else if (Is<NodeType::SplitBest>(node_type)){
55 return fmt::format("If({}>{:.2f},{},{})",
56 feature,
57 W,
58 children.at(0),
59 children.at(1)
60 );
61 }
62 else if (Is<NodeType::SplitOn>(node_type)){
63 if (arg_types.at(0) == DataType::ArrayB)
64 {
65 // booleans dont use thresholds (they are used directly as mask in split)
66 return fmt::format("If({},{},{})",
67 children.at(0),
68 children.at(1),
69 children.at(2)
70 );
71 }
72 // integers or floating points (they have a threshold)
73 return fmt::format("If({}>{:.2f},{},{})",
74 children.at(0),
75 W,
76 children.at(1),
77 children.at(2)
78 );
79 }
80 else if (Is<NodeType::OffsetSum>(node_type)){
81 // weight is part of the model
82 string args = fmt::format("{},", W);
83
84 for (int i = 0; i < children.size(); ++i){
85 args += children.at(i);
86 if (i < children.size()-1)
87 args += ",";
88 }
89
90 return fmt::format("Sum({})", args);
91 }
92 else{
93 string args = "";
94 for (int i = 0; i < children.size(); ++i){
95 args += children.at(i);
96 if (i < children.size()-1)
97 args += ",";
98 }
99
100 return fmt::format("{}({})", get_name(), args);
101 }
102
103}
104
106// serialization
107// serialization for Node
108// using json = nlohmann::json;
109
110void to_json(json& j, const Node& p)
111{
112 j = json{
113 {"name", p.name},
114 {"center_op", p.center_op},
115 {"prob_change", p.prob_change},
116 {"fixed", p.fixed},
117 {"node_type", p.node_type},
118 {"sig_hash", p.sig_hash},
119 {"sig_dual_hash", p.sig_dual_hash},
120 {"ret_type", p.ret_type},
121 {"arg_types", p.arg_types},
122 {"is_weighted", p.is_weighted},
123 {"W", p.W},
124 {"feature", p.get_feature()}
125 // {"node_hash", p.get_node_hash()}
126 };
127}
128
129using NT = NodeType;
131{
132 // if (Is<
133 // NT::Add,
134 // NT::Mul,
135 // NT::Min,
136 // NT::Max
137 // >(nt))
138 // return Signature<ArrayXf(ArrayXf,ArrayXf)>{};
139 NT n = node.node_type;
140 if (Is<
141 NT::Abs,
142 NT::Acos,
143 NT::Asin,
144 NT::Atan,
145 NT::Cos,
146 NT::Cosh,
147 NT::Sin,
148 NT::Sinh,
149 NT::Tan,
150 NT::Tanh,
151 NT::Ceil,
152 NT::Floor,
153 NT::Exp,
154 NT::Log,
156 NT::Log1p,
157 NT::Sqrt,
161 NT::OffsetSum, // unary version
163 >(n))
164 {
165 node.set_signature<Signature<ArrayXf(ArrayXf)>>();
166 }
167 else if (Is<
168 NT::Add,
169 NT::Sub,
170 NT::Mul,
171 NT::Div,
172 NT::Pow,
175 >(n))
176 {
177 node.set_signature<Signature<ArrayXf(ArrayXf,ArrayXf)>>();
178 }
179 else if (Is<
180 NT::And,
181 NT::Or
182 >(n))
183 {
184 node.set_signature<Signature<ArrayXb(ArrayXb,ArrayXb)>>();
185 }
186 // else if (Is<
187 // NT::Not
188 // >(n))
189 // {
190 // node.set_signature<Signature<ArrayXb(ArrayXb)>>();
191 // }
192 else if (Is<
193 NT::Min,
194 NT::Max,
195 NT::Mean,
197 NT::Sum,
198 // NT::OffsetSum, // n-ary version
199 NT::Prod,
201 >(n))
202 {
203 auto msg = fmt::format("Can't infer arguments for {} from json."
204 " Please provide them.\n",n);
206 }
207 else if (Is<
209 >(n))
210 {
211 node.set_signature<Signature<ArrayXf(ArrayXf,ArrayXf,ArrayXf)>>();
212 }
213 else{
214 node.set_signature<Signature<ArrayXf()>>();
215 }
216
217}
218
219void from_json(const json &j, Node& p)
220{
221
222 if (j.contains("node_type"))
223 j.at("node_type").get_to(p.node_type);
224 else
225 HANDLE_ERROR_THROW("Node json must contain node_type");
226
227 if (j.contains("name"))
228 j.at("name").get_to(p.name);
229 else
230 p.name = NodeTypeName[p.node_type];
231
232 if (j.contains("center_op"))
233 j.at("center_op").get_to(p.center_op);
234
235 if (j.contains("fixed"))
236 j.at("fixed").get_to(p.fixed);
237
238 if (j.contains("feature"))
239 {
240 // j.at("feature").get_to(p.feature);
241 p.set_feature(j.at("feature"));
242 }
243 if (j.contains("is_weighted"))
244 j.at("is_weighted").get_to(p.is_weighted);
245 else
246 p.is_weighted=false;
247
248 if (j.contains("prob_change"))
249 j.at("prob_change").get_to(p.prob_change);
250
251
252 // if node has a ret_type and arg_types, get them. if not we need to make
253 // a signature
254 bool make_signature=false;
255
256 if (j.contains("ret_type"))
257 j.at("ret_type").get_to(p.ret_type);
258 else
259 make_signature=true;
260 if (j.contains("arg_types"))
261 j.at("arg_types").get_to(p.arg_types);
262 else
263 make_signature=true;
264 if (j.contains("sig_hash"))
265 j.at("sig_hash").get_to(p.sig_hash);
266 else
267 make_signature=true;
268 if (j.contains("sig_dual_hash"))
269 j.at("sig_dual_hash").get_to(p.sig_dual_hash);
270 else
271 make_signature=true;
272
273 if (make_signature){
275 }
276 p.init();
277
278 if (j.contains("W"))
279 j.at("W").get_to(p.W);
280
281
282 json new_json = p;
283}
284
285
286}
void bind_engine(py::module &m, string name)
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition data.cpp:12
ostream & operator<<(ostream &os, DataType n)
NodeType
Definition nodetype.h:31
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:24
auto Is(NodeType nt) -> bool
Definition node.h:260
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
void init_node_with_default_signature(Node &node)
Definition node.cpp:130
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
class holding the data for a node in a tree.
Definition node.h:84
string get_name(bool include_weight=true) const noexcept
gets a string version of the node for printing.
Definition node.cpp:20
string get_model(const vector< string > &) const noexcept
Definition node.cpp:50