Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
node.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5Node class design heavily inspired by Operon, (c) Heal Research
6https://github.com/heal-research/operon/
7*/
8
9#ifndef NODE_H
10#define NODE_H
11
12#include "../data/data.h"
13#include "nodetype.h"
14#include "../util/utils.h"
15#include <iostream>
16// #include "nodes/base.h"
17// #include "nodes/dx.h"
18// #include "nodes/split.h"
19// #include "nodes/terminal.h"
21/*
22Node overhaul:
23
24- Incorporating new design principles, learning much from operon:
25 - make Node trivial, so that it is easily copied around.
26 - use Enums and maps to define node information. This kind of abandons the object oriented approach taken thus far, but it should make extensibility easier and performance better in the long run.
27 - Leverage ceres for parameter optimization. No more defining analytical
28 derivatives for every function. Let ceres do that.
29 - sidenote: not sure ceres can handle the data flow of split nodes.
30 need to figure out.
31 - this also suggests turning TimeSeries back into EigenSparse matrices.
32 - forget all the runtime node generation. It saves space at the cost of
33 unclear code. I might as well just define all the nodes that are available, plainly. At run-time this will be faster.
34 - keep an eye towards extensibility by defining a custom node registration function that works.
35
36*/
37using Brush::DataType;
39
40namespace Brush{
41
42template <DataType... T>
43inline auto Isnt(DataType dt) -> bool { return !((dt == T) || ...); }
44
45template<DataType DT>
62
64 std::size_t operator()(std::vector<uint32_t> const& vec) const {
65 std::size_t seed = vec.size();
66 for(auto& i : vec) {
67 seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
68 }
69 return seed;
70 }
71 std::size_t operator()(std::vector<Brush::DataType> const& vec) const {
72 std::size_t seed = vec.size();
73 for(auto& i : vec) {
74 seed ^= uint32_t(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
75 }
76 return seed;
77 }
78};
79
84struct Node {
85
87 string name;
93 bool fixed;
97 std::size_t sig_hash;
99 std::size_t sig_dual_hash;
103 std::vector<DataType> arg_types;
107 float W;
108 // /// @brief a node hash / unique ID for the node, except weights
109 // size_t node_hash;
111 using HashTuple = std::tuple<
112 UnderlyingNodeType, // node type
113 size_t, // sig_hash
114 bool, // is_weighted
115 string, // feature
116 bool, // fixed
117 int // rounded W
118 // float // prob_change
119 >;
120
121 // Node(){init();};
122 Node() = default;
123
124
130 template<typename S>
131 explicit Node(NodeType type, S signature, bool weighted=false, string feature_name="") noexcept
132 : node_type(type)
133 , name(NodeTypeName[type])
134 , ret_type(S::get_ret_type())
136 , sig_hash(S::hash())
137 , sig_dual_hash(S::Dual::hash())
140 {
141 init();
142 }
143
144 template<typename S>
146 {
147 ret_type = S::get_ret_type();
148 arg_types = S::get_arg_types();
149 sig_hash = S::hash();
150 sig_dual_hash = S::Dual::hash();
151 // set_node_hash();
152 }
153
154 void init(){
155
156 W = 1.0;
157 // set_node_hash();
158 fixed=false;
159 set_prob_change(1.0);
160
161 // cant weight an boolean terminal
162 if (!IsWeighable(this->ret_type))
163 this->is_weighted = false;
164 }
165
169 string get_name(bool include_weight=true) const noexcept;
170 string get_model(const vector<string>&) const noexcept;
171
172 // get return type and argument types.
173 inline DataType get_ret_type() const { return ret_type; };
174 inline std::size_t args_type() const { return sig_hash; };
175 inline auto get_arg_types() const { return arg_types; };
176 inline size_t get_arg_count() const { return arg_types.size(); };
177
178 // void set_node_hash(){
179 // node_hash = std::hash<HashTuple>{}(HashTuple{
180 // NodeTypes::GetIndex(node_type),
181 // sig_hash,
182 // is_weighted,
183 // feature,
184 // fixed,
185 // W,
186 // prob_change
187 // });
188 // // fmt::print("nodetype:{}; hash tuple:{}; node_hash={}\n", node_type, tmp, node_hash);
189 // }
190 size_t get_node_hash() const {
191 return std::hash<HashTuple>{}(HashTuple{
193 sig_hash,
195 feature,
196 fixed,
197 int(W*100)
198 });
199 }
201 //comparison operators
202 inline auto operator==(const Node& rhs) const noexcept -> bool
203 {
204 /* return CalculatedHashValue == rhs.CalculatedHashValue; */
205 return get_node_hash() == rhs.get_node_hash();
206 /* return (*this) == rhs; */
207 }
208
209 inline auto operator!=(const Node& rhs) const noexcept -> bool
210 {
211 return !((*this) == rhs);
212 }
213
214 inline auto operator<(const Node& rhs) const noexcept -> bool
215 {
216 /* return std::tie(HashValue, CalculatedHashValue) < std::tie(rhs.HashValue, rhs.CalculatedHashValue); */
217 return get_node_hash() < rhs.get_node_hash();
218 return (*this) < rhs;
219 }
220
221 inline auto operator<=(const Node& rhs) const noexcept -> bool
222 {
223 return ((*this) == rhs || (*this) < rhs);
224 }
225
226 inline auto operator>(const Node& rhs) const noexcept -> bool
227 {
228 return !((*this) <= rhs);
229 }
230
231 inline auto operator>=(const Node& rhs) const noexcept -> bool
232 {
233 return !((*this) < rhs);
234 }
235
237 // getters and setters
238 //TODO revisit
239 float get_prob_change() const { return fixed ? 0.0 : this->prob_change;};
240 void set_prob_change(float w){ this->prob_change = w;};
241 float get_prob_keep() const { return fixed ? 1.0 : 1.0-this->prob_change;};
242
243 inline void set_feature(string f){ feature = f; };
244 inline string get_feature() const { return feature; };
245
246 inline bool get_is_weighted() const {return this->is_weighted;};
247 inline void set_is_weighted(bool is_weighted){
248 // cant change the weight of a boolean terminal
249 if (IsWeighable(this->ret_type))
250 this->is_weighted = is_weighted;
251 };
252
253 private:
254
256 string feature;
257};
258
259template <NodeType... T>
260inline auto Is(NodeType nt) -> bool { return ((nt == T) || ...); }
261
262template <NodeType... T>
263inline auto Isnt(NodeType nt) -> bool { return !((nt == T) || ...); }
264
268
269inline auto IsCommutative(NodeType nt) noexcept -> bool {
270 return Is<NodeType::Add,
274 >(nt);
275}
276
277inline auto IsDifferentiable(NodeType nt) noexcept -> bool {
278 return Isnt<
288 >(nt);
289}
290template<NodeType NT>
291inline auto IsWeighable() noexcept -> bool {
292 return Isnt<
304 >(NT);
305}
321
322ostream& operator<<(ostream& os, const Node& n);
323ostream& operator<<(ostream& os, const NodeType& nt);
324
325
326
327void from_json(const json &j, Node& p);
328void to_json(json& j, const Node& p);
329} // namespace Brush
330
331// format overload for Nodes
332template <> struct fmt::formatter<Brush::Node>: formatter<string_view> {
333 // parse is inherited from formatter<string_view>.
334 template <typename FormatContext>
336 return formatter<string_view>::format(x.get_name(), ctx);
337 }
338};
339
340
341
342#endif
void bind_engine(py::module &m, string name)
holds variable type data.
Definition data.h:51
< nsga2 selection operator for getting the front
Definition data.cpp:12
ostream & operator<<(ostream &os, DataType n)
std::underlying_type_t< NodeType > UnderlyingNodeType
Definition nodetype.h:112
NodeType
Definition nodetype.h:31
auto Isnt(DataType dt) -> bool
Definition node.h:43
auto IsWeighable() noexcept -> bool
Definition node.h:46
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:24
auto IsLeaf(NodeType nt) noexcept -> bool
Definition node.h:265
auto IsCommutative(NodeType nt) noexcept -> bool
Definition node.h:269
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:260
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
NodeType NT
Definition node.cpp:129
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
DataType DT
Definition types.h:188
auto IsDifferentiable(NodeType nt) noexcept -> bool
Definition node.h:277
static auto GetIndex(NodeType type) -> size_t
Definition nodetype.h:123
class holding the data for a node in a tree.
Definition node.h:84
float get_prob_keep() const
Definition node.h:241
bool center_op
whether to center the operator in pretty printing
Definition node.h:89
size_t get_node_hash() const
Definition node.h:190
std::vector< DataType > arg_types
argument data types
Definition node.h:103
void set_signature()
Definition node.h:145
DataType get_ret_type() const
Definition node.h:173
bool fixed
whether node is modifiable
Definition node.h:93
float get_prob_change() const
Definition node.h:239
NodeType node_type
the node type
Definition node.h:95
auto operator<=(const Node &rhs) const noexcept -> bool
Definition node.h:221
Node(NodeType type, S signature, bool weighted=false, string feature_name="") noexcept
Constructor used by search space.
Definition node.h:131
DataType ret_type
return data type
Definition node.h:101
auto operator>(const Node &rhs) const noexcept -> bool
Definition node.h:226
void set_is_weighted(bool is_weighted)
Definition node.h:247
auto get_arg_types() const
Definition node.h:175
auto operator==(const Node &rhs) const noexcept -> bool
Definition node.h:202
void set_feature(string f)
Definition node.h:243
string get_feature() const
Definition node.h:244
std::size_t sig_hash
a hash of the signature
Definition node.h:97
bool get_is_weighted() const
Definition node.h:246
void init()
Definition node.h:154
std::size_t args_type() const
Definition node.h:174
float prob_change
chance of node being selected for variation
Definition node.h:91
auto operator>=(const Node &rhs) const noexcept -> bool
Definition node.h:231
string get_name(bool include_weight=true) const noexcept
gets a string version of the node for printing.
Definition node.cpp:20
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:107
bool is_weighted
whether this node is weighted
Definition node.h:105
std::tuple< UnderlyingNodeType, size_t, bool, string, bool, int > HashTuple
tuple type for hashing
Definition node.h:111
size_t get_arg_count() const
Definition node.h:176
string feature
feature name for terminals or splitting nodes
Definition node.h:256
auto operator!=(const Node &rhs) const noexcept -> bool
Definition node.h:209
string name
full name of the node, with types
Definition node.h:87
void set_prob_change(float w)
Definition node.h:240
Node()=default
auto operator<(const Node &rhs) const noexcept -> bool
Definition node.h:214
std::size_t sig_dual_hash
a hash of the dual of the signature (for NLS)
Definition node.h:99
string get_model(const vector< string > &) const noexcept
Definition node.cpp:50
std::size_t operator()(std::vector< Brush::DataType > const &vec) const
Definition node.h:71
std::size_t operator()(std::vector< uint32_t > const &vec) const
Definition node.h:64
auto format(Brush::Node x, FormatContext &ctx) const
Definition node.h:335