18 template<
typename T,
typename Scalar,
typename W>
24 if (weights ==
nullptr)
31 if constexpr (is_same_v<Scalar, W>)
35 else if constexpr (is_same_v<Scalar, iJet> && is_same_v<W, fJet>) {
36 using WScalar =
typename Scalar::Scalar;
37 WScalar tmp = WScalar((**weights).a);
43 *weights = *weights+1;
48 template<
typename T,
typename Scalar,
typename W>
56 if (tn.data.get_is_weighted())
74template<NodeType NT,
typename S,
bool Fit,
typename E=
void>
88 Array<typename S::FirstArg::Scalar, -1, S::ArgCount>,
89 typename S::ArgTypes>;
95 static constexpr size_t ArgCount = S::ArgCount;
98 template <std::
size_t N>
102 using W =
typename S::WeightType;
105 static constexpr auto F = [](
const auto& ...args) {
119 using arg_type = std::conditional_t<is_std_array_v<T>,
120 typename T::value_type, Array<
typename S::FirstArg::Scalar, -1, 1>>;
131 child_outputs.at(i) = sib->fit<arg_type>(d);
133 child_outputs.col(i) = sib->fit<arg_type>(d);
137 child_outputs.at(i) = sib->predict<arg_type>(d, weights);
139 child_outputs.col(i) = sib->predict<arg_type>(d, weights);
141 sib = sib->next_sibling;
143 return child_outputs;
151 for (
int i = 0; i < I; ++i)
153 sib= sib->next_sibling;
185 return get_kids_seq<T>(d, tn, weights, std::make_index_sequence<ArgCount>{});
197 return std::apply(
F, inputs);
217 template<
typename T=ArgTypes,
typename Scalar=RetType::Scalar>
220 auto inputs =
get_kids(d, tn, weights);
223 if (tn.data.get_is_weighted())
229 return this->
apply(inputs);
233 template<
typename T=ArgTypes,
typename Scalar=RetType::Scalar>
237 auto inputs =
get_kids(d, tn, weights);
240 if (tn.data.get_is_weighted())
243 return this->
apply(inputs) + w;
246 return this->
apply(inputs);
252template<
typename S,
bool Fit>
256 using W =
typename S::WeightType;
259 template<
typename T=RetType,
typename Scalar=
typename T::Scalar>
265 if (tn.data.get_is_weighted())
275 template <
typename T = RetType,
typename Scalar=
typename T::Scalar>
282 if (tn.data.get_is_weighted())
285 return this->
get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>()*w;
288 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>();
295 if (std::holds_alternative<T>(d[feature]))
296 return std::get<T>(d[feature]);
309template<
typename S,
bool Fit>
313 using W =
typename S::WeightType;
315 template<
typename T=RetType,
typename Scalar=T::Scalar,
int N=T::NumDimensions>
320 if constexpr (N == 1)
330template<
typename S,
bool Fit>
334 using W =
typename S::WeightType;
340 std::unordered_map<float, int> counters;
341 for (
float val : d.
y) {
342 if (counters.find(val) != counters.end()) {
351 auto mode = std::max_element(
352 counters.begin(), counters.end(),
353 [](
const auto& a,
const auto& b) { return a.second < b.second; }
356 tn.data.W = mode->first;
360 tn.data.W = d.
y.mean();
366 template<
typename T=RetType,
typename Scalar=T::Scalar,
int N=T::NumDimensions>
370 if constexpr (N == 1)
390template<
typename R, NodeType NT,
typename S,
bool Fit,
typename W>
394 return op.
eval(d, tn, weights);
397template<
typename R, NodeType NT,
typename S,
bool Fit>
401 return op.
eval(d, tn);
holds variable type data.
bool classification
whether this is a classification problem
int get_n_samples() const
int get_n_features() const
ArrayXf y
length N array, the target label
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
Scalar get_weight(const TreeNode &tn, const W **weights=nullptr)
get weight
< nsga2 selection operator for getting the front
static constexpr bool is_tuple_v
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
static constexpr bool is_eigen_array_v
static constexpr bool is_one_of_v
typename UnJetify< T >::type UnJetify_t
static constexpr bool is_std_array_v
static constexpr bool is_in_v
static constexpr bool NaryOp
static constexpr bool UnaryOp
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType fit(const Dataset &d, TreeNode &tn) const
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
auto get(const Dataset &d, const string &feature) const
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Core computation of a node's function to data.
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
evaluate the operator on the data. main entry point.
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::WeightType W
set weight type
T get_kids_seq(const Dataset &d, TreeNode &tn, const W **weights, std::index_sequence< Is... >) const
Makes and returns a tuple of child outputs.
static constexpr size_t ArgCount
stores the argument count of the operator
typename S::RetType RetType
return type of the operator
typename S::template NthType< N > NthType
utility for returning the type of the Nth argument
static constexpr auto F
wrapper function for the node function
conditional_t<((UnaryOp< NT >||NaryOp< NT >) &&S::ArgCount > 1), Array< typename S::FirstArg::Scalar, -1, S::ArgCount >, typename S::ArgTypes > ArgTypes
set argument types to those of the signature unless:
NthType< I > get_kid(const Dataset &d, TreeNode &tn, const W **weights) const
gets one kid for a tuple of kids
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::tuple of kids. Used when child arguments are different types.
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::array or eigen array of kids
RetType apply(const T &inputs) const
Apply node function in a functional style.
RetType apply(const T &inputs) const
Apply the node function like a function.