Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
lexicase.cpp
Go to the documentation of this file.
1#include "lexicase.h"
2
3namespace Brush {
4namespace Sel {
5
6using namespace Brush;
7using namespace Pop;
8using namespace Sel;
9
10template<ProgramType T>
12{
13 this->name = "lexicase";
14 this->survival = surv;
15}
16
17template<ProgramType T>
18vector<size_t> Lexicase<T>::select(Population<T>& pop, int island,
19 const Parameters& params)
20{
21 // this one can be executed in parallel because it is just reading the errors. This
22 // method assumes that the expressions have been fitted previously, and their respective
23 // error vectors are filled
24
25 auto island_pool = pop.get_island_indexes(island);
26
27 // if this is first generation, just return indices to pop
28 if (params.current_gen==0)
29 return island_pool;
30
31 //< number of samples
32 unsigned int N = pop.individuals.at(island_pool.at(0))->error.size();
33
34 //< number of individuals
35 unsigned int P = island_pool.size();
36
37 // define epsilon
38 ArrayXf epsilon = ArrayXf::Zero(N);
39
40 // if output is continuous, use epsilon lexicase
41 if (!params.classification || params.scorer.compare("log")==0
42 || params.scorer.compare("multi_log")==0
43 || params.scorer.compare("average_precision_score")==0 )
44 {
45 // for each sample, calculate epsilon
46 for (int i = 0; i<epsilon.size(); ++i)
47 {
48 VectorXf case_errors(island_pool.size());
49 for (int j = 0; j<island_pool.size(); ++j)
50 {
51 case_errors(j) = pop.individuals.at(island_pool[j])->error(i);
52 }
53
54 // notice that metric used to calculate the error must be a
55 // minimization problem in order for lexicase to work
56 epsilon(i) = mad(case_errors);
57 }
58 }
59 assert(epsilon.size() == N);
60
61 // selection pool
62 vector<size_t> starting_pool;
63 for (int i = 0; i < island_pool.size(); ++i)
64 {
65 starting_pool.push_back(island_pool[i]);
66 }
67 assert(starting_pool.size() == P);
68
69 vector<size_t> selected(P,0); // selected individuals
70
71 for (unsigned int i = 0; i<P; ++i) // selection loop
72 {
73 vector<size_t> cases; // cases (samples)
74 if (params.classification && !params.class_weights.empty())
75 {
76 // NOTE: when calling lexicase, make sure `errors` is from training
77 // data, and not from validation data. This is because the sample
78 // weights indexes are based on train partition
79
80 // for classification problems, weight case selection
81 // by class weights
82 cases.resize(0);
83 vector<size_t> choices(N);
84 std::iota(choices.begin(), choices.end(),0);
85
86 vector<float> sample_weights = params.sample_weights;
87
88 for (unsigned i = 0; i<N; ++i)
89 {
90 vector<size_t> choice_indices(N-i);
91 std::iota(choice_indices.begin(),choice_indices.end(),0);
92
93 size_t idx = *r.select_randomly(
94 choice_indices.begin(), choice_indices.end(),
95 sample_weights.begin(), sample_weights.end());
96
97 cases.push_back(choices.at(idx));
98 choices.erase(choices.begin() + idx);
99
100 sample_weights.erase(sample_weights.begin() + idx);
101 }
102 }
103 else
104 { // otherwise, choose cases randomly
105 cases.resize(N);
106 std::iota(cases.begin(),cases.end(),0);
107 r.shuffle(cases.begin(),cases.end()); // shuffle cases
108 }
109 vector<size_t> pool = starting_pool; // initial pool
110 vector<size_t> winner; // winners
111
112 bool pass = true; // checks pool size and number of cases
113 unsigned int h = 0; // case count
114
115 float epsilon_threshold;
116
117 while(pass){ // main loop
118 epsilon_threshold = 0;
119
120 winner.resize(0); // winners
121 // minimum error on case
122 float minfit = std::numeric_limits<float>::max();
123
124 // get minimum (assuming minization of indiviual errors)
125 for (size_t j = 0; j<pool.size(); ++j)
126 if (pop.individuals.at(pool[j])->error(cases[h]) < minfit)
127 minfit = pop.individuals.at(pool[j])->error(cases[h]);
128
129 // criteria to stay in pool
130 epsilon_threshold = minfit+epsilon[cases[h]];
131
132 // select best
133 for (size_t j = 0; j<pool.size(); ++j)
134 {
135 if (pop.individuals.at(pool[j])->error(cases[h])
136 <= epsilon_threshold)
137 winner.push_back(pool[j]);
138 }
139
140 ++h; // next case
141 // only keep going if needed
142 pass = (winner.size()>1 && h<cases.size());
143
144 if(winner.size() == 0)
145 {
146 if(h >= cases.size())
147 winner.push_back(*r.select_randomly(
148 pool.begin(), pool.end()) );
149 else
150 pass = true;
151 }
152 else
153 pool = winner; // reduce pool to remaining individuals
154 }
155
156 assert(winner.size()>0);
157
158 //if more than one winner, pick randomly
159 selected.at(i) = *r.select_randomly(
160 winner.begin(), winner.end() );
161 }
162
163 if (selected.size() != island_pool.size())
164 {
165 HANDLE_ERROR_THROW("Lexicase did not select correct number of \
166 parents");
167 }
168
169 return selected;
170}
171
172template<ProgramType T>
173vector<size_t> Lexicase<T>::survive(Population<T>& pop, int island,
174 const Parameters& params)
175{
176 /* Lexicase survival */
177 HANDLE_ERROR_THROW("Lexicase survival not implemented");
178 return vector<size_t>();
179}
180
181}
182}
vector< size_t > get_island_indexes(int island)
Definition population.h:39
vector< std::shared_ptr< Individual< T > > > individuals
Definition population.h:19
Lexicase(bool surv=false)
Definition lexicase.cpp:11
vector< size_t > survive(Population< T > &pop, int island, const Parameters &p)
lexicase survival
Definition lexicase.cpp:173
vector< size_t > select(Population< T > &pop, int island, const Parameters &p)
function returns a set of selected indices from pop
Definition lexicase.cpp:18
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
float mad(const ArrayXf &x)
median absolute deviation
Definition utils.cpp:373
static Rnd & r
Definition rnd.h:174
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
vector< float > sample_weights
weights for each sample
Definition params.h:73
bool classification
Definition params.h:77
vector< float > class_weights
weights for each class
Definition params.h:72
string scorer
actual loss function used, determined by error
Definition params.h:69
unsigned int current_gen
Definition params.h:29