Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
node.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5Node class design heavily inspired by Operon, (c) Heal Research
6https://github.com/heal-research/operon/
7*/
8
9#ifndef NODE_H
10#define NODE_H
11
12#include "../data/data.h"
13#include "nodetype.h"
14#include "../util/utils.h"
15#include <iostream>
16// #include "nodes/base.h"
17// #include "nodes/dx.h"
18// #include "nodes/split.h"
19// #include "nodes/terminal.h"
21
22/*
23Node overhaul:
24
25- Incorporating new design principles, learning much from operon:
26 - make Node trivial, so that it is easily copied around.
27 - use Enums and maps to define node information. This kind of abandons the
28 object oriented approach taken thus far, but it should make extensibility
29 easier and performance better in the long run.
30 - Leverage ceres for parameter optimization. No more defining analytical
31 derivatives for every function. Let ceres do that.
32 - sidenote: not sure ceres can handle the data flow of split nodes.
33 need to figure out.
34 - this also suggests turning TimeSeries back into EigenSparse matrices.
35 - forget all the runtime node generation. It saves space at the cost of
36 unclear code. I might as well just define all the nodes that are available,
37 plainly. At run-time this will be faster.
38 - keep an eye towards extensibility by defining a custom node registration
39 function that works.
40*/
41
42using Brush::DataType;
44
45namespace Brush{
46
47template <DataType... T>
48inline auto Isnt(DataType dt) -> bool { return !((dt == T) || ...); }
49
50template<DataType DT>
51inline auto IsWeighable() noexcept -> bool {
57 >(DT);
58}
59inline auto IsWeighable(DataType dt) noexcept -> bool {
65 >(dt);
66}
67
69 std::size_t operator()(std::vector<uint32_t> const& vec) const {
70 std::size_t seed = vec.size();
71 for(auto& i : vec) {
72 seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
73 }
74 return seed;
75 }
76 std::size_t operator()(std::vector<Brush::DataType> const& vec) const {
77 std::size_t seed = vec.size();
78 for(auto& i : vec) {
79 seed ^= uint32_t(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
80 }
81 return seed;
82 }
83};
84
89struct Node {
90
92 string name;
95
99 std::vector<DataType> arg_types;
101 std::size_t sig_hash;
103 std::size_t sig_dual_hash;
104
105 // TODO: node_is_fixed, weight_is_fixed, and is_weighted accessed via getters/setters
106
107 // The three flags below will help determine how to handle the node during mutation,
108 // and can also be changed by the locking mechanism. prob_change is user-defined and
109 // also interact with the flags
110
120
122 float W;
123
125 bool center_op; // TODO: use center_op in printing
126
127 // /// @brief a node hash / unique ID for the node, except weights
128 // size_t node_hash;
130 using HashTuple = std::tuple<
131 UnderlyingNodeType, // node type
132 size_t, // sig_hash
133 bool, // is_weighted
134 string, // feature
135 bool, // node_is_fixed
136 bool, // weight_is_fixed
137 int // rounded W
138 // float // prob_change
139 >;
140
141 // Node(){init();};
142 Node() = default;
143
149 template<typename S>
150 explicit Node(NodeType type, S signature, bool weighted=false, string feature_name="") noexcept
151 : node_type(type)
152 , name(NodeTypeName[type])
153 , ret_type(S::get_ret_type())
155 , sig_hash(S::hash())
156 , sig_dual_hash(S::Dual::hash())
157 , is_weighted(weighted)
158 , feature(feature_name)
159 {
160 init();
161 }
162
163 template<typename S>
165 {
166 ret_type = S::get_ret_type();
167 arg_types = S::get_arg_types();
168 sig_hash = S::hash();
169 sig_dual_hash = S::Dual::hash();
170 // set_node_hash();
171 }
172
173 void init(){
174 // starting weights with neutral element of the operation. offsetsum
175 // is the only node that does not multiply the weight --- instead, it adds it,
176 // so we need to handle it differently
177 W = (node_type == NodeType::OffsetSum) ? 0.0 : 1.0;
178
179 // set_node_hash();
180
181 // everything is unlocked. Special nodes (like the logistic root)
182 // should be fixed during its creation (check `vary` source code)
183 node_is_fixed=false;
184 weight_is_fixed=false;
185
186 set_prob_change(1.0);
187
188 // TODO: confirm that this is really necessary (intializing this variable) and transform this line into a ternary if so
189 // cant weight an boolean terminal
190 this->is_weighted = IsWeighable(this->ret_type);
191 }
192
196 string get_name(bool include_weight=true) const noexcept;
197 string get_model(const vector<string>&) const noexcept;
198
199 // get return type and argument types.
200 inline DataType get_ret_type() const { return ret_type; };
201 inline std::size_t args_type() const { return sig_hash; };
202 inline auto get_arg_types() const { return arg_types; };
203 inline size_t get_arg_count() const { return arg_types.size(); };
204
205 // void set_node_hash(){
206 // node_hash = std::hash<HashTuple>{}(HashTuple{
207 // NodeTypes::GetIndex(node_type),
208 // sig_hash,
209 // is_weighted,
210 // feature,
211 // fixed,
212 // W,
213 // prob_change
214 // });
215 // // fmt::print("nodetype:{}; hash tuple:{}; node_hash={}\n", node_type, tmp, node_hash);
216 // }
217 size_t get_node_hash(bool include_const=true) const {
218 // we ignore the constant only on simplification, because it will be tuned anyways
219
220 return std::hash<HashTuple>{}(HashTuple{
222 sig_hash,
224 feature,
227 // include weights only if we want exact matches.
228 // but we will indicate that constants are on using get is weighted,
229 // just so we differentiate whether the weight exists or is ignored
230 include_const ? int(W * 100) : (get_is_weighted() ? 1 : 0)
231 });
232 }
233
234 //comparison operators
235 inline auto operator==(const Node& rhs) const noexcept -> bool
236 {
237 // obs: this is declared as a member operator, so the lhs is implicit.
238 // If i were to declare this outsize class definition, then I would have
239 // to include lhs in the function signature
240
241 /* return CalculatedHashValue == rhs.CalculatedHashValue; */
242 return get_node_hash() == rhs.get_node_hash();
243 /* return (*this) == rhs; */
244 }
245
246 inline auto operator!=(const Node& rhs) const noexcept -> bool
247 {
248 return !((*this) == rhs);
249 }
250
251 inline auto operator<(const Node& rhs) const noexcept -> bool
252 {
253 /* return std::tie(HashValue, CalculatedHashValue) < std::tie(rhs.HashValue, rhs.CalculatedHashValue); */
254 return get_node_hash() < rhs.get_node_hash();
255 return (*this) < rhs;
256 }
257
258 inline auto operator<=(const Node& rhs) const noexcept -> bool
259 {
260 return ((*this) == rhs || (*this) < rhs);
261 }
262
263 inline auto operator>(const Node& rhs) const noexcept -> bool
264 {
265 return !((*this) <= rhs);
266 }
267
268 inline auto operator>=(const Node& rhs) const noexcept -> bool
269 {
270 return !((*this) < rhs);
271 }
272
274 // getters and setters
275 //TODO revisit
276 float get_prob_change() const { return node_is_fixed ? 0.0 : this->prob_change;};
277 void set_prob_change(float w){ this->prob_change = w;};
278 float get_prob_keep() const { return node_is_fixed ? 1.0 : 1.0-this->prob_change;};
279
280 inline void set_feature(string f){ feature = f; };
281 inline string get_feature() const { return feature; };
282
283 inline void set_feature_type(DataType ft){ this->feature_type = ft; };
284 inline DataType get_feature_type() const { return this->feature_type; };
285
286 inline void set_keep_split_feature(bool keep){ this->keep_split_feature = keep; };
287 inline bool get_keep_split_feature() const { return this->keep_split_feature; };
288
289 // Some types does not support weights, so we completely ignore the weights
290 // if is not weighable
291 inline bool get_is_weighted() const {
292 if (IsWeighable(this->ret_type))
293 return this->is_weighted;
294 return false;
295 };
296 inline void set_is_weighted(bool is_weighted){
297 if (IsWeighable(this->ret_type))
298 this->is_weighted = is_weighted;
299 };
300
301 // private:
303 string feature;
304
307
309 bool keep_split_feature = false; // TODO: unittests for keep_split_feature
310};
311
312template <NodeType... T>
313inline auto Is(NodeType nt) -> bool { return ((nt == T) || ...); }
314
315template <NodeType... T>
316inline auto Isnt(NodeType nt) -> bool { return !((nt == T) || ...); }
317
318// TODO: I think there are places where I can replace some logic with IsLeaf --> Check that.
319// TODO: create IsConstant, and add Constant and MeanLabel to it.
320inline auto IsLeaf(NodeType nt) noexcept -> bool {
322}
323
324inline auto IsCommutative(NodeType nt) noexcept -> bool {
325 return Is<NodeType::Add,
331 >(nt);
332}
333
334inline auto IsDifferentiable(NodeType nt) noexcept -> bool {
335 return Isnt<
344 NodeType::Not // TODO: should I include OffsetSum here? If I do so, then I should change the logic in the optimizer to not optimize the weight of OffsetSum nodes.
345 >(nt);
346}
347template<NodeType NT>
348inline auto IsWeighable() noexcept -> bool {
349 return Isnt<
362 >(NT);
363}
380
381ostream& operator<<(ostream& os, const Node& n);
382ostream& operator<<(ostream& os, const NodeType& nt);
383
384void from_json(const json &j, Node& p);
385void to_json(json& j, const Node& p);
386} // namespace Brush
387
388// format overload for Nodes
389template <> struct fmt::formatter<Brush::Node>: formatter<string_view> {
390 // parse is inherited from formatter<string_view>.
391 template <typename FormatContext>
392 auto format(Brush::Node x, FormatContext& ctx) const {
393 return formatter<string_view>::format(x.get_name(), ctx);
394 }
395};
396
397
398
399#endif
holds variable type data.
Definition data.h:51
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
ostream & operator<<(ostream &os, DataType n)
std::underlying_type_t< NodeType > UnderlyingNodeType
Definition nodetype.h:114
NodeType
Definition nodetype.h:31
auto Isnt(DataType dt) -> bool
Definition node.h:48
auto IsWeighable() noexcept -> bool
Definition node.h:51
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:31
auto IsLeaf(NodeType nt) noexcept -> bool
Definition node.h:320
auto IsCommutative(NodeType nt) noexcept -> bool
Definition node.h:324
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:313
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
NodeType NT
Definition node.cpp:141
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
DataType DT
Definition types.h:188
auto IsDifferentiable(NodeType nt) noexcept -> bool
Definition node.h:334
static auto GetIndex(NodeType type) -> size_t
Definition nodetype.h:127
class holding the data for a node in a tree.
Definition node.h:89
std::tuple< UnderlyingNodeType, size_t, bool, string, bool, bool, int > HashTuple
tuple type for hashing
Definition node.h:130
float get_prob_keep() const
Definition node.h:278
bool center_op
whether to center the operator in pretty printing
Definition node.h:125
std::vector< DataType > arg_types
argument data types
Definition node.h:99
void set_signature()
Definition node.h:164
bool get_keep_split_feature() const
Definition node.h:287
DataType feature_type
feature type for terminals or splitting nodes
Definition node.h:306
DataType get_ret_type() const
Definition node.h:200
void set_keep_split_feature(bool keep)
Definition node.h:286
size_t get_node_hash(bool include_const=true) const
Definition node.h:217
bool weight_is_fixed
whether the weight should be kept during variation. Notice that weight_is_fixed alows us to fix the w...
Definition node.h:114
float get_prob_change() const
Definition node.h:276
NodeType node_type
the node type
Definition node.h:94
DataType get_feature_type() const
Definition node.h:284
auto operator<=(const Node &rhs) const noexcept -> bool
Definition node.h:258
Node(NodeType type, S signature, bool weighted=false, string feature_name="") noexcept
Constructor used by search space.
Definition node.h:150
DataType ret_type
return data type
Definition node.h:97
auto operator>(const Node &rhs) const noexcept -> bool
Definition node.h:263
void set_is_weighted(bool is_weighted)
Definition node.h:296
auto get_arg_types() const
Definition node.h:202
auto operator==(const Node &rhs) const noexcept -> bool
Definition node.h:235
void set_feature(string f)
Definition node.h:280
string get_feature() const
Definition node.h:281
std::size_t sig_hash
a hash of the signature
Definition node.h:101
bool get_is_weighted() const
Definition node.h:291
void init()
Definition node.h:173
void set_feature_type(DataType ft)
Definition node.h:283
std::size_t args_type() const
Definition node.h:201
float prob_change
chance of node being selected for variation. This will take into account if the node is fixed,...
Definition node.h:119
auto operator>=(const Node &rhs) const noexcept -> bool
Definition node.h:268
string get_name(bool include_weight=true) const noexcept
gets a string version of the node for printing.
Definition node.cpp:20
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:122
bool is_weighted
whether this node is weighted (ignored in nodes that must have weights, such as meanLabel,...
Definition node.h:116
bool node_is_fixed
whether the node is replaceable. Weights are still optimized.
Definition node.h:112
bool keep_split_feature
fix the SplitBest feature when the node is fixed
Definition node.h:309
size_t get_arg_count() const
Definition node.h:203
string feature
feature name for terminals or splitting nodes
Definition node.h:303
auto operator!=(const Node &rhs) const noexcept -> bool
Definition node.h:246
string name
full name of the node, with types
Definition node.h:92
void set_prob_change(float w)
Definition node.h:277
Node()=default
auto operator<(const Node &rhs) const noexcept -> bool
Definition node.h:251
std::size_t sig_dual_hash
a hash of the dual of the signature (for NLS)
Definition node.h:103
string get_model(const vector< string > &) const noexcept
Definition node.cpp:58
std::size_t operator()(std::vector< Brush::DataType > const &vec) const
Definition node.h:76
std::size_t operator()(std::vector< uint32_t > const &vec) const
Definition node.h:69
auto format(Brush::Node x, FormatContext &ctx) const
Definition node.h:392