Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
signatures.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5*/
6#ifndef SIGNATURES_H
7#define SIGNATURES_H
8
9namespace Brush {
11// refs:
12// https://stackoverflow.com/questions/25958259/how-do-i-find-out-if-a-tuple-contains-a-type
13// https://stackoverflow.com/questions/34111060/c-check-if-the-template-type-is-one-of-the-variadic-template-types
14
15
16// TODO: potentially improve this with something like
17// template <typename T, typename S> struct Jetify;
18// template<typename T,
19// typename S=T::Scalar,
20// typename R=T::RowsAtCompileTime,
21// typename C=T::ColsAtCompileTime,
22// >
23// struct Jetify<Eigen::ArrayBase<T>> {
24// using Scalar = std::conditional_t<is_same_v<S,int>, iJet,
25// conditional_t<is_same_v<S,bool>,bJet,
26// conditional_t<is_same_v<S,float>,fJet, void>>>;
27// using type = Array<Scalar,R,C>;
28// };
29
30static constexpr size_t MAX_ARGS = 5;
31
32template <typename T> struct Jetify { using type = T;};
33template<> struct Jetify<ArrayXf> { using type = ArrayXfJet;};
34template<> struct Jetify<ArrayXi> { using type = ArrayXiJet;};
35template<> struct Jetify<ArrayXb> { using type = ArrayXbJet;};
36template<> struct Jetify<ArrayXXf> { using type = ArrayXXfJet;};
37template<> struct Jetify<ArrayXXi> { using type = ArrayXXiJet;};
38template<> struct Jetify<ArrayXXb> { using type = ArrayXXbJet;};
39template<> struct Jetify<Data::TimeSeriesf> { using type = Data::TimeSeriesfJet;};
40template<> struct Jetify<Data::TimeSeriesi> { using type = Data::TimeSeriesiJet;};
41template<> struct Jetify<Data::TimeSeriesb> { using type = Data::TimeSeriesbJet;};
42template <typename T>
43using Jetify_t = typename Jetify<T>::type;
44
45template <typename T> struct UnJetify { using type = T;};
46template<> struct UnJetify<ArrayXfJet> { using type = ArrayXf;};
47template<> struct UnJetify<ArrayXiJet> { using type = ArrayXi;};
48template<> struct UnJetify<ArrayXbJet> { using type = ArrayXb;};
49template<> struct UnJetify<ArrayXXfJet> { using type = ArrayXXf;};
50template<> struct UnJetify<ArrayXXiJet> { using type = ArrayXXi;};
51template<> struct UnJetify<ArrayXXbJet> { using type = ArrayXXb;};
52template<> struct UnJetify<Data::TimeSeriesfJet> { using type = Data::TimeSeriesf;};
53template<> struct UnJetify<Data::TimeSeriesiJet> { using type = Data::TimeSeriesi;};
54template<> struct UnJetify<Data::TimeSeriesbJet> { using type = Data::TimeSeriesb;};
55template <typename T>
57
58
59template<typename R, typename... Args>
60struct SigBase
61{
62 using RetType = R;
63 static constexpr std::size_t ArgCount = sizeof...(Args);
64
65 using FirstArg = std::tuple_element_t<0, std::tuple<Args...>>;
66
68 // (using std::array allows begin() and end() ops like transform to be applied)
69 // TODO: add an option to have argtypes be an ArrayX<T,-1,ArgCount> if the nodetype
70 // is associative and the ArgCount is greater than the operator's arg count
71 // (i.e., add is a BinaryOp and associative, so for Args>2 make the argtype an ArrayXX<T> )
72 using ArgTypes = conditional_t<(std::is_same_v<FirstArg,Args> && ...),
73 std::array<FirstArg,ArgCount>,
74 std::tuple<Args...>
75 >;
76 template <std::size_t N>
77 using NthType = conditional_t<!is_tuple<ArgTypes>::value,
79 typename std::tuple_element<N, ArgTypes>::type
80 >;
82 // currently unused
83 using Function = std::function<R(Args...)>;
84
85 template<size_t... Is>
86 static constexpr auto get_arg_types(std::index_sequence<Is...>)
87 {
88 return vector<DataType>{(DataTypeEnum<NthType<Is>>::value) ...};
89 }
90
91 static constexpr auto get_arg_types() {
92 return get_arg_types(make_index_sequence<ArgCount>());
93 }
94 static constexpr auto get_args_type() {
95 if constexpr (!is_tuple<ArgTypes>::value)
96 return "Array";
97 else
98 return "Tuple";
99 };
100
101 static constexpr auto get_ret_type() {return DataTypeEnum<RetType>::value;};
102
103 template<typename T>
104 static constexpr bool contains() { return is_any_v<T, Args...>; }
105
106 static constexpr std::size_t hash_args(){ return typeid(ArgTypes).hash_code();}
107
108 static constexpr std::size_t hash(){ return typeid(tuple<R,Args...>).hash_code();};
109};
110
111template<typename R>
112struct SigBase<R>
113{
114 using RetType = R;
115 using ArgTypes = void;
116 using FirstArg = void;
118 static constexpr std::size_t ArgCount = 0;
119 static constexpr auto get_ret_type() {return DataTypeEnum<RetType>::value;};
120 static constexpr auto get_arg_types() { return vector<DataType>{}; };
121 static constexpr auto get_args_type() { return "None"; };
122 static constexpr std::size_t hash(){ return typeid(R).hash_code(); };
123};
124
125template<typename T> struct Signature;
126template<typename R, typename... Args>
127struct Signature<R(Args...)> : SigBase<R, Args...>
128{
129 using base = SigBase<R, Args...>;
134 static constexpr auto ArgCount = base::ArgCount;
135
138};
139
140template<typename R, typename Arg, size_t ArgCount,
141 typename Indices = std::make_index_sequence<ArgCount> >
143{
144 template <std::size_t N>
145 using NthType = Arg;
146
147 template<size_t ...Is>
148 static constexpr auto make_signature(std::index_sequence<Is...>)
149 {
150 return Signature<R(NthType<Is>...)>{};
151 }
152
153 using type = decltype(make_signature(Indices{}));
154
155};
156template<typename R, typename Arg, size_t ArgCount>
158
168template<typename R, typename Arg, size_t MaxArgCount>
170{
171 template <std::size_t N>
172 using NthType = Arg;
173 static constexpr size_t Min = 2;
174 static constexpr size_t Max = MaxArgCount-2;
175 static constexpr auto Indices = std::make_index_sequence<Max>();
176
177 template<size_t ...Is>
178 static constexpr auto make_signatures(std::index_sequence<Is...>)
179 {
180 return std::tuple<NarySignature_t<R,Arg,(Is+Min)> ...>();
181 }
182
183 using type = decltype(make_signatures(Indices));
184
185};
186template<typename R, typename Arg, size_t MaxArgCount>
188
190// Signatures
191// - store the signatures that each Node can handle
192//
193template<NodeType N, typename T = void> struct Signatures;
194
195template<NodeType N>
197 using type = std::tuple<
198 Signature<ArrayXf()>,
201 >;
202};
203
204template<>
206 // meanlabel is based on y, so it is always a float ret_type
207 using type = std::tuple<
208 Signature<ArrayXf()>,
211 >;
212};
213
214template<NodeType N>
215struct Signatures<N, enable_if_t<is_in_v<N,
216 NodeType::Add,
221 >>>{
222 using type = std::tuple<
223 Signature<ArrayXf(ArrayXf,ArrayXf)>,
225 // Signature<ArrayXf(ArrayXi,ArrayXf)>, // this will cast the integer to float. TODO: make this work (or figure out a better way of casting)
226 Signature<ArrayXXf(ArrayXXf,ArrayXXf)>,
228 >;
229 };
230
231template<NodeType N>
232struct Signatures<N, enable_if_t<is_in_v<N,
233 NodeType::And,
235 >>>{
236 using type = std::tuple<
238 >;
239 };
240
241template<>
243 using type = std::tuple<
245 >;
246 };
247
248template<NodeType N>
249struct Signatures<N, enable_if_t<is_in_v<N,
250 NodeType::Abs,
270 >>>{
271 // using type = std::tuple<
272 // Signature<ArrayXf(ArrayXf)>,
273 // Signature<ArrayXXf(ArrayXXf)>
274 // >;
275 using unaryTuple = std::tuple<
276 Signature<ArrayXf(ArrayXf)>,
277 Signature<ArrayXXf(ArrayXXf)>
278 >;
279
281
282 using type = decltype(std::tuple_cat(unaryTuple(), naryTuple()));
283
284 // using default = tuple_element<0,type>;
285 };
286
287
288template<NodeType N>
289struct Signatures<N, enable_if_t<is_in_v<N,
291 >>>{
292 using unaryTuple = std::tuple<
293 Signature<ArrayXf(ArrayXf)>
294 // Signature<ArrayXf(ArrayXi)>,
295 // Signature<ArrayXf(ArrayXb)>
296 >;
297
299
300 using type = decltype(std::tuple_cat(unaryTuple(), naryTuple()));
301
302 // using default = tuple_element<0,type>;
303 };
304
305template<NodeType N>
306struct Signatures<N, enable_if_t<is_in_v<N,
310 >>>{
311 //TODO: Fix
312 using type = std::tuple<
316 >;
317 };
318
319template<NodeType N>
320struct Signatures<N, enable_if_t<is_in_v<N,
321 NodeType::Sum,
327 >>>{
328 using unaryTuple = std::tuple<
329 Signature<ArrayXf(ArrayXXf)>,
330 Signature<ArrayXf(TimeSeriesf)>
331 >;// TODO: should I implement compatibility with integers?
332
334
335 using type = decltype(std::tuple_cat(unaryTuple(), naryTuple()));
336 };
337
338template<>
340 using type = std::tuple<
341 /* Signature<ArrayXf(ArrayXXb)>, */
342 Signature<ArrayXf(TimeSeriesf)>,
343 Signature<ArrayXf(TimeSeriesi)>,
344 Signature<ArrayXf(TimeSeriesb)>
345 >;
346};
347
348template<NodeType N>
349struct Signatures<N, enable_if_t<is_in_v<N,
351 >>>{
352 using type = std::tuple<
353 Signature<ArrayXi(ArrayXXf)>,
356 >;
357 };
358
359/* template<NodeType N> */
360/* struct Signatures<N, enable_if_t<is_in_v<N, */
361/* NodeType::Equals, */
362/* NodeType::LessThan, */
363/* NodeType::GreaterThan, */
364/* NodeType::Leq, */
365/* NodeType::Geq, */
366/* NodeType::CustomBinaryOp */
367/* >>>{ */
368/* using type = std::tuple< */
369/* Signature<ArrayXb(ArrayXf,ArrayXf)> */
370/* /1* Signature<ArrayXXf(ArrayXXf,ArrayXXf)>, *1/ */
371/* >; */
372/* }; */
373
374template<NodeType N>
376 {
377 using type = std::tuple<
378 Signature<ArrayXf(ArrayXf,ArrayXf)>,
381 // TODO
382 /* Signature<ArrayXXf,ArrayXXf,,ArrayXXf,, */
383 /* Signature<ArrayXXf,ArrayXXf,,ArrayXXf, */
384 >;
385 };
386
387template<>
389 // spliton and splitbest will always compare to the weight if is a number, otherwise will use the boolean value
390 // TODO: idea: if we have the LEQ or GEQ, we can have splitOn with all different data types without
391 // having to make it explicit in the signature. I think there is too many types of splitOn that makes it hard to actually be used
392 using type = std::tuple<
393 // Signature<ArrayXf(ArrayXf,ArrayXf,ArrayXf)>,
394 // Signature<ArrayXf(ArrayXi,ArrayXf,ArrayXf)>,
395 Signature<ArrayXf(ArrayXb,ArrayXf,ArrayXf)>,
396
397 // Signature<ArrayXi(ArrayXf,ArrayXi,ArrayXi)>,
398 // Signature<ArrayXi(ArrayXi,ArrayXi,ArrayXi)>,
400
401 // Signature<ArrayXb(ArrayXf,ArrayXb,ArrayXb)>,
402 // Signature<ArrayXb(ArrayXi,ArrayXb,ArrayXb)>,
404 >;
405 };
406
407template <>
409 {
410 using unaryTuple = std::tuple< Signature<ArrayXXf(ArrayXXf)> >;
412
413 using type = decltype(std::tuple_cat(unaryTuple(), naryTuple()));
414 };
415
416} // namespace Brush
417#endif
STL class.
TimeSeries< float > TimeSeriesf
Definition types.h:112
namespace containing Data structures used in Brush
Definition data.cpp:49
TimeSeries< fJet > TimeSeriesfJet
Definition types.h:115
TimeSeries< iJet > TimeSeriesiJet
Definition types.h:114
TimeSeries< bool > TimeSeriesb
TimeSeries convenience typedefs.
Definition types.h:110
TimeSeries< float > TimeSeriesf
Definition types.h:112
TimeSeries< bJet > TimeSeriesbJet
Definition types.h:113
TimeSeries< int > TimeSeriesi
Definition types.h:111
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
NodeType
Definition nodetype.h:31
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
Eigen::Array< int, Eigen::Dynamic, Eigen::Dynamic > ArrayXXi
Definition types.h:42
Eigen::Array< bJet, Eigen::Dynamic, Eigen::Dynamic > ArrayXXbJet
Definition types.h:54
auto Is(NodeType nt) -> bool
Definition node.h:272
typename NarySignature< R, Arg, ArgCount >::type NarySignature_t
Definition signatures.h:157
static constexpr size_t MAX_ARGS
Definition signatures.h:30
typename Jetify< T >::type Jetify_t
Definition signatures.h:43
typename UnJetify< T >::type UnJetify_t
Definition signatures.h:56
Eigen::Array< bool, Eigen::Dynamic, Eigen::Dynamic > ArrayXXb
Definition types.h:41
typename NarySignatures< R, Arg, MaxArgCount >::type NarySignatures_t
Definition signatures.h:187
Eigen::Array< fJet, Eigen::Dynamic, 1 > ArrayXfJet
Definition types.h:49
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Eigen::Array< bJet, Eigen::Dynamic, 1 > ArrayXbJet
Definition types.h:51
Eigen::Array< fJet, Eigen::Dynamic, Eigen::Dynamic > ArrayXXfJet
Definition types.h:52
Eigen::Array< iJet, Eigen::Dynamic, 1 > ArrayXiJet
Definition types.h:50
Eigen::Array< iJet, Eigen::Dynamic, Eigen::Dynamic > ArrayXXiJet
Definition types.h:53
TimeSeries< bool > TimeSeriesb
TimeSeries convenience typedefs.
Definition types.h:110
static constexpr bool is_any_v
Definition nodetype.h:259
static constexpr bool is_in_v
Definition nodetype.h:268
TimeSeries< int > TimeSeriesi
Definition types.h:111
static constexpr auto make_signature(std::index_sequence< Is... >)
Definition signatures.h:148
decltype(make_signature(Indices{})) type
Definition signatures.h:153
Makes a tuple of signatures with increasing arity up to MaxArgCount.
Definition signatures.h:170
static constexpr auto make_signatures(std::index_sequence< Is... >)
Definition signatures.h:178
static constexpr size_t Max
Definition signatures.h:174
static constexpr auto Indices
Definition signatures.h:175
decltype(make_signatures(Indices)) type
Definition signatures.h:183
static constexpr size_t Min
Definition signatures.h:173
static constexpr auto get_ret_type()
Definition signatures.h:119
static constexpr auto get_arg_types()
Definition signatures.h:120
static constexpr auto get_args_type()
Definition signatures.h:121
static constexpr std::size_t hash()
Definition signatures.h:122
typename WeightType< R >::type WeightType
Definition signatures.h:117
static constexpr std::size_t ArgCount
Definition signatures.h:118
static constexpr auto get_arg_types()
Definition signatures.h:91
conditional_t<(std::is_same_v< FirstArg, Args > &&...), std::array< FirstArg, ArgCount >, std::tuple< Args... > > ArgTypes
ArgTypes is a std::array if the types are shared, otherwise it is a tuple.
Definition signatures.h:72
static constexpr std::size_t ArgCount
Definition signatures.h:63
static constexpr std::size_t hash()
Definition signatures.h:108
std::function< R(Args...)> Function
Definition signatures.h:83
std::tuple_element_t< 0, std::tuple< Args... > > FirstArg
Definition signatures.h:65
static constexpr auto get_arg_types(std::index_sequence< Is... >)
Definition signatures.h:86
static constexpr auto get_args_type()
Definition signatures.h:94
conditional_t<!is_tuple< ArgTypes >::value, FirstArg, typename std::tuple_element< N, ArgTypes >::type > NthType
Definition signatures.h:77
static constexpr auto get_ret_type()
Definition signatures.h:101
static constexpr std::size_t hash_args()
Definition signatures.h:106
typename WeightType< FirstArg >::type WeightType
Definition signatures.h:81
static constexpr bool contains()
Definition signatures.h:104
base::WeightType WeightType
Definition signatures.h:133
static constexpr auto ArgCount
Definition signatures.h:134
SigBase< R, Args... > base
Definition signatures.h:129
SigBase< Jetify_t< RetType >, Jetify_t< Args >... > Dual
Definition signatures.h:136
SigBase< RetType, Jetify_t< Args >... > DualArgs
Definition signatures.h:137
std::tuple< Signature< ArrayXf(ArrayXf, ArrayXf)>, Signature< ArrayXi(ArrayXi, ArrayXi)>, Signature< ArrayXXf(ArrayXXf, ArrayXXf)>, Signature< ArrayXXi(ArrayXXi, ArrayXXi)> > type
Definition signatures.h:222
std::tuple< Signature< ArrayXi(ArrayXXf)>, Signature< ArrayXi(ArrayXXi)>, Signature< ArrayXi(ArrayXXb)> > type
Definition signatures.h:352
std::tuple< Signature< TimeSeriesf(TimeSeriesf, TimeSeriesf)>, Signature< TimeSeriesi(TimeSeriesi, TimeSeriesi)>, Signature< TimeSeriesb(TimeSeriesb, TimeSeriesb)> > type
Definition signatures.h:312
std::tuple< Signature< ArrayXf()>, Signature< ArrayXi()>, Signature< ArrayXb()> > type
Definition signatures.h:197
decltype(std::tuple_cat(unaryTuple(), naryTuple())) type
Definition signatures.h:300
NarySignatures_t< ArrayXXf, ArrayXf, MAX_ARGS > naryTuple
Definition signatures.h:298
std::tuple< Signature< ArrayXf(ArrayXf, ArrayXf)>, Signature< ArrayXi(ArrayXi, ArrayXi)>, Signature< ArrayXb(ArrayXb, ArrayXb)> > type
Definition signatures.h:377
std::tuple< Signature< ArrayXf(TimeSeriesf)>, Signature< ArrayXf(TimeSeriesi)>, Signature< ArrayXf(TimeSeriesb)> > type
Definition signatures.h:340
std::tuple< Signature< ArrayXf()>, Signature< ArrayXi()>, Signature< ArrayXb()> > type
Definition signatures.h:207
std::tuple< Signature< ArrayXb(ArrayXb)> > type
Definition signatures.h:243
std::tuple< Signature< ArrayXXf(ArrayXXf)> > unaryTuple
Definition signatures.h:410
NarySignatures_t< ArrayXXf, ArrayXf, MAX_ARGS > naryTuple
Definition signatures.h:411
decltype(std::tuple_cat(unaryTuple(), naryTuple())) type
Definition signatures.h:413
std::tuple< Signature< ArrayXf(ArrayXb, ArrayXf, ArrayXf)>, Signature< ArrayXi(ArrayXb, ArrayXi, ArrayXi)>, Signature< ArrayXb(ArrayXb, ArrayXb, ArrayXb)> > type
Definition signatures.h:392
std::conditional_t< is_one_of_v< typename T::Scalar, fJet, iJet, bJet >, fJet, float > type
Definition types.h:64