21 for (
int i = 0;
i<
x.size(); ++
i)
33 x.begin(),
x.end(),
ret.begin(),
34 [&](
const auto&
e){return e > threshold;}
48 x.begin(),
x.end(),
ret.begin(),
49 [&](
const auto&
e){return e == threshold;}
61 tuple<float,float>
best_threshold(
const T&
x,
const ArrayXf& y,
bool classification)
92 if (
lhs.size() == 0 ||
rhs.size() == 0)
98 if (score < best_score ||
i == 0)
117 DataType DT = DataTypeEnum<T>::value;
126 if (
d.features_of_type.find(DT) !=
d.features_of_type.end())
127 keys =
d.features_of_type.at(DT);
134 for (
const auto&
key : keys)
140 if (score < best_score |
i == 0)
148 auto tmp = std::make_tuple(feature,
threshold, best_score);
152 template<
typename Ts, std::size_t... Is>
156 using entry = tuple<string, float, float>;
158 return (std::get<2>(
a) < std::get<2>(
b));
187template<NodeType NT,
typename S,
bool Fit>
193 using W =
typename S::WeightType;
194 static constexpr size_t ArgCount = S::ArgCount;
196 template <std::
size_t N>
208 if constexpr (
NT==NodeType::SplitOn)
211 for (
int i = 0;
i < 2; ++
i)
213 if (
d.at(
i).get_n_samples() > 0)
229 if constexpr (
NT == NodeType::SplitOn)
240 tn.data.set_feature(feature);
243 return predict(
d,
tn);
249 const auto& feature =
tn.data.get_feature();
255 mask.resize(
d.get_n_samples());
258 else if constexpr (
NT==NodeType::SplitBest)
281 return predict(
d,
tn,weights);
void bind_engine(py::module &m, string name)
class tree_node_< Node > TreeNode
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
ArrayXb threshold_mask(const T &x, const float &threshold)
Applies a learned threshold to a feature, returning a mask.
auto get_best_thresholds(const Dataset &d, std::index_sequence< Is... >)
vector< float > get_thresholds(const T &x)
float gain(const ArrayXf &lsplit, const ArrayXf &rsplit, bool classification, vector< float > unique_classes)
tuple< float, float > best_threshold(const T &x, const ArrayXf &y, bool classification)
void get_best_threshold_by_type(const Dataset &d, auto &results)
float gini_impurity_index(const ArrayXf &classes, const vector< float > &uc)
T stitch(array< T, 2 > &child_outputs, const ArrayXb &mask)
Stitches together outputs from left or right child based on threshold.
tuple< string, float > get_best_variable_and_threshold(const Dataset &d, TreeNode &tn)
static constexpr bool is_in_v
typename S::RetType RetType
typename S::ArgTypes ArgTypes
typename S::NthType< N > NthType
RetType fit(const Dataset &d, TreeNode &tn) const
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
array< RetType, 2 > get_kids(const array< Dataset, 2 > &d, TreeNode &tn, const W **weights=nullptr) const
typename S::FirstArg FirstArg
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const