Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
node.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5Node class design heavily inspired by Operon, (c) Heal Research
6https://github.com/heal-research/operon/
7*/
8
9#ifndef NODE_H
10#define NODE_H
11
12#include "../data/data.h"
13#include "nodetype.h"
14#include "../util/utils.h"
15#include <iostream>
16// #include "nodes/base.h"
17// #include "nodes/dx.h"
18// #include "nodes/split.h"
19// #include "nodes/terminal.h"
21/*
22Node overhaul:
23
24- Incorporating new design principles, learning much from operon:
25 - make Node trivial, so that it is easily copied around.
26 - use Enums and maps to define node information. This kind of abandons the object oriented approach taken thus far, but it should make extensibility easier and performance better in the long run.
27 - Leverage ceres for parameter optimization. No more defining analytical
28 derivatives for every function. Let ceres do that.
29 - sidenote: not sure ceres can handle the data flow of split nodes.
30 need to figure out.
31 - this also suggests turning TimeSeries back into EigenSparse matrices.
32 - forget all the runtime node generation. It saves space at the cost of
33 unclear code. I might as well just define all the nodes that are available, plainly. At run-time this will be faster.
34 - keep an eye towards extensibility by defining a custom node registration function that works.
35
36*/
37using Brush::DataType;
39
40namespace Brush{
41
42template <DataType... T>
43inline auto Isnt(DataType dt) -> bool { return !((dt == T) || ...); }
44
45template<DataType DT>
46inline auto IsWeighable() noexcept -> bool {
52 >(DT);
53}
54inline auto IsWeighable(DataType dt) noexcept -> bool {
60 >(dt);
61}
62
64 std::size_t operator()(std::vector<uint32_t> const& vec) const {
65 std::size_t seed = vec.size();
66 for(auto& i : vec) {
67 seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
68 }
69 return seed;
70 }
71 std::size_t operator()(std::vector<Brush::DataType> const& vec) const {
72 std::size_t seed = vec.size();
73 for(auto& i : vec) {
74 seed ^= uint32_t(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
75 }
76 return seed;
77 }
78};
79
84struct Node {
85
87 string name;
93 bool fixed;
97 std::size_t sig_hash;
99 std::size_t sig_dual_hash;
103 std::vector<DataType> arg_types;
107 float W;
108 // /// @brief a node hash / unique ID for the node, except weights
109 // size_t node_hash;
111 using HashTuple = std::tuple<
112 UnderlyingNodeType, // node type
113 size_t, // sig_hash
114 bool, // is_weighted
115 string, // feature
116 bool, // fixed
117 int // rounded W
118 // float // prob_change
119 >;
120
121 // Node(){init();};
122 Node() = default;
123
124
130 template<typename S>
131 explicit Node(NodeType type, S signature, bool weighted=false, string feature_name="") noexcept
132 : node_type(type)
133 , name(NodeTypeName[type])
134 , ret_type(S::get_ret_type())
136 , sig_hash(S::hash())
137 , sig_dual_hash(S::Dual::hash())
138 , is_weighted(weighted)
139 , feature(feature_name)
140 {
141 init();
142 }
143
144 template<typename S>
146 {
147 ret_type = S::get_ret_type();
148 arg_types = S::get_arg_types();
149 sig_hash = S::hash();
150 sig_dual_hash = S::Dual::hash();
151 // set_node_hash();
152 }
153
154 void init(){
155
156 // starting weights with neutral element of the operation. offsetsum
157 // is the only node that does not multiply the weight --- instead, it adds it
158 W = (node_type == NodeType::OffsetSum) ? 0.0 : 1.0;
159
160 // set_node_hash();
161 fixed=false;
162 set_prob_change(1.0);
163
164 // TODO: confirm that this is really necessary (intializing this variable) and transform this line into a ternary if so
165 // cant weight an boolean terminal
166 if (!IsWeighable(this->ret_type))
167 this->is_weighted = false;
168 else
169 this->is_weighted = true;
170 }
171
175 string get_name(bool include_weight=true) const noexcept;
176 string get_model(const vector<string>&) const noexcept;
177
178 // get return type and argument types.
179 inline DataType get_ret_type() const { return ret_type; };
180 inline std::size_t args_type() const { return sig_hash; };
181 inline auto get_arg_types() const { return arg_types; };
182 inline size_t get_arg_count() const { return arg_types.size(); };
183
184 // void set_node_hash(){
185 // node_hash = std::hash<HashTuple>{}(HashTuple{
186 // NodeTypes::GetIndex(node_type),
187 // sig_hash,
188 // is_weighted,
189 // feature,
190 // fixed,
191 // W,
192 // prob_change
193 // });
194 // // fmt::print("nodetype:{}; hash tuple:{}; node_hash={}\n", node_type, tmp, node_hash);
195 // }
196 size_t get_node_hash() const {
197 return std::hash<HashTuple>{}(HashTuple{
199 sig_hash,
201 feature,
202 fixed,
203 int(W*100)
204 });
205 }
206
207 //comparison operators
208 inline auto operator==(const Node& rhs) const noexcept -> bool
209 {
210 /* return CalculatedHashValue == rhs.CalculatedHashValue; */
211 return get_node_hash() == rhs.get_node_hash();
212 /* return (*this) == rhs; */
213 }
214
215 inline auto operator!=(const Node& rhs) const noexcept -> bool
216 {
217 return !((*this) == rhs);
218 }
219
220 inline auto operator<(const Node& rhs) const noexcept -> bool
221 {
222 /* return std::tie(HashValue, CalculatedHashValue) < std::tie(rhs.HashValue, rhs.CalculatedHashValue); */
223 return get_node_hash() < rhs.get_node_hash();
224 return (*this) < rhs;
225 }
226
227 inline auto operator<=(const Node& rhs) const noexcept -> bool
228 {
229 return ((*this) == rhs || (*this) < rhs);
230 }
231
232 inline auto operator>(const Node& rhs) const noexcept -> bool
233 {
234 return !((*this) <= rhs);
235 }
236
237 inline auto operator>=(const Node& rhs) const noexcept -> bool
238 {
239 return !((*this) < rhs);
240 }
241
243 // getters and setters
244 //TODO revisit
245 float get_prob_change() const { return fixed ? 0.0 : this->prob_change;};
246 void set_prob_change(float w){ this->prob_change = w;};
247 float get_prob_keep() const { return fixed ? 1.0 : 1.0-this->prob_change;};
248
249 inline void set_feature(string f){ feature = f; };
250 inline string get_feature() const { return feature; };
251
252 inline void set_feature_type(DataType ft){ feature_type = ft; };
253 inline DataType get_feature_type() const { return feature_type; };
254
255 inline bool get_is_weighted() const {return this->is_weighted;};
256 inline void set_is_weighted(bool is_weighted){
257 // cant change the weight of a boolean terminal
258 if (IsWeighable(this->ret_type))
259 this->is_weighted = is_weighted;
260 };
261
262 private:
264 string feature;
265
268
269};
270
271template <NodeType... T>
272inline auto Is(NodeType nt) -> bool { return ((nt == T) || ...); }
273
274template <NodeType... T>
275inline auto Isnt(NodeType nt) -> bool { return !((nt == T) || ...); }
276
277// TODO: I think there are places where I can replace some logic with IsLeaf --> Check that.
278// TODO: create IsConstant, and add Constant and MeanLabel to it.
279inline auto IsLeaf(NodeType nt) noexcept -> bool {
281}
282
283inline auto IsCommutative(NodeType nt) noexcept -> bool {
284 return Is<NodeType::Add,
288 >(nt);
289}
290
291inline auto IsDifferentiable(NodeType nt) noexcept -> bool {
292 return Isnt<
301 NodeType::Not // TODO: should I include OffsetSum here? If I do so, then I should change the logic in the optimizer to not optimize the weight of OffsetSum nodes.
302 >(nt);
303}
304template<NodeType NT>
305inline auto IsWeighable() noexcept -> bool {
306 return Isnt<
319 >(NT);
320}
337
338ostream& operator<<(ostream& os, const Node& n);
339ostream& operator<<(ostream& os, const NodeType& nt);
340
341
342
343void from_json(const json &j, Node& p);
344void to_json(json& j, const Node& p);
345} // namespace Brush
346
347// format overload for Nodes
348template <> struct fmt::formatter<Brush::Node>: formatter<string_view> {
349 // parse is inherited from formatter<string_view>.
350 template <typename FormatContext>
351 auto format(Brush::Node x, FormatContext& ctx) const {
352 return formatter<string_view>::format(x.get_name(), ctx);
353 }
354};
355
356
357
358#endif
holds variable type data.
Definition data.h:51
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
ostream & operator<<(ostream &os, DataType n)
std::underlying_type_t< NodeType > UnderlyingNodeType
Definition nodetype.h:114
NodeType
Definition nodetype.h:31
auto Isnt(DataType dt) -> bool
Definition node.h:43
auto IsWeighable() noexcept -> bool
Definition node.h:46
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:25
auto IsLeaf(NodeType nt) noexcept -> bool
Definition node.h:279
auto IsCommutative(NodeType nt) noexcept -> bool
Definition node.h:283
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:272
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
NodeType NT
Definition node.cpp:134
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
DataType DT
Definition types.h:188
auto IsDifferentiable(NodeType nt) noexcept -> bool
Definition node.h:291
static auto GetIndex(NodeType type) -> size_t
Definition nodetype.h:126
class holding the data for a node in a tree.
Definition node.h:84
float get_prob_keep() const
Definition node.h:247
bool center_op
whether to center the operator in pretty printing
Definition node.h:89
size_t get_node_hash() const
Definition node.h:196
std::vector< DataType > arg_types
argument data types
Definition node.h:103
void set_signature()
Definition node.h:145
DataType feature_type
feature type for terminals or splitting nodes
Definition node.h:267
DataType get_ret_type() const
Definition node.h:179
bool fixed
whether node is modifiable
Definition node.h:93
float get_prob_change() const
Definition node.h:245
NodeType node_type
the node type
Definition node.h:95
DataType get_feature_type() const
Definition node.h:253
auto operator<=(const Node &rhs) const noexcept -> bool
Definition node.h:227
Node(NodeType type, S signature, bool weighted=false, string feature_name="") noexcept
Constructor used by search space.
Definition node.h:131
DataType ret_type
return data type
Definition node.h:101
auto operator>(const Node &rhs) const noexcept -> bool
Definition node.h:232
void set_is_weighted(bool is_weighted)
Definition node.h:256
auto get_arg_types() const
Definition node.h:181
auto operator==(const Node &rhs) const noexcept -> bool
Definition node.h:208
void set_feature(string f)
Definition node.h:249
string get_feature() const
Definition node.h:250
std::size_t sig_hash
a hash of the signature
Definition node.h:97
bool get_is_weighted() const
Definition node.h:255
void init()
Definition node.h:154
void set_feature_type(DataType ft)
Definition node.h:252
std::size_t args_type() const
Definition node.h:180
float prob_change
chance of node being selected for variation
Definition node.h:91
auto operator>=(const Node &rhs) const noexcept -> bool
Definition node.h:237
string get_name(bool include_weight=true) const noexcept
gets a string version of the node for printing.
Definition node.cpp:20
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:107
bool is_weighted
whether this node is weighted
Definition node.h:105
std::tuple< UnderlyingNodeType, size_t, bool, string, bool, int > HashTuple
tuple type for hashing
Definition node.h:111
size_t get_arg_count() const
Definition node.h:182
string feature
feature name for terminals or splitting nodes
Definition node.h:264
auto operator!=(const Node &rhs) const noexcept -> bool
Definition node.h:215
string name
full name of the node, with types
Definition node.h:87
void set_prob_change(float w)
Definition node.h:246
Node()=default
auto operator<(const Node &rhs) const noexcept -> bool
Definition node.h:220
std::size_t sig_dual_hash
a hash of the dual of the signature (for NLS)
Definition node.h:99
string get_model(const vector< string > &) const noexcept
Definition node.cpp:53
std::size_t operator()(std::vector< Brush::DataType > const &vec) const
Definition node.h:71
std::size_t operator()(std::vector< uint32_t > const &vec) const
Definition node.h:64
auto format(Brush::Node x, FormatContext &ctx) const
Definition node.h:351