Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
node.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5Node class design heavily inspired by Operon, (c) Heal Research
6https://github.com/heal-research/operon/
7*/
8
9#ifndef NODE_H
10#define NODE_H
11
12#include "../data/data.h"
13#include "nodetype.h"
14#include "../util/utils.h"
15#include <iostream>
16// #include "nodes/base.h"
17// #include "nodes/dx.h"
18// #include "nodes/split.h"
19// #include "nodes/terminal.h"
21/*
22Node overhaul:
23
24- Incorporating new design principles, learning much from operon:
25 - make Node trivial, so that it is easily copied around.
26 - use Enums and maps to define node information. This kind of abandons the object oriented approach taken thus far, but it should make extensibility easier and performance better in the long run.
27 - Leverage ceres for parameter optimization. No more defining analytical
28 derivatives for every function. Let ceres do that.
29 - sidenote: not sure ceres can handle the data flow of split nodes.
30 need to figure out.
31 - this also suggests turning TimeSeries back into EigenSparse matrices.
32 - forget all the runtime node generation. It saves space at the cost of
33 unclear code. I might as well just define all the nodes that are available, plainly. At run-time this will be faster.
34 - keep an eye towards extensibility by defining a custom node registration function that works.
35
36*/
37using Brush::DataType;
39
40namespace Brush{
41
42template <DataType... T>
43inline auto Isnt(DataType dt) -> bool { return !((dt == T) || ...); }
44
45template<DataType DT>
46inline auto IsWeighable() noexcept -> bool {
52 >(DT);
53}
54inline auto IsWeighable(DataType dt) noexcept -> bool {
60 >(dt);
61}
62
64 std::size_t operator()(std::vector<uint32_t> const& vec) const {
65 std::size_t seed = vec.size();
66 for(auto& i : vec) {
67 seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
68 }
69 return seed;
70 }
71 std::size_t operator()(std::vector<Brush::DataType> const& vec) const {
72 std::size_t seed = vec.size();
73 for(auto& i : vec) {
74 seed ^= uint32_t(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
75 }
76 return seed;
77 }
78};
79
84struct Node {
85
87 string name;
90
94 std::vector<DataType> arg_types;
96 std::size_t sig_hash;
98 std::size_t sig_dual_hash;
99
101 bool fixed;
107 float W;
108
110 bool center_op; // TODO: use center_op in printing
111
112 // /// @brief a node hash / unique ID for the node, except weights
113 // size_t node_hash;
115 using HashTuple = std::tuple<
116 UnderlyingNodeType, // node type
117 size_t, // sig_hash
118 bool, // is_weighted
119 string, // feature
120 bool, // fixed
121 int // rounded W
122 // float // prob_change
123 >;
124
125 // Node(){init();};
126 Node() = default;
127
133 template<typename S>
134 explicit Node(NodeType type, S signature, bool weighted=false, string feature_name="") noexcept
135 : node_type(type)
136 , name(NodeTypeName[type])
137 , ret_type(S::get_ret_type())
139 , sig_hash(S::hash())
140 , sig_dual_hash(S::Dual::hash())
141 , is_weighted(weighted)
142 , feature(feature_name)
143 {
144 init();
145 }
146
147 template<typename S>
149 {
150 ret_type = S::get_ret_type();
151 arg_types = S::get_arg_types();
152 sig_hash = S::hash();
153 sig_dual_hash = S::Dual::hash();
154 // set_node_hash();
155 }
156
157 void init(){
158
159 // starting weights with neutral element of the operation. offsetsum
160 // is the only node that does not multiply the weight --- instead, it adds it
161 W = (node_type == NodeType::OffsetSum) ? 0.0 : 1.0;
162
163 // set_node_hash();
164 fixed=false;
165 set_prob_change(1.0);
166
167 // TODO: confirm that this is really necessary (intializing this variable) and transform this line into a ternary if so
168 // cant weight an boolean terminal
169 this->is_weighted = IsWeighable(this->ret_type);
170 }
171
175 string get_name(bool include_weight=true) const noexcept;
176 string get_model(const vector<string>&) const noexcept;
177
178 // get return type and argument types.
179 inline DataType get_ret_type() const { return ret_type; };
180 inline std::size_t args_type() const { return sig_hash; };
181 inline auto get_arg_types() const { return arg_types; };
182 inline size_t get_arg_count() const { return arg_types.size(); };
183
184 // void set_node_hash(){
185 // node_hash = std::hash<HashTuple>{}(HashTuple{
186 // NodeTypes::GetIndex(node_type),
187 // sig_hash,
188 // is_weighted,
189 // feature,
190 // fixed,
191 // W,
192 // prob_change
193 // });
194 // // fmt::print("nodetype:{}; hash tuple:{}; node_hash={}\n", node_type, tmp, node_hash);
195 // }
196 size_t get_node_hash(bool include_const=true) const {
197 // we ignore the constant only on simplification, because it will be tuned anyways
198
199 return std::hash<HashTuple>{}(HashTuple{
201 sig_hash,
203 feature,
204 fixed,
205 // include weights only if we want exact matches.
206 // but we will indicate that constants are on using get is weighted,
207 // just so we differentiate whether the weight exists or is ignored
208 include_const ? int(W * 100) : (get_is_weighted() ? 1 : 0)
209 });
210 }
211
212 //comparison operators
213 inline auto operator==(const Node& rhs) const noexcept -> bool
214 {
215 // obs: this is declared as a member operator, so the lhs is implicit.
216 // If i were to declare this outsize class definition, then I would have
217 // to include lhs in the function signature
218
219 /* return CalculatedHashValue == rhs.CalculatedHashValue; */
220 return get_node_hash() == rhs.get_node_hash();
221 /* return (*this) == rhs; */
222 }
223
224 inline auto operator!=(const Node& rhs) const noexcept -> bool
225 {
226 return !((*this) == rhs);
227 }
228
229 inline auto operator<(const Node& rhs) const noexcept -> bool
230 {
231 /* return std::tie(HashValue, CalculatedHashValue) < std::tie(rhs.HashValue, rhs.CalculatedHashValue); */
232 return get_node_hash() < rhs.get_node_hash();
233 return (*this) < rhs;
234 }
235
236 inline auto operator<=(const Node& rhs) const noexcept -> bool
237 {
238 return ((*this) == rhs || (*this) < rhs);
239 }
240
241 inline auto operator>(const Node& rhs) const noexcept -> bool
242 {
243 return !((*this) <= rhs);
244 }
245
246 inline auto operator>=(const Node& rhs) const noexcept -> bool
247 {
248 return !((*this) < rhs);
249 }
250
252 // getters and setters
253 //TODO revisit
254 float get_prob_change() const { return fixed ? 0.0 : this->prob_change;};
255 void set_prob_change(float w){ this->prob_change = w;};
256 float get_prob_keep() const { return fixed ? 1.0 : 1.0-this->prob_change;};
257
258 inline void set_feature(string f){ feature = f; };
259 inline string get_feature() const { return feature; };
260
261 inline void set_feature_type(DataType ft){ this->feature_type = ft; };
262 inline DataType get_feature_type() const { return this->feature_type; };
263
264 inline void set_keep_split_feature(bool keep){ this->keep_split_feature = keep; };
265 inline bool get_keep_split_feature() const { return this->keep_split_feature; };
266
267 // Some types does not have weights, so we completely ignore the weights
268 // if is not weighable
269 inline bool get_is_weighted() const {
270 if (IsWeighable(this->ret_type))
271 return this->is_weighted;
272 return false;
273 };
274 inline void set_is_weighted(bool is_weighted){
275 if (IsWeighable(this->ret_type))
276 this->is_weighted = is_weighted;
277 };
278
279 // private:
281 string feature;
282
285
287 bool keep_split_feature = false; // TODO: unittests for keep_split_feature
288};
289
290template <NodeType... T>
291inline auto Is(NodeType nt) -> bool { return ((nt == T) || ...); }
292
293template <NodeType... T>
294inline auto Isnt(NodeType nt) -> bool { return !((nt == T) || ...); }
295
296// TODO: I think there are places where I can replace some logic with IsLeaf --> Check that.
297// TODO: create IsConstant, and add Constant and MeanLabel to it.
298inline auto IsLeaf(NodeType nt) noexcept -> bool {
300}
301
302inline auto IsCommutative(NodeType nt) noexcept -> bool {
303 return Is<NodeType::Add,
309 >(nt);
310}
311
312inline auto IsDifferentiable(NodeType nt) noexcept -> bool {
313 return Isnt<
322 NodeType::Not // TODO: should I include OffsetSum here? If I do so, then I should change the logic in the optimizer to not optimize the weight of OffsetSum nodes.
323 >(nt);
324}
325template<NodeType NT>
326inline auto IsWeighable() noexcept -> bool {
327 return Isnt<
340 >(NT);
341}
358
359ostream& operator<<(ostream& os, const Node& n);
360ostream& operator<<(ostream& os, const NodeType& nt);
361
362void from_json(const json &j, Node& p);
363void to_json(json& j, const Node& p);
364} // namespace Brush
365
366// format overload for Nodes
367template <> struct fmt::formatter<Brush::Node>: formatter<string_view> {
368 // parse is inherited from formatter<string_view>.
369 template <typename FormatContext>
370 auto format(Brush::Node x, FormatContext& ctx) const {
371 return formatter<string_view>::format(x.get_name(), ctx);
372 }
373};
374
375
376
377#endif
holds variable type data.
Definition data.h:51
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
ostream & operator<<(ostream &os, DataType n)
std::underlying_type_t< NodeType > UnderlyingNodeType
Definition nodetype.h:114
NodeType
Definition nodetype.h:31
auto Isnt(DataType dt) -> bool
Definition node.h:43
auto IsWeighable() noexcept -> bool
Definition node.h:46
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:31
auto IsLeaf(NodeType nt) noexcept -> bool
Definition node.h:298
auto IsCommutative(NodeType nt) noexcept -> bool
Definition node.h:302
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:291
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
NodeType NT
Definition node.cpp:134
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
DataType DT
Definition types.h:188
auto IsDifferentiable(NodeType nt) noexcept -> bool
Definition node.h:312
static auto GetIndex(NodeType type) -> size_t
Definition nodetype.h:127
class holding the data for a node in a tree.
Definition node.h:84
float get_prob_keep() const
Definition node.h:256
bool center_op
whether to center the operator in pretty printing
Definition node.h:110
std::vector< DataType > arg_types
argument data types
Definition node.h:94
void set_signature()
Definition node.h:148
bool get_keep_split_feature() const
Definition node.h:265
DataType feature_type
feature type for terminals or splitting nodes
Definition node.h:284
DataType get_ret_type() const
Definition node.h:179
void set_keep_split_feature(bool keep)
Definition node.h:264
bool fixed
whether the node is replaceable. Weights are still optimized.
Definition node.h:101
size_t get_node_hash(bool include_const=true) const
Definition node.h:196
float get_prob_change() const
Definition node.h:254
NodeType node_type
the node type
Definition node.h:89
DataType get_feature_type() const
Definition node.h:262
auto operator<=(const Node &rhs) const noexcept -> bool
Definition node.h:236
Node(NodeType type, S signature, bool weighted=false, string feature_name="") noexcept
Constructor used by search space.
Definition node.h:134
DataType ret_type
return data type
Definition node.h:92
auto operator>(const Node &rhs) const noexcept -> bool
Definition node.h:241
void set_is_weighted(bool is_weighted)
Definition node.h:274
auto get_arg_types() const
Definition node.h:181
auto operator==(const Node &rhs) const noexcept -> bool
Definition node.h:213
void set_feature(string f)
Definition node.h:258
string get_feature() const
Definition node.h:259
std::size_t sig_hash
a hash of the signature
Definition node.h:96
bool get_is_weighted() const
Definition node.h:269
void init()
Definition node.h:157
void set_feature_type(DataType ft)
Definition node.h:261
std::size_t args_type() const
Definition node.h:180
float prob_change
chance of node being selected for variation
Definition node.h:105
auto operator>=(const Node &rhs) const noexcept -> bool
Definition node.h:246
string get_name(bool include_weight=true) const noexcept
gets a string version of the node for printing.
Definition node.cpp:20
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:107
bool is_weighted
whether this node is weighted (ignored in nodes that must have weights, such as meanLabel,...
Definition node.h:103
std::tuple< UnderlyingNodeType, size_t, bool, string, bool, int > HashTuple
tuple type for hashing
Definition node.h:115
bool keep_split_feature
fix the SplitBest feature when the node is fixed
Definition node.h:287
size_t get_arg_count() const
Definition node.h:182
string feature
feature name for terminals or splitting nodes
Definition node.h:281
auto operator!=(const Node &rhs) const noexcept -> bool
Definition node.h:224
string name
full name of the node, with types
Definition node.h:87
void set_prob_change(float w)
Definition node.h:255
Node()=default
auto operator<(const Node &rhs) const noexcept -> bool
Definition node.h:229
std::size_t sig_dual_hash
a hash of the dual of the signature (for NLS)
Definition node.h:98
string get_model(const vector< string > &) const noexcept
Definition node.cpp:52
std::size_t operator()(std::vector< Brush::DataType > const &vec) const
Definition node.h:71
std::size_t operator()(std::vector< uint32_t > const &vec) const
Definition node.h:64
auto format(Brush::Node x, FormatContext &ctx) const
Definition node.h:370