14 template<
typename T>
requires same_as<typename T::Scalar, bool>
18 template<
typename T>
requires same_as<typename T::Scalar, bJet>
21 for (
int i = 0; i< x.size(); ++i)
25 template<
typename T>
requires same_as<typename T::Scalar, float>
27 return (x > threshold);
29 template<
typename T>
requires same_as<typename T::Scalar, fJet>
33 x.begin(), x.end(), ret.begin(),
34 [&](
const auto& e){return e > threshold;}
38 template<
typename T>
requires same_as<typename T::Scalar, int>
40 return (x == threshold);
43 template<
typename T>
requires same_as<typename T::Scalar, iJet>
48 x.begin(), x.end(), ret.begin(),
49 [&](
const auto& e){return e == threshold;}
54 float gain(
const ArrayXf& lsplit,
const ArrayXf& rsplit,
bool classification,
61 tuple<float,float>
best_threshold(
const T& x,
const ArrayXf& y,
bool classification)
75 float best_thresh, best_score =
MAX_FLT;
79 unique_classes = unique(y);
81 for (
const auto thresh: all_thresholds)
86 tie (L_idx, R_idx) = Util::mask_to_indices(mask);
89 const ArrayXf& lhs = y(L_idx);
90 const ArrayXf& rhs = y(R_idx);
92 if (lhs.size() == 0 || rhs.size() == 0)
96 float score =
gain(lhs, rhs, classification, unique_classes);
98 if (score < best_score || i == 0)
101 best_thresh = thresh;
106 best_thresh = std::isinf(best_thresh)?
107 0 : std::isnan(best_thresh)?
110 return make_tuple(best_thresh, best_score);
117 DataType DT = DataTypeEnum<T>::value;
134 for (
const auto& key : keys)
136 float tmp_thresh, score;
140 if (score < best_score | i == 0)
144 threshold = tmp_thresh;
148 auto tmp = std::make_tuple(feature, threshold, best_score);
149 results.push_back(std::make_tuple(feature, threshold, best_score));
152 template<
typename Ts, std::size_t... Is>
156 using entry = tuple<string, float, float>;
157 auto compare = [](
const entry& a,
const entry& b){
158 return (std::get<2>(a) < std::get<2>(b));
165 auto best = std::ranges::min_element(results, compare);
174 T result(mask.size());
177 tie (L_idx, R_idx) = Util::mask_to_indices(mask);
178 result(L_idx) = child_outputs.at(0);
179 result(R_idx) = child_outputs.at(1);
187template<NodeType NT,
typename S,
bool Fit>
188struct Operator<NT, S, Fit, enable_if_t<
is_in_v<NT,
NodeType::SplitOn, NodeType::SplitBest>>>
193 using W =
typename S::WeightType;
196 template <std::
size_t N>
200 static constexpr Function<NT>
F{};
208 if constexpr (NT==NodeType::SplitOn)
209 sib = sib->next_sibling;
211 for (
int i = 0; i < 2; ++i)
213 if (d.at(i).get_n_samples() > 0)
216 child_outputs.at(i) = sib->fit<arg_type>(d.at(i));
218 child_outputs.at(i) = sib->predict<arg_type>(d.at(i), weights);
220 sib = sib->next_sibling;
222 return child_outputs;
226 auto& threshold = tn.data.W;
229 if constexpr (NT == NodeType::SplitOn)
240 tn.data.set_feature(feature);
248 const auto& threshold = tn.data.W;
249 const auto& feature = tn.data.get_feature();
258 else if constexpr (NT==NodeType::SplitBest)
261 auto split_feature = tn.first_child->predict<
FirstArg>(d, weights);
267 auto child_outputs =
get_kids(data_splits, tn, weights);
holds variable type data.
bool classification
whether this is a classification problem
int get_n_samples() const
std::unordered_map< DataType, vector< string > > features_of_type
map from data types to features having that type.
ArrayXf y
length N array, the target label
std::array< Dataset, 2 > split(const ArrayXb &mask) const
class tree_node_< Node > TreeNode
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
ArrayXb threshold_mask(const T &x, const float &threshold)
Applies a learned threshold to a feature, returning a mask.
auto get_best_thresholds(const Dataset &d, std::index_sequence< Is... >)
vector< float > get_thresholds(const T &x)
float gain(const ArrayXf &lsplit, const ArrayXf &rsplit, bool classification, vector< float > unique_classes)
tuple< float, float > best_threshold(const T &x, const ArrayXf &y, bool classification)
void get_best_threshold_by_type(const Dataset &d, auto &results)
float gini_impurity_index(const ArrayXf &classes, const vector< float > &uc)
T stitch(array< T, 2 > &child_outputs, const ArrayXb &mask)
Stitches together outputs from left or right child based on threshold.
tuple< string, float > get_best_variable_and_threshold(const Dataset &d, TreeNode &tn)
static constexpr bool is_in_v
typename S::RetType RetType
typename S::ArgTypes ArgTypes
static constexpr Function< NT > F
typename S::NthType< N > NthType
RetType fit(const Dataset &d, TreeNode &tn) const
static constexpr size_t ArgCount
RetType eval(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const
array< RetType, 2 > get_kids(const array< Dataset, 2 > &d, TreeNode &tn, const W **weights=nullptr) const
typename S::FirstArg FirstArg
RetType predict(const Dataset &d, TreeNode &tn, const W **weights=nullptr) const