Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
nodetype.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5This section adapted heavily from Operon: https://github.com/heal-research/operon/
6*/
7#ifndef NODEMAP_H
8#define NODEMAP_H
9// external includes
10#include <bitset>
11#include <type_traits>
12#include <utility>
13//internal includes
14#include "../init.h"
15#include "../data/data.h"
16#include "../util/utils.h"
17#include "../util/rnd.h"
18using std::vector;
19using namespace Brush;
20/* using Brush::NodeType; */
21/* using Brush::ExecType; */
22using std::tuple;
23using std::array;
24using Brush::DataType;
28
29namespace Brush {
30
31enum class NodeType : uint64_t { // Each node type must have a complexity
32 // in operator_complexities@tree_node.cpp
33 // Unary
34 Abs = 1UL << 0UL,
35 Acos = 1UL << 1UL,
36 Asin = 1UL << 2UL,
37 Atan = 1UL << 3UL,
38 Cos = 1UL << 4UL,
39 Cosh = 1UL << 5UL,
40 Sin = 1UL << 6UL,
41 Sinh = 1UL << 7UL,
42 Tan = 1UL << 8UL,
43 Tanh = 1UL << 9UL,
44 Ceil = 1UL << 10UL,
45 Floor = 1UL << 11UL,
46 Exp = 1UL << 12UL,
47 Log = 1UL << 13UL,
48 Logabs = 1UL << 14UL,
49 Log1p = 1UL << 15UL,
50 Sqrt = 1UL << 16UL,
51 Sqrtabs = 1UL << 17UL,
52 Square = 1UL << 18UL,
53 Logistic = 1UL << 19UL, // used as root for classification trees
54
55 // timing masks
56 Before = 1UL << 20UL,
57 After = 1UL << 21UL,
58 During = 1UL << 22UL,
59
60 // Reducers
61 Min = 1UL << 23UL,
62 Max = 1UL << 24UL,
63 Mean = 1UL << 25UL,
64 Median = 1UL << 26UL,
65 Prod = 1UL << 27UL,
66 Sum = 1UL << 28UL,
67 OffsetSum = 1UL << 29UL, // Sum with weight as one of its arguments
68
69 // Transformers
70 Softmax = 1UL << 30UL, // used as root for multiclf trees
71
72 // Binary
73 Add = 1UL << 31UL,
74 Sub = 1UL << 32UL,
75 Mul = 1UL << 33UL,
76 Div = 1UL << 34UL,
77 Pow = 1UL << 35UL,
78
79 //split
80 SplitBest = 1UL << 36UL,
81 SplitOn = 1UL << 37UL,
82
83 // these ones change type
84 /* Equals = 1UL << 39UL, */
85 /* LessThan = 1UL << 40UL, */
86 /* GreaterThan = 1UL << 41UL, */
87 /* Leq = 1UL << 42UL, */
88 /* Geq = 1UL << 43UL, */
89
90 // boolean
91 And = 1UL << 38UL,
92 Or = 1UL << 39UL,
93 Not = 1UL << 40UL,
94 // Xor = 1UL << 39UL,
95
96 // leaves (must be the last ones in this enum)
97 MeanLabel = 1UL << 41UL,
98 Constant = 1UL << 42UL,
99 Terminal = 1UL << 43UL,
100
101 // TODO: implement operators below and move them before leaves
102 ArgMax = 1UL << 44UL,
103 Count = 1UL << 45UL,
104
105 // custom
106 CustomUnaryOp = 1UL << 46UL,
107 CustomBinaryOp = 1UL << 47UL,
108 CustomSplit = 1UL << 48UL
109};
110
111
112using UnderlyingNodeType = std::underlying_type_t<NodeType>;
113struct NodeTypes {
114 // magic number keeping track of the number of different node types
115
116 // index of last available node visible to search_space
117 static constexpr size_t Count = 44;
118
119 // subtracting leaves (leaving just the ops into this)
120 static constexpr size_t OpCount = Count-3;
121
122 // returns the index of the given type in the NodeType enum
123 static auto GetIndex(NodeType type) -> size_t
124 {
125 return std::bitset<Count>(static_cast<UnderlyingNodeType>(type)).count();
126 }
127};
128
129
130inline constexpr auto operator&(NodeType lhs, NodeType rhs) -> NodeType { return static_cast<NodeType>(static_cast<UnderlyingNodeType>(lhs) & static_cast<UnderlyingNodeType>(rhs)); }
131inline constexpr auto operator|(NodeType lhs, NodeType rhs) -> NodeType { return static_cast<NodeType>(static_cast<UnderlyingNodeType>(lhs) | static_cast<UnderlyingNodeType>(rhs)); }
132inline constexpr auto operator^(NodeType lhs, NodeType rhs) -> NodeType { return static_cast<NodeType>(static_cast<UnderlyingNodeType>(lhs) ^ static_cast<UnderlyingNodeType>(rhs)); }
133inline constexpr auto operator~(NodeType lhs) -> NodeType { return static_cast<NodeType>(~static_cast<UnderlyingNodeType>(lhs)); }
135{
136 lhs = lhs & rhs;
137 return lhs;
138}
140{
141 lhs = lhs | rhs;
142 return lhs;
143}
145{
146 lhs = lhs ^ rhs;
147 return lhs;
148}
149
150
151
152extern std::map<std::string, NodeType> NodeNameType;
153extern std::map<NodeType,std::string> NodeTypeName;
154
155#ifndef DOXYGEN_SKIP
156// map NodeType values to JSON as strings
158 //arithmetic
159 {NodeType::Add,"Add" },
160 {NodeType::Sub,"Sub" },
161 {NodeType::Mul,"Mul" },
162 {NodeType::Div,"Div" },
163 /* {NodeType::Aq,"Aq" }, */
164 {NodeType::Abs,"Abs" },
165 {NodeType::Acos,"Acos" },
166 {NodeType::Asin,"Asin" },
167 {NodeType::Atan,"Atan" },
168 {NodeType::Cos,"Cos" },
169 {NodeType::Cosh,"Cosh" },
170 {NodeType::Sin,"Sin" },
171 {NodeType::Sinh,"Sinh" },
172 {NodeType::Tan,"Tan" },
173 {NodeType::Tanh,"Tanh" },
174 {NodeType::Ceil,"Ceil" },
175 {NodeType::Floor,"Floor" },
176 {NodeType::Exp,"Exp" },
177 {NodeType::Log,"Log" },
178 {NodeType::Logabs,"Logabs" },
179 {NodeType::Log1p,"Log1p" },
180 {NodeType::Sqrt,"Sqrt" },
181 {NodeType::Sqrtabs,"Sqrtabs" },
182 {NodeType::Square,"Square" },
183 {NodeType::Pow,"Pow" },
184 {NodeType::Logistic,"Logistic" },
185
186 // logic
187 {NodeType::And,"And" },
188 {NodeType::Or,"Or" },
189 {NodeType::Not,"Not" },
190 // {NodeType::Xor,"Xor" },
191
192 // decision (same)
193 /* {NodeType::Equals,"Equals" }, */
194 /* {NodeType::LessThan,"LessThan" }, */
195 /* {NodeType::GreaterThan,"GreaterThan" }, */
196 /* {NodeType::Leq,"Leq" }, */
197 /* {NodeType::Geq,"Geq" }, */
198
199 // reductions
200 {NodeType::Min,"Min" },
201 {NodeType::Max,"Max" },
202 {NodeType::Mean,"Mean" },
203 {NodeType::Median,"Median" },
204 {NodeType::Count,"Count" },
205 {NodeType::Sum,"Sum" },
206 {NodeType::OffsetSum,"OffsetSum" },
207 {NodeType::Prod,"Prod" },
208 {NodeType::ArgMax,"ArgMax" },
209
210 // transforms
211 {NodeType::Softmax,"Softmax" },
212
213 // timing masks
214 {NodeType::Before,"Before" },
215 {NodeType::After,"After" },
216 {NodeType::During,"During" },
217
218 //split
219 {NodeType::SplitBest,"SplitBest" },
220 {NodeType::SplitOn,"SplitOn" },
221
222 // leaves
223 {NodeType::MeanLabel,"MeanLabel" },
224 {NodeType::Constant,"Constant" },
225 {NodeType::Terminal,"Terminal" },
226
227 // custom
228 {NodeType::CustomUnaryOp,"CustomUnaryOp" },
229 {NodeType::CustomBinaryOp,"CustomBinaryOp" },
230 {NodeType::CustomSplit,"CustomSplit" }
231})
232#endif
233
234} // namespace Brush
235
236// format overload for NodeTypes
237template <> struct fmt::formatter<Brush::NodeType>: formatter<string_view> {
238 // parse is inherited from formatter<string_view>.
239 template <typename FormatContext>
241 return formatter<string_view>::format(Brush::NodeTypeName.at(x), ctx);
242 }
243};
245template <NodeType T, NodeType... Ts>
246struct is_in
247{
248 static constexpr bool value = ((T == Ts) || ...);
249};
250
251template<NodeType T, NodeType... Ts>
252static constexpr bool is_in_v = is_in<T, Ts...>::value;
253
254using NT = NodeType;
255// NodeType concepts
256template<NT nt>
257static constexpr bool UnaryOp = is_in_v<nt,
258 NT::Abs,
259 NT::Acos,
260 NT::Asin,
261 NT::Atan,
262 NT::Cos,
263 NT::Cosh,
264 NT::Sin,
265 NT::Sinh,
266 NT::Tan,
267 NT::Tanh,
268 NT::Ceil,
269 NT::Floor,
270 NT::Exp,
271 NT::Log,
272 NT::Logabs,
273 NT::Log1p,
274 NT::Sqrt,
275 NT::Sqrtabs,
276 NT::Square,
277 NT::Logistic
278 // NT::Not
279>;
280
281template<NT nt>
282static constexpr bool BinaryOp = is_in_v<nt,
283 NT::Add,
284 NT::Sub,
285 NT::Mul,
286 NT::Div,
287 NT::Pow
288>;
289
290template<NT nt>
291static constexpr bool AssociativeBinaryOp = is_in_v<nt,
292 NT::Add,
293 NT::Mul
294>;
295
296template<NT nt>
297static constexpr bool NaryOp = is_in_v<nt,
298 NT::Min,
299 NT::Max,
300 NT::Mean,
301 NT::Median,
302 NT::Sum,
303 NT::OffsetSum,
304 NT::Prod,
305 NT::Softmax
306>;
307
308// // TODO: make this work
309// template<typename NT, size_t ArgCount>
310// concept Transformer = requires(NT n, size_t ArgCount)
311// {
312// UnaryOp<n> && ArgCount > 1;
313// }
314
315// template<typename NT, size_t ArgCount>
316// concept Reducer = requires(NT n, size_t ArgCount)
317// {
318// BinaryOp<n> && ArgCount > 2;
319// }
320
321
322
323#include "signatures.h"
324#endif
void bind_engine(py::module &m, string name)
< nsga2 selection operator for getting the front
Definition data.cpp:12
std::underlying_type_t< NodeType > UnderlyingNodeType
Definition nodetype.h:112
constexpr auto operator|(NodeType lhs, NodeType rhs) -> NodeType
Definition nodetype.h:131
NodeType
Definition nodetype.h:31
constexpr auto operator~(NodeType lhs) -> NodeType
Definition nodetype.h:133
auto operator|=(NodeType &lhs, NodeType rhs) -> NodeType &
Definition nodetype.h:139
constexpr auto operator^(NodeType lhs, NodeType rhs) -> NodeType
Definition nodetype.h:132
DataType
data types.
Definition types.h:143
auto operator&=(NodeType &lhs, NodeType rhs) -> NodeType &
Definition nodetype.h:134
std::map< std::string, NodeType > NodeNameType
Definition nodetype.cpp:5
auto operator^=(NodeType &lhs, NodeType rhs) -> NodeType &
Definition nodetype.h:144
constexpr auto operator&(NodeType lhs, NodeType rhs) -> NodeType
Definition nodetype.h:130
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
static constexpr bool is_in_v
Definition nodetype.h:252
static constexpr bool NaryOp
Definition nodetype.h:297
static constexpr bool UnaryOp
Definition nodetype.h:257
static constexpr bool AssociativeBinaryOp
Definition nodetype.h:291
static constexpr bool BinaryOp
Definition nodetype.h:282
Stores time series data and implements operators over it.
Definition timeseries.h:26
static constexpr size_t Count
Definition nodetype.h:117
static auto GetIndex(NodeType type) -> size_t
Definition nodetype.h:123
static constexpr size_t OpCount
Definition nodetype.h:120
auto format(Brush::NodeType x, FormatContext &ctx) const
Definition nodetype.h:240
static constexpr bool value
Definition nodetype.h:248