18 template<
typename T,
typename Scalar,
typename W>
24 if (weights ==
nullptr)
36 using WScalar =
typename Scalar::Scalar;
43 *weights = *weights+1;
48 template<
typename T,
typename Scalar,
typename W>
56 if (
tn.data.get_is_weighted())
74template<NodeType NT,
typename S,
bool Fit,
typename E=
void>
88 Array<
typename S::FirstArg::Scalar, -1, S::ArgCount>,
89 typename S::ArgTypes>;
95 static constexpr size_t ArgCount = S::ArgCount;
98 template <std::
size_t N>
102 using W =
typename S::WeightType;
105 static constexpr auto F = [](
const auto& ...args) {
119 using arg_type = std::conditional_t<is_std_array_v<T>,
120 typename T::value_type, Array<
typename S::FirstArg::Scalar, -1, 1>>;
150 auto sib = tree<TreeNode>::sibling_iterator(
tn.first_child) ;
194 return std::apply(
F, inputs);
214 template<
typename T=ArgTypes,
typename Scalar=RetType::Scalar>
220 if (
tn.data.get_is_weighted())
222 auto w = util::get_weight<RetType,Scalar,W>(
tn, weights);
223 return this->
apply(inputs)*
w;
226 return this->
apply(inputs);
230 template<
typename T=ArgTypes,
typename Scalar=RetType::Scalar>
237 if (
tn.data.get_is_weighted())
239 auto w = util::get_weight<RetType,Scalar,W>(
tn, weights);
240 return this->
apply(inputs) +
w;
243 return this->
apply(inputs);
249template<
typename S,
bool Fit>
253 using W =
typename S::WeightType;
256 template<
typename T=RetType,
typename Scalar=
typename T::Scalar>
262 if (
tn.data.get_is_weighted())
264 auto w = util::get_weight<RetType,Scalar,W>(
tn, weights);
272 template <
typename T = RetType,
typename Scalar=
typename T::Scalar>
279 if (
tn.data.get_is_weighted())
281 auto w = util::get_weight<RetType,Scalar,W>(
tn, weights);
292 if (std::holds_alternative<T>(
d[feature]))
293 return std::get<T>(
d[feature]);
306template<
typename S,
bool Fit>
310 using W =
typename S::WeightType;
312 template<
typename T=RetType,
typename Scalar=T::Scalar,
int N=T::NumDimensions>
316 Scalar w = util::get_weight<RetType,Scalar,W>(
tn, weights);
317 if constexpr (
N == 1)
318 return RetType::Constant(
d.get_n_samples(),
w);
320 return RetType::Constant(
d.get_n_samples(),
d.get_n_features(),
w);
327template<
typename S,
bool Fit>
331 using W =
typename S::WeightType;
334 tn.data.W =
d.y.mean();
335 return predict(
d,
tn);
338 template<
typename T=RetType,
typename Scalar=T::Scalar,
int N=T::NumDimensions>
341 Scalar w = util::get_weight<RetType,Scalar,W>(
tn, weights);
342 if constexpr (
N == 1)
343 return RetType::Constant(
d.get_n_samples(),
w);
345 return RetType::Constant(
d.get_n_samples(),
d.get_n_features(),
w);
352 return predict(
d,
tn,weights);
362template<
typename R, NodeType NT,
typename S,
bool Fit,
typename W>
369template<
typename R, NodeType NT,
typename S,
bool Fit>
void bind_engine(py::module &m, string name)
holds variable type data.
class tree_node_< Node > TreeNode
#define HANDLE_ERROR_THROW(err)
Scalar get_weight(const TreeNode &tn, const W **weights=nullptr)
get weight
< nsga2 selection operator for getting the front
R DispatchOp(const Data::Dataset &d, TreeNode &tn, const W **weights)
typename UnJetify< T >::type UnJetify_t
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType fit(const Dataset &d, TreeNode &tn) const
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::RetType RetType
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
auto get(const Dataset &d, const string &feature) const
RetType eval(const Dataset &d, const TreeNode &tn, const W **weights=nullptr) const
Core computation of a node's function to data.
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
evaluate the operator on the data. main entry point.
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
typename S::WeightType W
set weight type
T get_kids_seq(const Dataset &d, TreeNode &tn, const W **weights, std::index_sequence< Is... >) const
Makes and returns a tuple of child outputs.
static constexpr size_t ArgCount
stores the argument count of the operator
typename S::NthType< N > NthType
utility for returning the type of the Nth argument
typename S::RetType RetType
return type of the operator
static constexpr auto F
wrapper function for the node function
conditional_t<((UnaryOp< NT >||NaryOp< NT >) &&S::ArgCount > 1), Array< typename S::FirstArg::Scalar, -1, S::ArgCount >, typename S::ArgTypes > ArgTypes
set argument types to those of the signature unless:
NthType< I > get_kid(const Dataset &d, TreeNode &tn, const W **weights) const
gets one kid for a tuple of kids
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::tuple of kids. Used when child arguments are different types.
T get_kids(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
get a std::array or eigen array of kids
RetType apply(const T &inputs) const
Apply node function in a functional style.
RetType apply(const T &inputs) const
Apply the node function like a function.