18 template<
typename T,
typename Scalar,
typename W>
24 if (weights ==
nullptr)
31 if constexpr (is_same_v<Scalar, W>)
35 else if constexpr (is_same_v<Scalar, iJet> && is_same_v<W, fJet>) {
36 using WScalar =
typename Scalar::Scalar;
37 WScalar tmp = WScalar((**weights).a);
43 *weights = *weights+1;
48 template<
typename T,
typename Scalar,
typename W>
56 if (tn.data.get_is_weighted())
74template<NodeType NT,
typename S,
bool Fit,
typename E=
void>
88 Array<typename S::FirstArg::Scalar, -1, S::ArgCount>,
89 typename S::ArgTypes>;
95 static constexpr size_t ArgCount = S::ArgCount;
98 template <std::
size_t N>
102 using W =
typename S::WeightType;
105 static constexpr auto F = [](
const auto& ...args) {
119 using arg_type = std::conditional_t<is_std_array_v<T>,
120 typename T::value_type, Array<
typename S::FirstArg::Scalar, -1, 1>>;
131 child_outputs.at(i) = sib->fit<arg_type>(d);
133 child_outputs.col(i) = sib->fit<arg_type>(d);
137 child_outputs.at(i) = sib->predict<arg_type>(d, weights);
139 child_outputs.col(i) = sib->predict<arg_type>(d, weights);
141 sib = sib->next_sibling;
143 return child_outputs;
150 auto sib = tree<TreeNode>::sibling_iterator(tn.first_child) ;
170 return std::make_tuple(
get_kid<Is>(d,tn,weights)...);
182 return get_kids_seq<T>(d, tn, weights, std::make_index_sequence<ArgCount>{});
194 return std::apply(
F, inputs);
214 template<
typename T=ArgTypes,
typename Scalar=RetType::Scalar>
217 auto inputs =
get_kids(d, tn, weights);
220 if (tn.data.get_is_weighted())
223 return this->
apply(inputs)*w;
230 template<
typename T=ArgTypes,
typename Scalar=RetType::Scalar>
234 auto inputs =
get_kids(d, tn, weights);
237 if (tn.data.get_is_weighted())
240 return this->
apply(inputs) + w;
243 return this->
apply(inputs);
249template<
typename S,
bool Fit>
253 using W =
typename S::WeightType;
256 template<
typename T=RetType,
typename Scalar=
typename T::Scalar>
262 if (tn.data.get_is_weighted())
272 template <
typename T = RetType,
typename Scalar=
typename T::Scalar>
279 if (tn.data.get_is_weighted())
282 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>()*w;
285 return this->get<nonJetType>(d, tn.data.get_feature()).template cast<Scalar>();
292 if (std::holds_alternative<T>(d[feature]))
293 return std::get<T>(d[feature]);
306template<
typename S,
bool Fit>
310 using W =
typename S::WeightType;
312 template<
typename T=RetType,
typename Scalar=T::Scalar,
int N=T::NumDimensions>
317 if constexpr (N == 1)
327template<
typename S,
bool Fit>
331 using W =
typename S::WeightType;
334 tn.data.W = d.
y.mean();
338 template<
typename T=RetType,
typename Scalar=T::Scalar,
int N=T::NumDimensions>
342 if constexpr (N == 1)
362template<
typename R, NodeType NT,
typename S,
bool Fit,
typename W>
366 return op.
eval(d, tn, weights);
369template<
typename R, NodeType NT,
typename S,
bool Fit>
373 return op.
eval(d, tn);
holds variable type data.
int get_n_samples() const
int get_n_features() const
ArrayXf y
length N array, the target label
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
Scalar get_weight(const TreeNode &tn, const W **weights=nullptr)
get weight
< nsga2 selection operator for getting the front
static constexpr bool is_tuple_v
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
static constexpr bool is_eigen_array_v
static constexpr bool is_one_of_v
typename UnJetify< T >::type UnJetify_t
static constexpr bool is_std_array_v
static constexpr bool is_in_v
static constexpr bool NaryOp
static constexpr bool UnaryOp
static constexpr Function< NT > F
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType fit(const Dataset &d, TreeNode &tn) const
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
auto get(const Dataset &d, const string &feature) const
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Core computation of a node's function to data.
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
evaluate the operator on the data. main entry point.
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::WeightType W
set weight type
T get_kids_seq(const Dataset &d, TreeNode &tn, const W **weights, std::index_sequence< Is... >) const
Makes and returns a tuple of child outputs.
static constexpr size_t ArgCount
stores the argument count of the operator
typename S::NthType< N > NthType
utility for returning the type of the Nth argument
typename S::RetType RetType
return type of the operator
static constexpr auto F
wrapper function for the node function
conditional_t<((UnaryOp< NT >||NaryOp< NT >) &&S::ArgCount > 1), Array< typename S::FirstArg::Scalar, -1, S::ArgCount >, typename S::ArgTypes > ArgTypes
set argument types to those of the signature unless:
NthType< I > get_kid(const Dataset &d, TreeNode &tn, const W **weights) const
gets one kid for a tuple of kids
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::tuple of kids. Used when child arguments are different types.
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::array or eigen array of kids
RetType apply(const T &inputs) const
Apply node function in a functional style.
RetType apply(const T &inputs) const
Apply the node function like a function.