Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
nodetype.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5This section adapted heavily from Operon: https://github.com/heal-research/operon/
6*/
7#ifndef NODEMAP_H
8#define NODEMAP_H
9// external includes
10#include <bitset>
11#include <type_traits>
12#include <utility>
13//internal includes
14#include "../init.h"
15#include "../data/data.h"
16#include "../util/utils.h"
17#include "../util/rnd.h"
18using std::vector;
19using namespace Brush;
20/* using Brush::NodeType; */
21/* using Brush::ExecType; */
22using std::tuple;
23using std::array;
24using Brush::DataType;
28
29namespace Brush {
30
31enum class NodeType : uint64_t { // Each node type must have a complexity
32 // in operator_complexities@tree_node.cpp
33 // Unary
34 Abs = 1UL << 0UL,
35 Acos = 1UL << 1UL,
36 Asin = 1UL << 2UL,
37 Atan = 1UL << 3UL,
38 Cos = 1UL << 4UL,
39 Cosh = 1UL << 5UL,
40 Sin = 1UL << 6UL,
41 Sinh = 1UL << 7UL,
42 Tan = 1UL << 8UL,
43 Tanh = 1UL << 9UL,
44 Ceil = 1UL << 10UL,
45 Floor = 1UL << 11UL,
46 Exp = 1UL << 12UL,
47 Log = 1UL << 13UL,
48 Logabs = 1UL << 14UL,
49 Log1p = 1UL << 15UL,
50 Sqrt = 1UL << 16UL,
51 Sqrtabs = 1UL << 17UL,
52 Square = 1UL << 18UL,
53 Logistic = 1UL << 19UL, // used as root for classification trees
54
55 // timing masks
56 Before = 1UL << 20UL,
57 After = 1UL << 21UL,
58 During = 1UL << 22UL,
59
60 // Reducers
61 Min = 1UL << 23UL,
62 Max = 1UL << 24UL,
63 Mean = 1UL << 25UL,
64 Median = 1UL << 26UL,
65 Prod = 1UL << 27UL,
66 Sum = 1UL << 28UL,
67 OffsetSum = 1UL << 29UL, // Sum with weight as one of its arguments
68
69 // Transformers
70 Softmax = 1UL << 30UL, // used as root for multiclf trees
71
72 // Binary
73 Add = 1UL << 31UL,
74 Sub = 1UL << 32UL,
75 Mul = 1UL << 33UL,
76 Div = 1UL << 34UL,
77 Pow = 1UL << 35UL,
78
79 //split
80 SplitBest = 1UL << 36UL,
81 SplitOn = 1UL << 37UL,
82
83 // these ones change type
84 /* Equals = 1UL << 39UL, */
85 /* LessThan = 1UL << 40UL, */
86 /* GreaterThan = 1UL << 41UL, */
87 /* Leq = 1UL << 42UL, */
88 /* Geq = 1UL << 43UL, */
89
90 // boolean
91 And = 1UL << 38UL,
92 Or = 1UL << 39UL,
93 Not = 1UL << 40UL,
94 // Xor = 1UL << 39UL,
95
96 // leaves (must be the last ones in this enum)
97 MeanLabel = 1UL << 41UL,
98 Constant = 1UL << 42UL,
99 Terminal = 1UL << 43UL,
100
101 // TODO: implement operators below and move them before leaves
102 ArgMax = 1UL << 44UL,
103 // count the number of elements in an array. Should be the last element in the enum
104 Count = 1UL << 45UL,
105
106 // // custom
107 CustomUnaryOp = 1UL << 46UL,
108 CustomBinaryOp = 1UL << 47UL,
109 CustomSplit = 1UL << 48UL
110
111};
112
113
114using UnderlyingNodeType = std::underlying_type_t<NodeType>;
115struct NodeTypes {
116 // magic number keeping track of the number of different node types
117
118 // index of last available node visible to search_space.
119 // It must match the highest bit used in the enum
120 static constexpr size_t Count = 44;
121
122 // subtracting leaves (leaving just the ops into this)
123 static constexpr size_t OpCount = Count-3;
124
125 // returns the index of the given type in the NodeType enum
126 static auto GetIndex(NodeType type) -> size_t
127 {
128 // Chad G. Pete did this
129 UnderlyingNodeType utype = static_cast<UnderlyingNodeType>(type);
130 size_t result = 0;
131 while (utype >>= 1) ++result;
132
133 return utype ? result + 1 : 0;
134 }
135};
136
137
138inline constexpr auto operator&(NodeType lhs, NodeType rhs) -> NodeType { return static_cast<NodeType>(static_cast<UnderlyingNodeType>(lhs) & static_cast<UnderlyingNodeType>(rhs)); }
139inline constexpr auto operator|(NodeType lhs, NodeType rhs) -> NodeType { return static_cast<NodeType>(static_cast<UnderlyingNodeType>(lhs) | static_cast<UnderlyingNodeType>(rhs)); }
140inline constexpr auto operator^(NodeType lhs, NodeType rhs) -> NodeType { return static_cast<NodeType>(static_cast<UnderlyingNodeType>(lhs) ^ static_cast<UnderlyingNodeType>(rhs)); }
141inline constexpr auto operator~(NodeType lhs) -> NodeType { return static_cast<NodeType>(~static_cast<UnderlyingNodeType>(lhs)); }
142inline auto operator&=(NodeType& lhs, NodeType rhs) -> NodeType&
143{
144 lhs = lhs & rhs;
145 return lhs;
146}
147inline auto operator|=(NodeType& lhs, NodeType rhs) -> NodeType&
148{
149 lhs = lhs | rhs;
150 return lhs;
151}
152inline auto operator^=(NodeType& lhs, NodeType rhs) -> NodeType&
153{
154 lhs = lhs ^ rhs;
155 return lhs;
156}
157
158
159
160extern std::map<std::string, NodeType> NodeNameType;
161extern std::map<NodeType,std::string> NodeTypeName;
162
163#ifndef DOXYGEN_SKIP
164// map NodeType values to JSON as strings
165NLOHMANN_JSON_SERIALIZE_ENUM( NodeType, {
166 //arithmetic
167 {NodeType::Add,"Add" },
168 {NodeType::Sub,"Sub" },
169 {NodeType::Mul,"Mul" },
170 {NodeType::Div,"Div" },
171 /* {NodeType::Aq,"Aq" }, */
172 {NodeType::Abs,"Abs" },
173 {NodeType::Acos,"Acos" },
174 {NodeType::Asin,"Asin" },
175 {NodeType::Atan,"Atan" },
176 {NodeType::Cos,"Cos" },
177 {NodeType::Cosh,"Cosh" },
178 {NodeType::Sin,"Sin" },
179 {NodeType::Sinh,"Sinh" },
180 {NodeType::Tan,"Tan" },
181 {NodeType::Tanh,"Tanh" },
182 {NodeType::Ceil,"Ceil" },
183 {NodeType::Floor,"Floor" },
184 {NodeType::Exp,"Exp" },
185 {NodeType::Log,"Log" },
186 {NodeType::Logabs,"Logabs" },
187 {NodeType::Log1p,"Log1p" },
188 {NodeType::Sqrt,"Sqrt" },
189 {NodeType::Sqrtabs,"Sqrtabs" },
190 {NodeType::Square,"Square" },
191 {NodeType::Pow,"Pow" },
192 {NodeType::Logistic,"Logistic" },
193
194 // logic
195 {NodeType::And,"And" },
196 {NodeType::Or,"Or" },
197 {NodeType::Not,"Not" },
198 // {NodeType::Xor,"Xor" },
199
200 // decision (same)
201 /* {NodeType::Equals,"Equals" }, */
202 /* {NodeType::LessThan,"LessThan" }, */
203 /* {NodeType::GreaterThan,"GreaterThan" }, */
204 /* {NodeType::Leq,"Leq" }, */
205 /* {NodeType::Geq,"Geq" }, */
206
207 // reductions
208 {NodeType::Min,"Min" },
209 {NodeType::Max,"Max" },
210 {NodeType::Mean,"Mean" },
211 {NodeType::Median,"Median" },
212 {NodeType::Count,"Count" },
213 {NodeType::Sum,"Sum" },
214 {NodeType::OffsetSum,"OffsetSum" },
215 {NodeType::Prod,"Prod" },
216 {NodeType::ArgMax,"ArgMax" },
217
218 // transforms
219 {NodeType::Softmax,"Softmax" },
220
221 // timing masks
222 {NodeType::Before,"Before" },
223 {NodeType::After,"After" },
224 {NodeType::During,"During" },
225
226 //split
227 {NodeType::SplitBest,"SplitBest" },
228 {NodeType::SplitOn,"SplitOn" },
229
230 // leaves
231 {NodeType::MeanLabel,"MeanLabel" },
232 {NodeType::Constant,"Constant" },
233 {NodeType::Terminal,"Terminal" },
234
235 // custom
236 {NodeType::CustomUnaryOp,"CustomUnaryOp" },
237 {NodeType::CustomBinaryOp,"CustomBinaryOp" },
238 {NodeType::CustomSplit,"CustomSplit" }
239})
240#endif
241
242} // namespace Brush
243
244// format overload for NodeTypes
245template <> struct fmt::formatter<Brush::NodeType>: formatter<string_view> {
246 // parse is inherited from formatter<string_view>.
247 template <typename FormatContext>
248 auto format(Brush::NodeType x, FormatContext& ctx) const {
249 return formatter<string_view>::format(Brush::NodeTypeName.at(x), ctx);
250 }
251};
252
253template <typename T, typename... Ts>
254struct is_any {
255 static constexpr bool value = (std::is_same_v<T, Ts> || ...);
256};
257
258template <typename T, typename... Ts>
259static constexpr bool is_any_v = is_any<T, Ts...>::value;
260
261template <NodeType T, NodeType... Ts>
262struct is_in
263{
264 static constexpr bool value = ((T == Ts) || ...);
265};
266
267template<NodeType T, NodeType... Ts>
268static constexpr bool is_in_v = is_in<T, Ts...>::value;
269
270using NT = NodeType;
271// NodeType concepts
272template<NT nt>
273static constexpr bool UnaryOp = is_in_v<nt,
274 NT::Abs,
275 NT::Acos,
276 NT::Asin,
277 NT::Atan,
278 NT::Cos,
279 NT::Cosh,
280 NT::Sin,
281 NT::Sinh,
282 NT::Tan,
283 NT::Tanh,
284 NT::Ceil,
285 NT::Floor,
286 NT::Exp,
287 NT::Log,
289 NT::Log1p,
290 NT::Sqrt,
294 // NT::Not
295>;
296
297template<NT nt>
298static constexpr bool BinaryOp = is_in_v<nt,
299 NT::Add,
300 NT::Sub,
301 NT::Mul,
302 NT::Div,
303 NT::Pow
304>;
305
306template<NT nt>
307static constexpr bool AssociativeBinaryOp = is_in_v<nt,
308 NT::Add,
309 NT::Mul
310>;
311
312template<NT nt>
313static constexpr bool NaryOp = is_in_v<nt,
314 NT::Min,
315 NT::Max,
316 NT::Mean,
318 NT::Sum,
320 NT::Prod,
322>;
323
324// // TODO: make this work
325// template<typename NT, size_t ArgCount>
326// concept Transformer = requires(NT n, size_t ArgCount)
327// {
328// UnaryOp<n> && ArgCount > 1;
329// }
330
331// template<typename NT, size_t ArgCount>
332// concept Reducer = requires(NT n, size_t ArgCount)
333// {
334// BinaryOp<n> && ArgCount > 2;
335// }
336
337
338
339#include "signatures.h"
340#endif
TimeSeries< bool > TimeSeriesb
TimeSeries convenience typedefs.
Definition types.h:110
TimeSeries< float > TimeSeriesf
Definition types.h:112
TimeSeries< int > TimeSeriesi
Definition types.h:111
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
std::underlying_type_t< NodeType > UnderlyingNodeType
Definition nodetype.h:114
constexpr auto operator|(NodeType lhs, NodeType rhs) -> NodeType
Definition nodetype.h:139
NodeType
Definition nodetype.h:31
constexpr auto operator~(NodeType lhs) -> NodeType
Definition nodetype.h:141
auto operator|=(NodeType &lhs, NodeType rhs) -> NodeType &
Definition nodetype.h:147
constexpr auto operator^(NodeType lhs, NodeType rhs) -> NodeType
Definition nodetype.h:140
DataType
data types.
Definition types.h:143
auto operator&=(NodeType &lhs, NodeType rhs) -> NodeType &
Definition nodetype.h:142
std::map< std::string, NodeType > NodeNameType
Definition nodetype.cpp:5
auto operator^=(NodeType &lhs, NodeType rhs) -> NodeType &
Definition nodetype.h:152
constexpr auto operator&(NodeType lhs, NodeType rhs) -> NodeType
Definition nodetype.h:138
NodeType NT
Definition node.cpp:134
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
static constexpr bool is_any_v
Definition nodetype.h:259
static constexpr bool is_in_v
Definition nodetype.h:268
static constexpr bool NaryOp
Definition nodetype.h:313
static constexpr bool UnaryOp
Definition nodetype.h:273
static constexpr bool AssociativeBinaryOp
Definition nodetype.h:307
static constexpr bool BinaryOp
Definition nodetype.h:298
static constexpr size_t Count
Definition nodetype.h:120
static auto GetIndex(NodeType type) -> size_t
Definition nodetype.h:126
static constexpr size_t OpCount
Definition nodetype.h:123
auto format(Brush::NodeType x, FormatContext &ctx) const
Definition nodetype.h:248
static constexpr bool value
Definition nodetype.h:255
static constexpr bool value
Definition nodetype.h:264