Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
tree_node.cpp
Go to the documentation of this file.
1#include "tree_node.h"
2
3
4string TreeNode::get_model(bool pretty) const
5{
6 if (data.get_arg_count()==0)
7 return data.get_name();
8
9 vector<string> child_outputs;
10 auto sib = first_child;
11 for(int i = 0; i < data.get_arg_count(); ++i)
12 {
13 child_outputs.push_back(sib->get_model(pretty));
14 sib = sib->next_sibling;
15 }
16 return data.get_model(child_outputs);
17};
18
19
20string TreeNode::get_tree_model(bool pretty, string offset) const
21{
22 if (data.get_arg_count()==0)
23 return data.get_name();
24
25 string new_offset = " ";
26 string child_outputs = "\n";
27
28 auto sib = first_child;
29 for(int i = 0; i < data.get_arg_count(); ++i)
30 {
31 child_outputs += offset + "|- ";
32 string s = sib->get_tree_model(pretty, offset+new_offset);
33 sib = sib->next_sibling;
34 // if (sib == nullptr)
35 ReplaceStringInPlace(s, "\n"+offset, "\n"+offset+"|") ;
36 child_outputs += s;
37 if (sib != nullptr)
38 child_outputs += "\n";
39 }
40
41 if (Is<NodeType::SplitBest>(data.node_type)){
42 if (data.get_feature_type() == DataType::ArrayB)
43 return fmt::format("If({})", data.get_feature()) + child_outputs;
44
45 return fmt::format("If({}>{:.2f})", data.get_feature(), data.W) +
46 child_outputs;
47 }
48 else if (Is<NodeType::SplitOn>(data.node_type)){
49 if (data.arg_types.at(0) == DataType::ArrayB)
50 {
51 // booleans dont use thresholds (they are used directly as mask in split)
52 return "If" + child_outputs;
53 }
54 // integers or floating points (they have a threshold)
55 return fmt::format("If(>{:.2f})", data.W) + child_outputs;
56 }
57 else{
58 return data.get_name() + child_outputs;
59 }
60};
62// serialization for tree
63void to_json(json &j, const tree<Node> &t)
64{
65 j.clear();
66 // for (auto iter = t.begin(); iter!=t.end(); ++iter)
67 for (const auto &el : t)
68 {
69 j.push_back(el);
70 }
71}
72
76void from_json(const json &j, tree<Node> &t)
77{
78 vector<tree<Node>> stack;
79 for (int i = j.size(); i --> 0; )
80 {
81 auto node = j.at(i).get<Node>();
82 tree<Node> subtree;
83 auto root = subtree.insert(subtree.begin(), node);
84 for (auto at : node.arg_types)
85 {
86 auto spot = subtree.append_child(root);
87 auto arg = stack.back();
88 subtree.move_ontop(spot, arg.begin());
89 stack.pop_back();
90 }
91 stack.push_back(subtree);
92 }
93 t = stack.back();
94}
95
96unordered_map<NodeType, int> operator_complexities = {
97 // Unary
98 {NodeType::Abs , 4},
99 {NodeType::Acos , 6},
100 {NodeType::Asin , 6},
101 {NodeType::Atan , 6},
102 {NodeType::Cos , 6},
103 {NodeType::Cosh , 6},
104 {NodeType::Sin , 6},
105 {NodeType::Sinh , 6},
106 {NodeType::Tan , 6},
107 {NodeType::Tanh , 6},
108 {NodeType::Ceil , 5},
109 {NodeType::Floor , 5},
110 {NodeType::Exp , 5},
111 {NodeType::Log , 5},
112 {NodeType::Logabs , 10},
113 {NodeType::Log1p , 9},
114 {NodeType::Sqrt , 5},
115 {NodeType::Sqrtabs , 5},
116 {NodeType::Square , 4},
117 {NodeType::Logistic, 4},
118 {NodeType::OffsetSum, 3},
119
120 // timing masks
121 {NodeType::Before, 4},
122 {NodeType::After , 4},
123 {NodeType::During, 4},
124
125 // Reducers
126 {NodeType::Min , 4},
127 {NodeType::Max , 4},
128 {NodeType::Mean , 4},
129 {NodeType::Median , 4},
130 {NodeType::Sum , 4},
131 {NodeType::Prod , 4},
132
133 // Transformers
134 {NodeType::Softmax, 5},
135
136 // Binary
137 {NodeType::Add, 3},
138 {NodeType::Sub, 3},
139 {NodeType::Mul, 4},
140 {NodeType::Div, 5},
141 {NodeType::Pow, 5},
142
143 //split
144 {NodeType::SplitBest, 4},
145 {NodeType::SplitOn , 4},
146
147 // boolean
148 {NodeType::And, 3},
149 {NodeType::Or , 3},
150 {NodeType::Not, 3},
151
152 // leaves
153 {NodeType::MeanLabel, 1},
154 {NodeType::Constant , 2},
155 {NodeType::Terminal , 3},
156 {NodeType::ArgMax , 5},
157 {NodeType::Count , 4},
158
159 // custom
160 {NodeType::CustomUnaryOp , 5},
161 {NodeType::CustomBinaryOp, 5},
162 {NodeType::CustomSplit , 5}
163};
164
165int TreeNode::get_linear_complexity() const
166{
167 int tree_complexity = operator_complexities.at(data.node_type);
168
169 auto child = first_child;
170 for(int i = 0; i < data.get_arg_count(); ++i)
171 {
172 tree_complexity += child->get_linear_complexity();
173 child = child->next_sibling;
174 }
175
176 // include the `w` and `*` if the node is weighted (and it is not a constant or mean label)
177 if (data.get_is_weighted()
179 {
180 // ignoring weight if it has the value of neutral element of operation
181 if ((Is<NodeType::OffsetSum>(data.node_type) && data.W != 0.0)
182 || (data.W != 1.0))
183 return operator_complexities.at(NodeType::Mul) +
184 operator_complexities.at(NodeType::Constant) +
185 tree_complexity;
186 }
187
188 return tree_complexity;
189};
190
191int TreeNode::get_complexity() const
192{
193 int node_complexity = operator_complexities.at(data.node_type);
194 int children_complexity_sum = 0; // acumulator for children complexities
195
196 auto child = first_child;
197 for(int i = 0; i < data.get_arg_count(); ++i)
198 {
199 children_complexity_sum += child->get_complexity();
200 child = child->next_sibling;
201 }
202
203 // avoid multiplication by zero if the node is a terminal
204 children_complexity_sum = max(children_complexity_sum, 1);
205
206 // include the `w` and `*` if the node is weighted (and it is not a constant or mean label)
207 if (data.get_is_weighted()
209 {
210 // ignoring weight if it has the value of neutral element of operation
211 if ((Is<NodeType::OffsetSum>(data.node_type) && data.W != 0.0)
212 || (data.W != 1.0))
213 return operator_complexities.at(NodeType::Mul)*(
214 operator_complexities.at(NodeType::Constant) +
215 node_complexity*(children_complexity_sum)
216 );
217 }
218
219 return node_complexity*(children_complexity_sum);
220};
221
222int TreeNode::get_size(bool include_weight) const
223{
224 int acc = 1; // the node operator or terminal
225
226 // SplitBest has an optimizable decision tree consisting of 3 nodes
227 // (terminal, arithmetic comparison, value) that needs to be taken
228 // into account. Split on will have an random decision tree that can
229 // have different sizes, but will also have the arithmetic comparison
230 // and a value.
231 if (Is<NodeType::SplitBest>(data.node_type))
232 acc += 3;
233 else if (Is<NodeType::SplitOn>(data.node_type))
234 acc += 2;
235
236 if ( (include_weight && data.get_is_weighted()==true)
238 // Taking into account the weight and multiplication, if enabled.
239 // weighted constants still count as 1 (simpler than constant terminals)
240 acc += 2;
241
242 auto child = first_child;
243 for(int i = 0; i < data.get_arg_count(); ++i)
244 {
245 acc += child->get_size(include_weight);
246 child = child->next_sibling;
247 }
248
249 return acc;
250};
void ReplaceStringInPlace(std::string &subject, const std::string &search, const std::string &replace)
string find and replace in place
Definition utils.cpp:398
auto Isnt(DataType dt) -> bool
Definition node.h:43
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:25
auto Is(NodeType nt) -> bool
Definition node.h:272
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
class holding the data for a node in a tree.
Definition node.h:84
unordered_map< NodeType, int > operator_complexities
Definition tree_node.cpp:96