Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
nodetype.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5This section adapted heavily from Operon: https://github.com/heal-research/operon/
6*/
7#ifndef NODEMAP_H
8#define NODEMAP_H
9// external includes
10#include <bitset>
11#include <type_traits>
12#include <utility>
13//internal includes
14#include "../init.h"
15#include "../data/data.h"
16#include "../util/utils.h"
17#include "../util/rnd.h"
18using std::vector;
19using namespace Brush;
20/* using Brush::NodeType; */
21/* using Brush::ExecType; */
22using std::tuple;
23using std::array;
24using Brush::DataType;
28
29namespace Brush {
30
31enum class NodeType : uint64_t { // Each node type must have a complexity
32 // in operator_complexities@tree_node.cpp
33 // Unary
34 Abs = 1UL << 0UL,
35 Acos = 1UL << 1UL,
36 Asin = 1UL << 2UL,
37 Atan = 1UL << 3UL,
38 Cos = 1UL << 4UL,
39 Cosh = 1UL << 5UL,
40 Sin = 1UL << 6UL,
41 Sinh = 1UL << 7UL,
42 Tan = 1UL << 8UL,
43 Tanh = 1UL << 9UL,
44 Ceil = 1UL << 10UL,
45 Floor = 1UL << 11UL,
46 Exp = 1UL << 12UL,
47 Log = 1UL << 13UL,
48 Logabs = 1UL << 14UL,
49 Log1p = 1UL << 15UL,
50 Sqrt = 1UL << 16UL,
51 Sqrtabs = 1UL << 17UL,
52 Square = 1UL << 18UL,
53 Logistic = 1UL << 19UL, // used as root for classification trees
54
55 // timing masks
56 Before = 1UL << 20UL,
57 After = 1UL << 21UL,
58 During = 1UL << 22UL,
59
60 // Reducers
61 Min = 1UL << 23UL,
62 Max = 1UL << 24UL,
63 Mean = 1UL << 25UL,
64 Median = 1UL << 26UL,
65 Prod = 1UL << 27UL,
66 Sum = 1UL << 28UL,
67 OffsetSum = 1UL << 29UL, // Sum with weight as one of its arguments
68
69 // Transformers
70 Softmax = 1UL << 30UL, // used as root for multiclf trees
71
72 // Binary
73 Add = 1UL << 31UL,
74 Sub = 1UL << 32UL,
75 Mul = 1UL << 33UL,
76 Div = 1UL << 34UL,
77 Pow = 1UL << 35UL,
78
79 //split
80 SplitBest = 1UL << 36UL,
81 SplitOn = 1UL << 37UL,
82
83 // boolean
84 And = 1UL << 38UL,
85 Or = 1UL << 39UL,
86 Not = 1UL << 40UL,
87 // Xor = 1UL << 39UL,
88
89 // comparison
90 // Equals = 1UL << 41UL,
91 // Geq = 1UL << 42UL,
92 /* GreaterThan = 1UL << 41UL, */
93 /* Leq = 1UL << 42UL, */
94 /* LessThan = 1UL << 43UL, */
95
96 // leaves (must be the last ones in this enum)
97 MeanLabel = 1UL << 41UL,
98 Constant = 1UL << 42UL,
99 Terminal = 1UL << 43UL,
100
101 // TODO: implement operators below and move them before leaves
102 ArgMax = 1UL << 44UL,
103 // count the number of elements in an array. Should be the last element in the enum
104 Count = 1UL << 45UL,
105
106 // // custom
107 CustomUnaryOp = 1UL << 46UL,
108 CustomBinaryOp = 1UL << 47UL,
109 CustomSplit = 1UL << 48UL
110
111};
112
113
114using UnderlyingNodeType = std::underlying_type_t<NodeType>;
115struct NodeTypes {
116 // magic number keeping track of the number of different node types
117
118 // index of last available node visible to search_space.
119 // It must match the highest bit used in the enum.
120 //notice that we will create the nodetypes until this count
121 static constexpr size_t Count = 44;
122
123 // subtracting leaves (leaving just the ops into this)
124 static constexpr size_t OpCount = Count-3;
125
126 // returns the index of the given type in the NodeType enum
127 static auto GetIndex(NodeType type) -> size_t
128 {
129 // Chad G. Pete did this
130 UnderlyingNodeType utype = static_cast<UnderlyingNodeType>(type);
131 size_t result = 0;
132 while (utype >>= 1) ++result;
133
134 return utype ? result + 1 : 0;
135 }
136};
137
138
139inline constexpr auto operator&(NodeType lhs, NodeType rhs) -> NodeType { return static_cast<NodeType>(static_cast<UnderlyingNodeType>(lhs) & static_cast<UnderlyingNodeType>(rhs)); }
140inline constexpr auto operator|(NodeType lhs, NodeType rhs) -> NodeType { return static_cast<NodeType>(static_cast<UnderlyingNodeType>(lhs) | static_cast<UnderlyingNodeType>(rhs)); }
141inline constexpr auto operator^(NodeType lhs, NodeType rhs) -> NodeType { return static_cast<NodeType>(static_cast<UnderlyingNodeType>(lhs) ^ static_cast<UnderlyingNodeType>(rhs)); }
142inline constexpr auto operator~(NodeType lhs) -> NodeType { return static_cast<NodeType>(~static_cast<UnderlyingNodeType>(lhs)); }
143inline auto operator&=(NodeType& lhs, NodeType rhs) -> NodeType&
144{
145 lhs = lhs & rhs;
146 return lhs;
147}
148inline auto operator|=(NodeType& lhs, NodeType rhs) -> NodeType&
149{
150 lhs = lhs | rhs;
151 return lhs;
152}
153inline auto operator^=(NodeType& lhs, NodeType rhs) -> NodeType&
154{
155 lhs = lhs ^ rhs;
156 return lhs;
157}
158
159
160
161extern std::map<std::string, NodeType> NodeNameType;
162extern std::map<NodeType,std::string> NodeTypeName;
163
164#ifndef DOXYGEN_SKIP
165// map NodeType values to JSON as strings
166NLOHMANN_JSON_SERIALIZE_ENUM( NodeType, {
167 //arithmetic
168 {NodeType::Add,"Add" },
169 {NodeType::Sub,"Sub" },
170 {NodeType::Mul,"Mul" },
171 {NodeType::Div,"Div" },
172 /* {NodeType::Aq,"Aq" }, */
173 {NodeType::Abs,"Abs" },
174 {NodeType::Acos,"Acos" },
175 {NodeType::Asin,"Asin" },
176 {NodeType::Atan,"Atan" },
177 {NodeType::Cos,"Cos" },
178 {NodeType::Cosh,"Cosh" },
179 {NodeType::Sin,"Sin" },
180 {NodeType::Sinh,"Sinh" },
181 {NodeType::Tan,"Tan" },
182 {NodeType::Tanh,"Tanh" },
183 {NodeType::Ceil,"Ceil" },
184 {NodeType::Floor,"Floor" },
185 {NodeType::Exp,"Exp" },
186 {NodeType::Log,"Log" },
187 {NodeType::Logabs,"Logabs" },
188 {NodeType::Log1p,"Log1p" },
189 {NodeType::Sqrt,"Sqrt" },
190 {NodeType::Sqrtabs,"Sqrtabs" },
191 {NodeType::Square,"Square" },
192 {NodeType::Pow,"Pow" },
193 {NodeType::Logistic,"Logistic" },
194
195 // logic
196 {NodeType::And,"And" },
197 {NodeType::Or,"Or" },
198 {NodeType::Not,"Not" },
199 // {NodeType::Xor,"Xor" },
200
201 // decision (same)
202 // {NodeType::Equals,"Equals" },
203 // {NodeType::Geq,"Geq" },
204 /* {NodeType::LessThan,"LessThan" }, */
205 /* {NodeType::Leq,"Leq" }, */
206 /* {NodeType::Geq,"Geq" }, */
207
208 // reductions
209 {NodeType::Min,"Min" },
210 {NodeType::Max,"Max" },
211 {NodeType::Mean,"Mean" },
212 {NodeType::Median,"Median" },
213 {NodeType::Count,"Count" },
214 {NodeType::Sum,"Sum" },
215 {NodeType::OffsetSum,"OffsetSum" },
216 {NodeType::Prod,"Prod" },
217 {NodeType::ArgMax,"ArgMax" },
218
219 // transforms
220 {NodeType::Softmax,"Softmax" },
221
222 // timing masks
223 {NodeType::Before,"Before" },
224 {NodeType::After,"After" },
225 {NodeType::During,"During" },
226
227 //split
228 {NodeType::SplitBest,"SplitBest" },
229 {NodeType::SplitOn,"SplitOn" },
230
231 // leaves
232 {NodeType::MeanLabel,"MeanLabel" },
233 {NodeType::Constant,"Constant" },
234 {NodeType::Terminal,"Terminal" },
235
236 // custom
237 {NodeType::CustomUnaryOp,"CustomUnaryOp" },
238 {NodeType::CustomBinaryOp,"CustomBinaryOp" },
239 {NodeType::CustomSplit,"CustomSplit" }
240})
241#endif
242
243} // namespace Brush
244
245// format overload for NodeTypes
246template <> struct fmt::formatter<Brush::NodeType>: formatter<string_view> {
247 // parse is inherited from formatter<string_view>.
248 template <typename FormatContext>
249 auto format(Brush::NodeType x, FormatContext& ctx) const {
250 return formatter<string_view>::format(Brush::NodeTypeName.at(x), ctx);
251 }
252};
253
254template <typename T, typename... Ts>
255struct is_any {
256 static constexpr bool value = (std::is_same_v<T, Ts> || ...);
257};
258
259template <typename T, typename... Ts>
260static constexpr bool is_any_v = is_any<T, Ts...>::value;
261
262template <NodeType T, NodeType... Ts>
263struct is_in
264{
265 static constexpr bool value = ((T == Ts) || ...);
266};
267
268template<NodeType T, NodeType... Ts>
269static constexpr bool is_in_v = is_in<T, Ts...>::value;
270
271using NT = NodeType;
272// NodeType concepts
273template<NT nt>
274static constexpr bool UnaryOp = is_in_v<nt,
275 NT::Abs,
276 NT::Acos,
277 NT::Asin,
278 NT::Atan,
279 NT::Cos,
280 NT::Cosh,
281 NT::Sin,
282 NT::Sinh,
283 NT::Tan,
284 NT::Tanh,
285 NT::Ceil,
286 NT::Floor,
287 NT::Exp,
288 NT::Log,
290 NT::Log1p,
291 NT::Sqrt,
295 NT::Not
296>;
297
298template<NT nt>
299static constexpr bool BinaryOp = is_in_v<nt,
300 NT::Add,
301 NT::Sub,
302 NT::Mul,
303 NT::Div,
304 NT::Pow,
305 NT::And,
306 NT::Or
307 // NT::Equals
308>;
309
310template<NT nt>
311static constexpr bool AssociativeBinaryOp = is_in_v<nt,
312 NT::Add,
313 NT::Mul,
314 NT::And,
315 NT::Or
316 // NT::Equals
317>;
318
319template<NT nt>
320static constexpr bool NaryOp = is_in_v<nt,
321 NT::Min,
322 NT::Max,
323 NT::Mean,
325 NT::Sum,
327 NT::Prod,
329>;
330
331// // TODO: make this work
332// template<typename NT, size_t ArgCount>
333// concept Transformer = requires(NT n, size_t ArgCount)
334// {
335// UnaryOp<n> && ArgCount > 1;
336// }
337
338// template<typename NT, size_t ArgCount>
339// concept Reducer = requires(NT n, size_t ArgCount)
340// {
341// BinaryOp<n> && ArgCount > 2;
342// }
343
344
345
346#include "signatures.h"
347#endif
TimeSeries< bool > TimeSeriesb
TimeSeries convenience typedefs.
Definition types.h:110
TimeSeries< float > TimeSeriesf
Definition types.h:112
TimeSeries< int > TimeSeriesi
Definition types.h:111
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
std::underlying_type_t< NodeType > UnderlyingNodeType
Definition nodetype.h:114
constexpr auto operator|(NodeType lhs, NodeType rhs) -> NodeType
Definition nodetype.h:140
NodeType
Definition nodetype.h:31
constexpr auto operator~(NodeType lhs) -> NodeType
Definition nodetype.h:142
auto operator|=(NodeType &lhs, NodeType rhs) -> NodeType &
Definition nodetype.h:148
constexpr auto operator^(NodeType lhs, NodeType rhs) -> NodeType
Definition nodetype.h:141
DataType
data types.
Definition types.h:143
auto operator&=(NodeType &lhs, NodeType rhs) -> NodeType &
Definition nodetype.h:143
std::map< std::string, NodeType > NodeNameType
Definition nodetype.cpp:5
auto operator^=(NodeType &lhs, NodeType rhs) -> NodeType &
Definition nodetype.h:153
constexpr auto operator&(NodeType lhs, NodeType rhs) -> NodeType
Definition nodetype.h:139
NodeType NT
Definition node.cpp:134
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
static constexpr bool is_any_v
Definition nodetype.h:260
static constexpr bool is_in_v
Definition nodetype.h:269
static constexpr bool NaryOp
Definition nodetype.h:320
static constexpr bool UnaryOp
Definition nodetype.h:274
static constexpr bool AssociativeBinaryOp
Definition nodetype.h:311
static constexpr bool BinaryOp
Definition nodetype.h:299
static constexpr size_t Count
Definition nodetype.h:121
static auto GetIndex(NodeType type) -> size_t
Definition nodetype.h:127
static constexpr size_t OpCount
Definition nodetype.h:124
auto format(Brush::NodeType x, FormatContext &ctx) const
Definition nodetype.h:249
static constexpr bool value
Definition nodetype.h:256
static constexpr bool value
Definition nodetype.h:265