Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
dispatch_table.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5Dispatch class design heavily inspired by Operon, (c) Heal Research
6https://github.com/heal-research/operon/
7*/
8
9#ifndef DISPATCH_TABLE_H
10#define DISPATCH_TABLE_H
11
12#include "../init.h"
13#include "../data/data.h"
14#include "nodetype.h"
15#include "node.h"
16#include <optional>
17#include <cstddef>
18#include <tuple>
19
20
21// forward declarations
22template<typename T> class tree_node_;
23using TreeNode = class tree_node_<Node>;
24
25namespace Brush{
26
27// forward declarations
28template<typename R, NodeType NT, typename S, bool Fit, typename W>
29R DispatchOp( const Data::Dataset& d, TreeNode& tn, const W** weights);
30template<typename R, NodeType NT, typename S, bool Fit>
32
34// Dispatch Table
35template<bool Fit>
37{
38 template<typename T>
39 using Callable = typename std::conditional_t<Fit,
40 std::function<T(const Data::Dataset&, TreeNode&)>,
41 std::function<T(const Data::Dataset&, TreeNode&, const typename WeightType<T>::type**)>
42 >;
43
44 using CallVariant = std::variant<
54 // jet overloads
64 >;
66 using SigMap = std::unordered_map<std::size_t,CallVariant>;
68 using DTMap = std::unordered_map<NodeType, SigMap>;
69
70private:
72
73 template<std::size_t... Is>
74 void InitMap(std::index_sequence<Is...> /*unused*/)
75 {
76
77 //TODO: nt(Is) should be a hash, if want to register other functions
78 auto nt = [](auto i) { return static_cast<NodeType>(1UL << i); };
79 (map_.insert({ nt(Is), MakeOperators<nt(Is)>() }), ...);
80 }
81
82 template<NodeType NT>
84 {
85 using signatures = typename Signatures<NT>::type;
87 std::make_index_sequence<std::tuple_size_v<signatures>>()
88 );
89 }
90
91 template<NodeType NT, typename Sigs, std::size_t... Is>
92 static constexpr auto AddOperator(std::index_sequence<Is...>)
93 {
94 SigMap sm;
95 (sm.insert({std::tuple_element_t<Is, Sigs>::hash(),
97 // Add dual signatures that take Jet types
99 (sm.insert({std::tuple_element_t<Is, Sigs>::DualArgs::hash(),
101 }
102 else {
103 (sm.insert({std::tuple_element_t<Is, Sigs>::Dual::hash(),
105 }
106 return sm;
107 }
108
109
110 template<NodeType N, typename S>
111 static constexpr auto MakeCallable()
112 {
113 using R = typename S::RetType;
114 using W = typename S::WeightType;
115 if constexpr (Fit)
117 else
119 }
120
121public:
123 {
124 InitMap(std::make_index_sequence<NodeTypes::Count>{});
125 }
126
127 void print()
128 {
129 fmt::print("================== \n");
130 fmt::print("dispatch table map_: \n");
131 for (const auto& [nt, sigmap]: map_){
132 for (const auto& [sig, call]: sigmap){
133 if (Fit)
134 fmt::print("{} : {} : DispatchFit\n",nt, sig);
135 else
136 fmt::print("{} : {} : DispatchPredict\n",nt, sig);
137 }
138 }
139 fmt::print("================== \n");
140 }
141
142 ~DispatchTable() = default;
143
145 if (this != &other) {
146 map_ = other.map_;
147 }
148 return *this;
149 }
150
152 map_ = std::move(other.map_);
153 return *this;
154 }
155
158
159 template<typename T>
160 inline auto Get(NodeType n, std::size_t sig_hash) const -> Callable<T> const&
161 {
162 // fmt::print("get<Callable<{}>> for {} with hash {}\n",
163 // DataTypeEnum<T>::value, n, sig_hash
164 // );
165 if (map_.at(n).find(sig_hash) == map_.at(n).end())
166 {
167 string err;
168 err += fmt::format("sig_hash={} not in map_.at({})\n",sig_hash,n);
169 err += fmt::format("options:\n");
170 for (auto [k, v]: map_.at(n))
171 err+= fmt::format("{}\n", k);
173 }
174 // CallVariant callable = map_.at(n).at(sig_hash);
175 // try {
176 if (std::holds_alternative<Callable<T>>(map_.at(n).at(sig_hash)))
177 return std::get<Callable<T>>(map_.at(n).at(sig_hash));
178 // }
179 // catch(const std::bad_variant_access& e) {
180
181 // auto msg = fmt::format("{}\nTried to ",e.what()); HANDLE_ERROR_THROW(msg);
182 // }
183 else{
184 // if (map_.at(n).size() > 1){
185 // for (const auto & kv : map_.at(n))
186 // {
187 // if (std::holds_alternative<Callable<T>>(kv.second))
188 // return std::get<Callable<T>>(kv.second);
189 // }
190 // }
191 auto msg = fmt::format("Tried get<Callable<{}>> for {} with hash {}; failed"
192 " because map holds index {}\n",
193 DataTypeEnum<T>::value, n, sig_hash, map_.at(n).at(sig_hash).index()
194 );
196 }
197 return std::get<Callable<T>>(map_.at(n).at(sig_hash));
198 }
199
200};
201
204// // format overload
205// template <> struct fmt::formatter<Brush::SearchSpace>: formatter<string_view> {
206// template <typename FormatContext>
207// auto format(const Brush::SearchSpace& SS, FormatContext& ctx) const {
208// string output = "Search Space\n===\n";
209// output += fmt::format("terminal_map: {}\n", SS.terminal_map);
210// output += fmt::format("terminal_weights: {}\n", SS.terminal_weights);
211// for (const auto& [ret_type, v] : SS.node_map) {
212// for (const auto& [args_type, v2] : v) {
213// for (const auto& [node_type, node] : v2) {
214// output += fmt::format("node_map[{}][{}][{}] = {}, weight = {}\n",
215// ret_type,
216// ArgsName[args_type],
217// node_type,
218// node,
219// SS.node_map_weights.at(ret_type).at(args_type).at(node_type)
220// );
221// }
222// }
223// }
224// output += "===";
225// return formatter<string_view>::format(output, ctx);
226// }
227// };
228} // namespace Brush
229#endif
void bind_engine(py::module &m, string name)
holds variable type data.
Definition data.h:51
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition data.cpp:12
NodeType
Definition nodetype.h:31
DispatchTable< false > dtable_predict
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
Definition operator.h:363
auto Is(NodeType nt) -> bool
Definition node.h:260
NodeType NT
Definition node.cpp:129
DispatchTable< true > dtable_fit
static constexpr auto MakeCallable()
std::variant< Callable< ArrayXb >, Callable< ArrayXi >, Callable< ArrayXf >, Callable< ArrayXXb >, Callable< ArrayXXi >, Callable< ArrayXXf >, Callable< TimeSeriesb >, Callable< TimeSeriesi >, Callable< TimeSeriesf >, Callable< ArrayXbJet >, Callable< ArrayXiJet >, Callable< ArrayXfJet >, Callable< ArrayXXbJet >, Callable< ArrayXXiJet >, Callable< ArrayXXfJet >, Callable< Data::TimeSeriesbJet >, Callable< Data::TimeSeriesiJet >, Callable< Data::TimeSeriesfJet > > CallVariant
std::unordered_map< std::size_t, CallVariant > SigMap
maps Signature hashes -> Dispatch Operator
auto operator=(DispatchTable &&other) noexcept -> DispatchTable &
auto Get(NodeType n, std::size_t sig_hash) const -> Callable< T > const &
typename std::conditional_t< Fit, std::function< T(const Data::Dataset &, TreeNode &)>, std::function< T(const Data::Dataset &, TreeNode &, const typename WeightType< T >::type **)> > Callable
std::unordered_map< NodeType, SigMap > DTMap
maps NodeTypes -> Signature hash -> Dispatch Operator
DispatchTable(DispatchTable &&other) noexcept
static constexpr auto AddOperator(std::index_sequence< Is... >)
auto operator=(DispatchTable const &other) -> DispatchTable &
void InitMap(std::index_sequence< Is... >)
DispatchTable(DispatchTable const &other)
~DispatchTable()=default
class holding the data for a node in a tree.
Definition node.h:84
std::conditional_t< is_one_of_v< typename T::Scalar, fJet, iJet, bJet >, fJet, float > type
Definition types.h:64