Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
split.cpp
Go to the documentation of this file.
1#include "operator.h"
2#include <utility>
4
5namespace Brush::Split{
6
7tuple<string,float> get_best_variable_and_threshold(const Dataset& d, TreeNode& tn)
8{
9 /* loops thru variables in d and picks the best threshold
10 * and feature to split at.
11 */
12 using FeatTypes = tuple<ArrayXf,ArrayXi,ArrayXb>;
13 constexpr auto size = std::tuple_size<FeatTypes>::value;
14 auto [feature, threshold, best_score] = get_best_thresholds<FeatTypes>(d, std::make_index_sequence<size>{});
15 return std::make_tuple(feature, threshold);
16}
17
18template<> vector<float> get_thresholds<ArrayXb>(const ArrayXb& x){ return vector<float>{0.0}; }
19template<> vector<float> get_thresholds<ArrayXbJet>(const ArrayXbJet& x){ return vector<float>{0.0}; }
20template<> vector<float> get_thresholds<ArrayXi>(const ArrayXi& x){
21 vector<float> thresholds;
22 for (const auto& val : unique(x))
24 return thresholds;
25}
26template<> vector<float> get_thresholds<ArrayXiJet>(const ArrayXiJet& x){
27 vector<float> thresholds;
28 for (const auto& val : unique(x))
30 return thresholds;
31}
32
33template<> vector<float> get_thresholds<ArrayXf>(const ArrayXf& x){
34 vector<float> thresholds;
35 auto s = unique(x);
36 for (unsigned i =0; i<s.size()-1; ++i)
37 {
38 thresholds.push_back((s.at(i) + s.at(i+1))/2.0);
39 }
40 return thresholds;
41}
42template<> vector<float> get_thresholds<ArrayXfJet>(const ArrayXfJet& x){
43 vector<float> thresholds;
44 auto s = unique(x);
45 for (unsigned i =0; i<s.size()-1; ++i)
46 {
47 thresholds.push_back((s.at(i).a + s.at(i+1).a)/float(2.0));
48 }
49 return thresholds;
50}
51
52
53template<>
54ArrayXb threshold_mask<State>(const State& x, const float& threshold) {
55 return std::visit(
56 [&](const auto& arg) -> ArrayXb {
57 using T = std::decay_t<decltype(arg)>;
58 if constexpr (T::NumDimensions == 1)
59 return threshold_mask(arg, threshold);
60 else
61 return ArrayXb::Constant(arg.size(), true);
62 },
63 x
64 );
65}
66float gain(const ArrayXf& lsplit,
67 const ArrayXf& rsplit,
68 bool classification, vector<float> unique_classes)
69 {
70 float lscore, rscore, score;
71 if (classification)
72 {
75 /* cout << "lscore: " << lscore << "\n"; */
76 /* cout << "rscore: " << rscore << "\n"; */
77 score = (lscore*float(lsplit.size()) +
78 rscore*float(rsplit.size()))
79 /(float(lsplit.size()) + float(rsplit.size()));
80 }
81 else
82 {
85 /* cout << "lscore: " << lscore << "\n"; */
86 /* cout << "rscore: " << rscore << "\n"; */
87 score = lscore + rscore;
88 }
89
90 return score;
91 }
92
93float gini_impurity_index(const ArrayXf& classes,
94 const vector<float>& unique_classes)
95{
96 vector<float> class_weights;
97 for (auto c : unique_classes){
98 class_weights.push_back(
99 float( (classes.cast<int>() == int(c)).count())/classes.size()
100 );
101 }
102 /* float total_weight=class_weights.sum(); */
103 auto cw = VectorXf::Map(class_weights.data(), class_weights.size());
104 float gini = 1 - cw.dot(cw);
105
106 return gini;
107}
108
109} //Brush::Split
void bind_engine(py::module &m, string name)
holds variable type data.
Definition data.h:51
class tree_node_< Node > TreeNode
std::variant< ArrayXb, ArrayXi, ArrayXf, ArrayXXb, ArrayXXi, ArrayXXf, TimeSeriesb, TimeSeriesi, TimeSeriesf, ArrayXbJet, ArrayXiJet, ArrayXfJet, ArrayXXbJet, ArrayXXiJet, ArrayXXfJet, TimeSeriesbJet, TimeSeriesiJet, TimeSeriesfJet > State
defines the possible types of data flowing thru nodes.
Definition types.h:140
vector< float > get_thresholds< ArrayXb >(const ArrayXb &x)
Definition split.cpp:18
tuple< string, float > get_best_variable_and_threshold(const Dataset &d, TreeNode &tn)
Definition split.cpp:7
vector< float > get_thresholds< ArrayXfJet >(const ArrayXfJet &x)
Definition split.cpp:42
ArrayXb threshold_mask< State >(const State &x, const float &threshold)
Definition split.cpp:54
float gain(const ArrayXf &lsplit, const ArrayXf &rsplit, bool classification, vector< float > unique_classes)
Definition split.cpp:66
vector< float > get_thresholds< ArrayXi >(const ArrayXi &x)
Definition split.cpp:20
vector< float > get_thresholds< ArrayXf >(const ArrayXf &x)
Definition split.cpp:33
float gini_impurity_index(const ArrayXf &classes, const vector< float > &unique_classes)
Definition split.cpp:93
vector< float > get_thresholds< ArrayXiJet >(const ArrayXiJet &x)
Definition split.cpp:26
vector< float > get_thresholds< ArrayXbJet >(const ArrayXbJet &x)
Definition split.cpp:19
vector< T > unique(vector< T > w)
returns unique elements in vector
Definition utils.h:334
float variance(const ArrayXf &v)
calculate variance
Definition utils.cpp:317
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
Eigen::Array< fJet, Eigen::Dynamic, 1 > ArrayXfJet
Definition types.h:49
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Eigen::Array< bJet, Eigen::Dynamic, 1 > ArrayXbJet
Definition types.h:51
Eigen::Array< iJet, Eigen::Dynamic, 1 > ArrayXiJet
Definition types.h:50