9#ifndef DISPATCH_TABLE_H
10#define DISPATCH_TABLE_H
28template<
typename R, NodeType NT,
typename S,
bool Fit,
typename W>
30template<
typename R, NodeType NT,
typename S,
bool Fit>
66 using SigMap = std::unordered_map<std::size_t,CallVariant>;
68 using DTMap = std::unordered_map<NodeType, SigMap>;
73 template<std::size_t...
Is>
78 auto nt = [](
auto i) {
return static_cast<NodeType>(1UL <<
i); };
87 std::make_index_sequence<std::tuple_size_v<signatures>>()
95 (
sm.insert({std::tuple_element_t<Is, Sigs>::hash(),
99 (
sm.insert({std::tuple_element_t<Is, Sigs>::DualArgs::hash(),
103 (
sm.insert({std::tuple_element_t<Is, Sigs>::Dual::hash(),
110 template<NodeType N,
typename S>
113 using R =
typename S::RetType;
114 using W =
typename S::WeightType;
124 InitMap(std::make_index_sequence<NodeTypes::Count>{});
129 fmt::print(
"================== \n");
130 fmt::print(
"dispatch table map_: \n");
134 fmt::print(
"{} : {} : DispatchFit\n",
nt,
sig);
136 fmt::print(
"{} : {} : DispatchPredict\n",
nt,
sig);
139 fmt::print(
"================== \n");
145 if (
this != &
other) {
165 if (
map_.at(
n).find(sig_hash) ==
map_.at(
n).end())
168 err += fmt::format(
"sig_hash={} not in map_.at({})\n",sig_hash,
n);
169 err += fmt::format(
"options:\n");
171 err+= fmt::format(
"{}\n",
k);
177 return std::get<Callable<T>>(
map_.at(
n).at(sig_hash));
191 auto msg = fmt::format(
"Tried get<Callable<{}>> for {} with hash {}; failed"
192 " because map holds index {}\n",
197 return std::get<Callable<T>>(
map_.at(
n).at(sig_hash));
void bind_engine(py::module &m, string name)
holds variable type data.
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
< nsga2 selection operator for getting the front
DispatchTable< false > dtable_predict
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
auto Is(NodeType nt) -> bool
DispatchTable< true > dtable_fit
static constexpr auto MakeCallable()
std::variant< Callable< ArrayXb >, Callable< ArrayXi >, Callable< ArrayXf >, Callable< ArrayXXb >, Callable< ArrayXXi >, Callable< ArrayXXf >, Callable< TimeSeriesb >, Callable< TimeSeriesi >, Callable< TimeSeriesf >, Callable< ArrayXbJet >, Callable< ArrayXiJet >, Callable< ArrayXfJet >, Callable< ArrayXXbJet >, Callable< ArrayXXiJet >, Callable< ArrayXXfJet >, Callable< Data::TimeSeriesbJet >, Callable< Data::TimeSeriesiJet >, Callable< Data::TimeSeriesfJet > > CallVariant
std::unordered_map< std::size_t, CallVariant > SigMap
maps Signature hashes -> Dispatch Operator
auto operator=(DispatchTable &&other) noexcept -> DispatchTable &
auto Get(NodeType n, std::size_t sig_hash) const -> Callable< T > const &
typename std::conditional_t< Fit, std::function< T(const Data::Dataset &, TreeNode &)>, std::function< T(const Data::Dataset &, TreeNode &, const typename WeightType< T >::type **)> > Callable
std::unordered_map< NodeType, SigMap > DTMap
maps NodeTypes -> Signature hash -> Dispatch Operator
DispatchTable(DispatchTable &&other) noexcept
static constexpr auto AddOperator(std::index_sequence< Is... >)
auto operator=(DispatchTable const &other) -> DispatchTable &
void InitMap(std::index_sequence< Is... >)
DispatchTable(DispatchTable const &other)
class holding the data for a node in a tree.
std::conditional_t< is_one_of_v< typename T::Scalar, fJet, iJet, bJet >, fJet, float > type