Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
tree_node.cpp
Go to the documentation of this file.
1#include "tree_node.h"
2
3
4string TreeNode::get_model(bool pretty) const
5{
6 if (data.get_arg_count()==0)
7 return data.get_name();
8
9 vector<string> child_outputs;
10 auto sib = first_child;
11 for(int i = 0; i < data.get_arg_count(); ++i)
12 {
13 child_outputs.push_back(sib->get_model(pretty));
14 sib = sib->next_sibling;
15 }
16 return data.get_model(child_outputs);
17};
18
19
20string TreeNode::get_tree_model(bool pretty, string offset) const
21{
22 // TODO: pretty is not being used. either drop it or improve the function.
23 // offset seems to be for internal usage. TODO: check that, and write it in the docs
24
25 if (data.get_arg_count()==0)
26 return data.get_name();
27
28 string new_offset = " ";
29 string child_outputs = "\n";
30
31 auto sib = first_child;
32 for(int i = 0; i < data.get_arg_count(); ++i)
33 {
34 child_outputs += offset + "|- ";
35 string s = sib->get_tree_model(pretty, offset+new_offset);
36 sib = sib->next_sibling;
37 // if (sib == nullptr)
38 ReplaceStringInPlace(s, "\n"+offset, "\n"+offset+"|") ;
39 child_outputs += s;
40 if (sib != nullptr)
41 child_outputs += "\n";
42 }
43 if (Is<NodeType::SplitBest>(data.node_type)){
44 if (data.get_feature_type() != DataType::ArrayB) {
45 return fmt::format("If({}>={:.2f})", data.get_feature(), data.W) + child_outputs;
46 }
47 else {
48 return fmt::format("If({})", data.get_feature()) + child_outputs;
49 }
50 }
51 else if (Is<NodeType::SplitOn>(data.node_type)){
52 if (data.arg_types.at(0) == DataType::ArrayB)
53 {
54 // booleans dont use thresholds (they are used directly as mask in split)
55 return "If" + child_outputs;
56 }
57 else {
58 // integers or floating points (they have a threshold)
59 return fmt::format("If(>={:.2f})", data.W) + child_outputs;
60 }
61 }
62 else{
63 return data.get_name() + child_outputs;
64 }
65};
67// serialization for tree
68void to_json(json &j, const tree<Node> &t)
69{
70 j.clear();
71 // for (auto iter = t.begin(); iter!=t.end(); ++iter)
72 for (const auto &el : t)
73 {
74 j.push_back(el);
75 }
76}
77
81void from_json(const json &j, tree<Node> &t)
82{
83 vector<tree<Node>> stack;
84 for (int i = j.size(); i --> 0; )
85 {
86 auto node = j.at(i).get<Node>();
87 tree<Node> subtree;
88 auto root = subtree.insert(subtree.begin(), node);
89 for (auto at : node.arg_types)
90 {
91 auto spot = subtree.append_child(root);
92 auto arg = stack.back();
93 subtree.move_ontop(spot, arg.begin());
94 stack.pop_back();
95 }
96 stack.push_back(subtree);
97 }
98 t = stack.back();
99}
100
101unordered_map<NodeType, int> operator_complexities = {
102 // Unary
103 {NodeType::Abs , 4},
104 {NodeType::Acos , 6},
105 {NodeType::Asin , 6},
106 {NodeType::Atan , 6},
107 {NodeType::Cos , 6},
108 {NodeType::Cosh , 6},
109 {NodeType::Sin , 6},
110 {NodeType::Sinh , 6},
111 {NodeType::Tan , 6},
112 {NodeType::Tanh , 6},
113 {NodeType::Ceil , 5},
114 {NodeType::Floor , 5},
115 {NodeType::Exp , 5},
116 {NodeType::Log , 5},
117 {NodeType::Logabs , 10},
118 {NodeType::Log1p , 9},
119 {NodeType::Sqrt , 5},
120 {NodeType::Sqrtabs , 5},
121 {NodeType::Square , 4},
122 {NodeType::Logistic, 4},
123 {NodeType::OffsetSum, 3},
124
125 // timing masks
126 {NodeType::Before, 4},
127 {NodeType::After , 4},
128 {NodeType::During, 4},
129
130 // Reducers
131 {NodeType::Min , 4},
132 {NodeType::Max , 4},
133 {NodeType::Mean , 4},
134 {NodeType::Median , 4},
135 {NodeType::Sum , 4},
136 {NodeType::Prod , 4},
137
138 // Transformers
139 {NodeType::Softmax, 5},
140
141 // Binary
142 {NodeType::Add, 3},
143 {NodeType::Sub, 3},
144 {NodeType::Mul, 4},
145 {NodeType::Div, 5},
146 {NodeType::Pow, 5},
147
148 //split
149 {NodeType::SplitBest, 4},
150 {NodeType::SplitOn , 4},
151
152 // boolean
153 {NodeType::And, 3},
154 {NodeType::Or , 3},
155 {NodeType::Not, 3},
156
157 // leaves
158 {NodeType::MeanLabel, 1},
159 {NodeType::Constant , 2},
160 {NodeType::Terminal , 3},
161 {NodeType::ArgMax , 5},
162 {NodeType::Count , 4},
163
164 // custom
165 {NodeType::CustomUnaryOp , 5},
166 {NodeType::CustomBinaryOp, 5},
167 {NodeType::CustomSplit , 5}
168};
169
170int TreeNode::get_linear_complexity() const
171{
172 int tree_complexity = operator_complexities.at(data.node_type);
173
174 auto child = first_child;
175 for(int i = 0; i < data.get_arg_count(); ++i)
176 {
177 tree_complexity += child->get_linear_complexity();
178 child = child->next_sibling;
179 }
180
181 // include the `w` and `*` if the node is weighted (and it is not a constant or mean label)
182 if (data.get_is_weighted()
184 {
185 // ignoring weight if it has the value of neutral element of operation
186 if ((Is<NodeType::OffsetSum>(data.node_type) && data.W != 0.0)
187 || (data.W != 1.0))
188 return operator_complexities.at(NodeType::Mul) +
189 operator_complexities.at(NodeType::Constant) +
190 tree_complexity;
191 }
192
193 return tree_complexity;
194};
195
196int TreeNode::get_complexity() const
197{
198 int node_complexity = operator_complexities.at(data.node_type);
199 int children_complexity_sum = 0; // acumulator for children complexities
200
201 auto child = first_child;
202 for(int i = 0; i < data.get_arg_count(); ++i)
203 {
204 children_complexity_sum += child->get_complexity();
205 child = child->next_sibling;
206 }
207
208 // avoid multiplication by zero if the node is a terminal
209 children_complexity_sum = max(children_complexity_sum, 1);
210
211 // include the `w` and `*` if the node is weighted (and it is not a constant or mean label)
212 if (data.get_is_weighted()
214 {
215 // ignoring weight if it has the value of neutral element of operation
216 if ((Is<NodeType::OffsetSum>(data.node_type) && data.W != 0.0)
217 || (data.W != 1.0))
218 return operator_complexities.at(NodeType::Mul)*(
219 operator_complexities.at(NodeType::Constant) +
220 node_complexity*(children_complexity_sum)
221 );
222 }
223
224 return node_complexity*(children_complexity_sum);
225};
226
227int TreeNode::get_size(bool include_weight) const
228{
229 int acc = 1; // the node operator or terminal
230
231 // SplitBest has an optimizable decision tree consisting of 3 nodes
232 // (terminal, arithmetic comparison, value) that needs to be taken
233 // into account. Split on will have an random decision tree that can
234 // have different sizes, but will also have the arithmetic comparison
235 // and a value.
236 if (Is<NodeType::SplitBest>(data.node_type))
237 acc += 3;
238 else if (Is<NodeType::SplitOn>(data.node_type))
239 acc += 2;
240
241 if ( (include_weight && data.get_is_weighted()==true)
243 // Taking into account the weight and multiplication, if enabled.
244 // weighted constants still count as 1 (simpler than constant terminals)
245 acc += 2;
246
247 auto child = first_child;
248 for(int i = 0; i < data.get_arg_count(); ++i)
249 {
250 acc += child->get_size(include_weight);
251 child = child->next_sibling;
252 }
253
254 return acc;
255};
void ReplaceStringInPlace(std::string &subject, const std::string &search, const std::string &replace)
string find and replace in place
Definition utils.cpp:398
auto Isnt(DataType dt) -> bool
Definition node.h:43
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:31
auto Is(NodeType nt) -> bool
Definition node.h:291
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
class holding the data for a node in a tree.
Definition node.h:84
unordered_map< NodeType, int > operator_complexities