18 template<
typename T,
typename Scalar,
typename W>
24 if (weights ==
nullptr)
26 if constexpr (std::is_floating_point_v<
decltype(tn.data.W)> || std::is_integral_v<
decltype(tn.data.W)>) {
27 if (std::isnan(tn.data.W) || tn.data.W == std::numeric_limits<
decltype(tn.data.W)>::lowest()) {
28 HANDLE_ERROR_THROW(
"TreeNode weight (W) is not set or is invalid for node: " + tn.data.name);
36 if (*weights ==
nullptr) {
37 std::string err_msg =
"Null pointer dereference: *weights is nullptr. "
38 "TreeNode ret_type: " + std::to_string(
static_cast<int>(tn.data.ret_type)) +
39 ", name: " + tn.data.name;
44 if constexpr (is_same_v<Scalar, W>)
49 else if constexpr (is_same_v<Scalar, iJet> && is_same_v<W, fJet>) {
50 using WScalar =
typename Scalar::Scalar;
51 WScalar tmp = WScalar((**weights).a);
58 *weights = *weights+1;
62 template<
typename T,
typename Scalar,
typename W>
70 if (tn.data.get_is_weighted())
75 "it should not be\n"));
89template<NodeType NT,
typename S,
bool Fit,
typename E=
void>
103 Array<typename S::FirstArg::Scalar, -1, S::ArgCount>,
104 typename S::ArgTypes>;
113 template <std::
size_t N>
117 using W =
typename S::WeightType;
120 static constexpr auto F = [](
const auto& ...args) {
134 using arg_type = std::conditional_t<is_std_array_v<T>,
135 typename T::value_type, Array<
typename S::FirstArg::Scalar, -1, 1>>;
146 child_outputs.at(i) = sib->fit<arg_type>(d);
148 child_outputs.col(i) = sib->fit<arg_type>(d);
152 child_outputs.at(i) = sib->predict<arg_type>(d, weights);
154 child_outputs.col(i) = sib->predict<arg_type>(d, weights);
156 sib = sib->next_sibling;
158 return child_outputs;
166 for (
int i = 0; i < I; ++i)
168 sib= sib->next_sibling;
188 return std::make_tuple(
get_kid<Is>(d,tn,weights)...);
200 return get_kids_seq<T>(d, tn, weights, std::make_index_sequence<ArgCount>{});
212 return std::apply(
F, inputs);
232 template<
typename T=ArgTypes,
typename Scalar=RetType::Scalar>
235 auto inputs =
get_kids(d, tn, weights);
238 if (tn.data.get_is_weighted())
241 return this->
apply(inputs)*w;
244 return this->
apply(inputs);
248 template<
typename T=ArgTypes,
typename Scalar=RetType::Scalar>
252 auto inputs =
get_kids(d, tn, weights);
255 if (tn.data.get_is_weighted())
258 return this->
apply(inputs) + w;
261 return this->
apply(inputs);
267template<
typename S,
bool Fit>
271 using W =
typename S::WeightType;
274 template<
typename T=RetType,
typename Scalar=
typename T::Scalar>
280 if (tn.data.get_is_weighted())
283 return this->get<RetType>(d, tn.data.get_feature())*w;
286 return this->get<RetType>(d,tn.data.get_feature());
290 template <
typename T = RetType,
typename Scalar=
typename T::Scalar>
297 if (tn.data.get_is_weighted())
300 return this->
get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>()*w;
303 return this->
get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>();
308 auto get(
const Dataset& d,
const string& feature)
const
310 if (std::holds_alternative<T>(d[feature]))
311 return std::get<T>(d[feature]);
313 HANDLE_ERROR_THROW(fmt::format(
"Failed to return type {} for '{}'. The feature's original ret type is {}.\n",
325template<
typename S,
bool Fit>
329 using W =
typename S::WeightType;
331 template<
typename T=RetType,
typename Scalar=T::Scalar,
int N=T::NumDimensions>
342 if constexpr (N == 1)
352template<
typename S,
bool Fit>
356 using W =
typename S::WeightType;
362 std::unordered_map<float, int> counters;
363 for (
float val : d.
y) {
364 if (counters.find(val) != counters.end()) {
373 auto mode = std::max_element(
374 counters.begin(), counters.end(),
375 [](
const auto& a,
const auto& b) { return a.second < b.second; }
378 tn.data.W = mode->first;
382 tn.data.W = d.
y.mean();
388 template<
typename T=RetType,
typename Scalar=T::Scalar,
int N=T::NumDimensions>
392 if constexpr (N == 1)
412template<
typename R, NodeType NT,
typename S,
bool Fit,
typename W>
416 return op.
eval(d, tn, weights);
419template<
typename R, NodeType NT,
typename S,
bool Fit>
423 return op.
eval(d, tn);
holds variable type data.
bool classification
whether this is a classification problem
DataType get_feature_type(const string &name) const
int get_n_samples() const
int get_n_features() const
ArrayXf y
length N array, the target label
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
Scalar get_weight(const TreeNode &tn, const W **weights=nullptr)
get weight
< nsga2 selection operator for getting the front
static constexpr bool is_tuple_v
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
static constexpr bool is_eigen_array_v
static constexpr bool is_one_of_v
typename UnJetify< T >::type UnJetify_t
map< DataType, string > DataTypeName
static constexpr bool is_std_array_v
static constexpr bool is_in_v
static constexpr bool NaryOp
static constexpr bool UnaryOp
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType fit(const Dataset &d, TreeNode &tn) const
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
auto get(const Dataset &d, const string &feature) const
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Core computation of a node's function to data.
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
evaluate the operator on the data. main entry point.
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::WeightType W
set weight type
T get_kids_seq(const Dataset &d, TreeNode &tn, const W **weights, std::index_sequence< Is... >) const
Makes and returns a tuple of child outputs.
static constexpr size_t ArgCount
stores the argument count of the operator
typename S::RetType RetType
return type of the operator
typename S::template NthType< N > NthType
utility for returning the type of the Nth argument
static constexpr auto F
wrapper function for the node function
conditional_t<((UnaryOp< NT >||NaryOp< NT >) &&S::ArgCount > 1), Array< typename S::FirstArg::Scalar, -1, S::ArgCount >, typename S::ArgTypes > ArgTypes
set argument types to those of the signature unless:
NthType< I > get_kid(const Dataset &d, TreeNode &tn, const W **weights) const
gets one kid for a tuple of kids
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::array or eigen array of kids
RetType apply(const T &inputs) const
Apply node function in a functional style.
RetType apply(const T &inputs) const
Apply the node function like a function.