Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
tree_node.cpp
Go to the documentation of this file.
1#include "tree_node.h"
2
3
4string TreeNode::get_model(bool pretty) const
5{
6 if (data.get_arg_count()==0)
7 return data.get_name();
8
9 vector<string> child_outputs;
10 auto sib = first_child;
11 for(int i = 0; i < data.get_arg_count(); ++i)
12 {
13 child_outputs.push_back(sib->get_model(pretty));
14 sib = sib->next_sibling;
15 }
16 return data.get_model(child_outputs);
17};
18
19
20string TreeNode::get_tree_model(bool pretty, string offset) const
21{
22 if (data.get_arg_count()==0)
23 return data.get_name();
24
25 string new_offset = " ";
26 string child_outputs = "\n";
27
28 auto sib = first_child;
29 for(int i = 0; i < data.get_arg_count(); ++i)
30 {
31 child_outputs += offset + "|-";
32 string s = sib->get_tree_model(pretty, offset+new_offset);
33 sib = sib->next_sibling;
34 if (sib == nullptr)
35 ReplaceStringInPlace(s, "\n"+offset, "\n"+offset+"|") ;
37 if (sib != nullptr)
38 child_outputs += "\n";
39 }
40
41 return data.get_name() + child_outputs;
42};
44// serialization for tree
45void to_json(json &j, const tree<Node> &t)
46{
47 j.clear();
48 // for (auto iter = t.begin(); iter!=t.end(); ++iter)
49 for (const auto &el : t)
50 {
51 j.push_back(el);
52 }
53}
54
58void from_json(const json &j, tree<Node> &t)
59{
60 vector<tree<Node>> stack;
61 for (int i = j.size(); i --> 0; )
62 {
63 auto node = j.at(i).get<Node>();
65 auto root = subtree.insert(subtree.begin(), node);
66 for (auto at : node.arg_types)
67 {
68 auto spot = subtree.append_child(root);
69 auto arg = stack.back();
70 subtree.move_ontop(spot, arg.begin());
71 stack.pop_back();
72 }
73 stack.push_back(subtree);
74 }
75 t = stack.back();
76}
77
78unordered_map<NodeType, int> operator_complexities = {
79 // Unary
80 {NodeType::Abs , 3},
81 {NodeType::Acos , 5},
82 {NodeType::Asin , 5},
83 {NodeType::Atan , 5},
84 {NodeType::Cos , 5},
85 {NodeType::Cosh , 5},
86 {NodeType::Sin , 5},
87 {NodeType::Sinh , 5},
88 {NodeType::Tan , 5},
89 {NodeType::Tanh , 5},
90 {NodeType::Ceil , 4},
91 {NodeType::Floor , 4},
92 {NodeType::Exp , 4},
93 {NodeType::Log , 4},
94 {NodeType::Logabs , 12},
95 {NodeType::Log1p , 8},
96 {NodeType::Sqrt , 4},
97 {NodeType::Sqrtabs , 4},
98 {NodeType::Square , 3},
99 {NodeType::Logistic, 3},
100 {NodeType::OffsetSum, 2},
101
102 // timing masks
103 {NodeType::Before, 3},
104 {NodeType::After , 3},
105 {NodeType::During, 3},
106
107 // Reducers
108 {NodeType::Min , 3},
109 {NodeType::Max , 3},
110 {NodeType::Mean , 3},
111 {NodeType::Median , 3},
112 {NodeType::Sum , 2},
113 {NodeType::Prod , 3},
114
115 // Transformers
116 {NodeType::Softmax, 4},
117
118 // Binary
119 {NodeType::Add, 2},
120 {NodeType::Sub, 2},
121 {NodeType::Mul, 3},
122 {NodeType::Div, 4},
123 {NodeType::Pow, 5},
124
125 //split
126 {NodeType::SplitBest, 4},
127 {NodeType::SplitOn , 4},
128
129 // boolean
130 {NodeType::And, 2},
131 {NodeType::Or , 2},
132 {NodeType::Not, 2},
133
134 // leaves
135 {NodeType::MeanLabel, 1},
136 {NodeType::Constant , 1},
137 {NodeType::Terminal , 2},
138 {NodeType::ArgMax , 5},
139 {NodeType::Count , 3},
140
141 // custom
142 {NodeType::CustomUnaryOp , 5},
143 {NodeType::CustomBinaryOp, 5},
144 {NodeType::CustomSplit , 5}
145};
146
147int TreeNode::get_complexity() const
148{
149 int node_complexity = operator_complexities.at(data.node_type);
150 int children_complexity_sum = 0; // acumulator for children complexities
151
152 auto child = first_child;
153 for(int i = 0; i < data.get_arg_count(); ++i)
154 {
155 children_complexity_sum += child->get_complexity();
156 child = child->next_sibling;
157 }
158
159 // avoid multiplication by zero if the node is a terminal
161
162 // include the `w` and `*` if the node is weighted (and it is not a constant or mean label)
163 if (data.get_is_weighted()
164 && !(Is<NodeType::Constant>(data.node_type)
165 || (Is<NodeType::MeanLabel>(data.node_type)
166 || Is<NodeType::OffsetSum>(data.node_type)) )
167 )
168 return operator_complexities.at(NodeType::Mul)*(
169 operator_complexities.at(NodeType::Constant) +
171 );
172
174};
175
176int TreeNode::get_size(bool include_weight) const
177{
178 int acc = 1; // the node operator or terminal
179
180 // SplitBest has an optimizable decision tree consisting of 3 nodes
181 // (terminal, arithmetic comparison, value) that needs to be taken
182 // into account. Split on will have an random decision tree that can
183 // have different sizes, but will also have the arithmetic comparison
184 // and a value.
185 if (Is<NodeType::SplitBest>(data.node_type))
186 acc += 3;
187 else if (Is<NodeType::SplitOn>(data.node_type))
188 acc += 2;
189
190 if ( (include_weight && data.get_is_weighted()==true)
192 // Taking into account the weight and multiplication, if enabled.
193 // weighted constants still count as 1 (simpler than constant terminals)
194 acc += 2;
195
196 auto child = first_child;
197 for(int i = 0; i < data.get_arg_count(); ++i)
198 {
199 acc += child->get_size(include_weight);
200 child = child->next_sibling;
201 }
202
203 return acc;
204};
void bind_engine(py::module &m, string name)
void ReplaceStringInPlace(std::string &subject, const std::string &search, const std::string &replace)
string find and replace in place
Definition utils.cpp:398
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:24
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
class holding the data for a node in a tree.
Definition node.h:84
unordered_map< NodeType, int > operator_complexities
Definition tree_node.cpp:78