Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
split.cpp
Go to the documentation of this file.
1#include "operator.h"
2#include <utility>
4
5namespace Brush::Split{
6
7tuple<string,float> get_best_variable_and_threshold(const Dataset& d, TreeNode& tn)
8{
9 /* loops thru variables in d and picks the best threshold
10 * and feature to split at.
11 */
12 using FeatTypes = tuple<ArrayXf,ArrayXi,ArrayXb>;
13 constexpr auto size = std::tuple_size<FeatTypes>::value;
14 auto [feature, threshold, best_score] = get_best_thresholds<FeatTypes>(d, std::make_index_sequence<size>{});
15 return std::make_tuple(feature, threshold);
16}
17
18template<> vector<float> get_thresholds<ArrayXb>(const ArrayXb& x){ return vector<float>{1.0}; }
19template<> vector<float> get_thresholds<ArrayXbJet>(const ArrayXbJet& x){ return vector<float>{1.0}; }
20template<> vector<float> get_thresholds<ArrayXi>(const ArrayXi& x){
21 vector<float> thresholds;
22 for (const auto& val : unique(x))
23 thresholds.push_back(val);
24 return thresholds;
25}
26template<> vector<float> get_thresholds<ArrayXiJet>(const ArrayXiJet& x){
27 vector<float> thresholds;
28 for (const auto& val : unique(x))
29 thresholds.push_back(val.a);
30 return thresholds;
31}
32
33template<> vector<float> get_thresholds<ArrayXf>(const ArrayXf& x){
34 vector<float> thresholds;
35 auto s = unique(x);
36 for (unsigned i =0; i<s.size()-1; ++i)
37 {
38 // thresholds.push_back((s.at(i) + s.at(i+1))/2.0);
39 thresholds.push_back(s.at(i+1));
40 }
41 return thresholds;
42}
43template<> vector<float> get_thresholds<ArrayXfJet>(const ArrayXfJet& x){
44 vector<float> thresholds;
45 auto s = unique(x);
46 for (unsigned i =0; i<s.size()-1; ++i)
47 {
48 // thresholds.push_back((s.at(i).a + s.at(i+1).a)/float(2.0));
49 thresholds.push_back(s.at(i+1).a);
50 }
51 return thresholds;
52}
53
54
55template<>
56ArrayXb threshold_mask<State>(const State& x, const float& threshold) {
57 return std::visit(
58 [&](const auto& arg) -> ArrayXb {
59 using T = std::decay_t<decltype(arg)>;
60 if constexpr (T::NumDimensions == 1)
61 return threshold_mask(arg, threshold);
62 else
63 return ArrayXb::Constant(arg.size(), true);
64 },
65 x
66 );
67}
68float gain(const ArrayXf& lsplit,
69 const ArrayXf& rsplit,
70 bool classification, vector<float> unique_classes)
71 {
72 float lscore, rscore, score;
73 if (classification)
74 {
75 lscore = gini_impurity_index(lsplit, unique_classes);
76 rscore = gini_impurity_index(rsplit, unique_classes);
77 /* cout << "lscore: " << lscore << "\n"; */
78 /* cout << "rscore: " << rscore << "\n"; */
79 score = (lscore*float(lsplit.size()) +
80 rscore*float(rsplit.size()))
81 /(float(lsplit.size()) + float(rsplit.size()));
82 }
83 else
84 {
85 lscore = variance(lsplit)/float(lsplit.size());
86 rscore = variance(rsplit)/float(rsplit.size());
87 /* cout << "lscore: " << lscore << "\n"; */
88 /* cout << "rscore: " << rscore << "\n"; */
89 score = lscore + rscore;
90 }
91
92 return score;
93 }
94
95float gini_impurity_index(const ArrayXf& classes,
96 const vector<float>& unique_classes)
97{
98 vector<float> class_weights;
99 for (auto c : unique_classes){
100 class_weights.push_back(
101 float( (classes.cast<int>() == int(c)).count())/classes.size()
102 );
103 }
104 /* float total_weight=class_weights.sum(); */
105 auto cw = VectorXf::Map(class_weights.data(), class_weights.size());
106 float gini = 1 - cw.dot(cw);
107
108 return gini;
109}
110
111} //Brush::Split
holds variable type data.
Definition data.h:51
class tree_node_< Node > TreeNode
std::variant< ArrayXb, ArrayXi, ArrayXf, ArrayXXb, ArrayXXi, ArrayXXf, TimeSeriesb, TimeSeriesi, TimeSeriesf, ArrayXbJet, ArrayXiJet, ArrayXfJet, ArrayXXbJet, ArrayXXiJet, ArrayXXfJet, TimeSeriesbJet, TimeSeriesiJet, TimeSeriesfJet > State
defines the possible types of data flowing thru nodes.
Definition types.h:140
vector< float > get_thresholds< ArrayXb >(const ArrayXb &x)
Definition split.cpp:18
tuple< string, float > get_best_variable_and_threshold(const Dataset &d, TreeNode &tn)
Definition split.cpp:7
vector< float > get_thresholds< ArrayXfJet >(const ArrayXfJet &x)
Definition split.cpp:43
ArrayXb threshold_mask< State >(const State &x, const float &threshold)
Definition split.cpp:56
float gain(const ArrayXf &lsplit, const ArrayXf &rsplit, bool classification, vector< float > unique_classes)
Definition split.cpp:68
float gini_impurity_index(const ArrayXf &classes, const vector< float > &uc)
Definition split.cpp:95
vector< float > get_thresholds< ArrayXi >(const ArrayXi &x)
Definition split.cpp:20
vector< float > get_thresholds< ArrayXf >(const ArrayXf &x)
Definition split.cpp:33
vector< float > get_thresholds< ArrayXiJet >(const ArrayXiJet &x)
Definition split.cpp:26
auto get_best_thresholds(const Dataset &d, std::index_sequence< Is... >)
Definition operator.h:157
ArrayXb threshold_mask(const T &x, const float &threshold)
Applies a learned threshold to a feature, returning a mask.
Definition operator.h:16
vector< float > get_thresholds< ArrayXbJet >(const ArrayXbJet &x)
Definition split.cpp:19
vector< T > unique(vector< T > w)
returns unique elements in vector
Definition utils.h:334
float variance(const ArrayXf &v)
calculate variance
Definition utils.cpp:317
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
Eigen::Array< fJet, Eigen::Dynamic, 1 > ArrayXfJet
Definition types.h:49
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Eigen::Array< bJet, Eigen::Dynamic, 1 > ArrayXbJet
Definition types.h:51
Eigen::Array< iJet, Eigen::Dynamic, 1 > ArrayXiJet
Definition types.h:50