Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
lexicase.cpp
Go to the documentation of this file.
1#include "lexicase.h"
2
3namespace Brush {
4namespace Sel {
5
6using namespace Brush;
7using namespace Pop;
8using namespace Sel;
9
10template<ProgramType T>
12{
13 this->name = "lexicase";
14 this->survival = surv;
15}
16
17template<ProgramType T>
18vector<size_t> Lexicase<T>::select(Population<T>& pop, int island,
19 const Parameters& params)
20{
21 // this one can be executed in parallel because it is just reading the errors. This
22 // method assumes that the expressions have been fitted previously, and their respective
23 // error vectors are filled
24
25 auto island_pool = pop.get_island_indexes(island);
26
27 // Filter out nullptr individuals (offspring slots)
28 island_pool.erase(
29 std::remove_if(island_pool.begin(), island_pool.end(),
30 [&pop](size_t idx) { return pop.individuals.at(idx) == nullptr; }),
31 island_pool.end()
32 );
33
34 // If all individuals were nullptr, return empty selection
35 if (island_pool.empty())
36 return island_pool;
37
38 // if this is first generation, just return indices to pop
39 if (params.current_gen==0)
40 return island_pool;
41
42 //< number of samples
43 unsigned int N = pop.individuals.at(island_pool.at(0))->error.size();
44
45 //< number of individuals
46 unsigned int P = island_pool.size();
47
48 // define epsilon
49 ArrayXf epsilon = ArrayXf::Zero(N);
50
51 // if output is continuous, use epsilon lexicase
52 if (!params.classification || params.scorer.compare("log")==0
53 || params.scorer.compare("multi_log")==0
54 || params.scorer.compare("average_precision_score")==0 )
55 {
56 // for each sample, calculate epsilon
57 for (int i = 0; i<epsilon.size(); ++i)
58 {
59 VectorXf case_errors(island_pool.size());
60 for (int j = 0; j<island_pool.size(); ++j)
61 {
62 case_errors(j) = pop.individuals.at(island_pool[j])->error(i);
63 }
64
65 // notice that metric used to calculate the error must be a
66 // minimization problem in order for lexicase to work
67 epsilon(i) = mad(case_errors);
68 }
69 }
70 assert(epsilon.size() == N);
71
72 // selection pool
73 vector<size_t> starting_pool;
74 for (int i = 0; i < island_pool.size(); ++i)
75 {
76 starting_pool.push_back(island_pool[i]);
77 }
78 assert(starting_pool.size() == P);
79
80 vector<size_t> selected(P,0); // selected individuals
81
82 for (unsigned int i = 0; i<P; ++i) // selection loop
83 {
84 vector<size_t> cases; // cases (samples)
85 if (params.classification && !params.class_weights.empty())
86 {
87 // NOTE: when calling lexicase, make sure `errors` is from training
88 // data, and not from validation data. This is because the sample
89 // weights indexes are based on train partition
90
91 // for classification problems, weight case selection
92 // by class weights
93 cases.resize(0);
94 vector<size_t> choices(N);
95 std::iota(choices.begin(), choices.end(),0);
96
97 vector<float> sample_weights = params.sample_weights;
98
99 for (unsigned i = 0; i<N; ++i)
100 {
101 vector<size_t> choice_indices(N-i);
102 std::iota(choice_indices.begin(),choice_indices.end(),0);
103
104 size_t idx = *r.select_randomly(
105 choice_indices.begin(), choice_indices.end(),
106 sample_weights.begin(), sample_weights.end());
107
108 cases.push_back(choices.at(idx));
109 choices.erase(choices.begin() + idx);
110
111 sample_weights.erase(sample_weights.begin() + idx);
112 }
113 }
114 else
115 { // otherwise, choose cases randomly
116 cases.resize(N);
117 std::iota(cases.begin(),cases.end(),0);
118 r.shuffle(cases.begin(),cases.end()); // shuffle cases
119 }
120 vector<size_t> pool = starting_pool; // initial pool
121 vector<size_t> winner; // winners
122
123 bool pass = true; // checks pool size and number of cases
124 unsigned int h = 0; // case count
125
126 float epsilon_threshold;
127
128 while(pass){ // main loop
129 epsilon_threshold = 0;
130
131 winner.resize(0); // winners
132 // minimum error on case
133 float minfit = std::numeric_limits<float>::max();
134
135 // get minimum (assuming minization of indiviual errors)
136 for (size_t j = 0; j<pool.size(); ++j)
137 if (pop.individuals.at(pool[j])->error(cases[h]) < minfit)
138 minfit = pop.individuals.at(pool[j])->error(cases[h]);
139
140 // criteria to stay in pool
141 epsilon_threshold = minfit+epsilon[cases[h]];
142
143 // select best
144 for (size_t j = 0; j<pool.size(); ++j)
145 {
146 if (pop.individuals.at(pool[j])->error(cases[h])
147 <= epsilon_threshold)
148 winner.push_back(pool[j]);
149 }
150
151 ++h; // next case
152 // only keep going if needed
153 pass = (winner.size()>1 && h<cases.size());
154
155 if(winner.size() == 0)
156 {
157 if(h >= cases.size())
158 winner.push_back(*r.select_randomly(
159 pool.begin(), pool.end()) );
160 else
161 pass = true;
162 }
163 else
164 pool = winner; // reduce pool to remaining individuals
165 }
166
167 assert(winner.size()>0);
168
169 //if more than one winner, pick randomly
170 selected.at(i) = *r.select_randomly(
171 winner.begin(), winner.end() );
172 }
173
174 if (selected.size() != island_pool.size())
175 {
176 HANDLE_ERROR_THROW("Lexicase did not select correct number of \
177 parents");
178 }
179
180 return selected;
181}
182
183template<ProgramType T>
184vector<size_t> Lexicase<T>::survive(Population<T>& pop, int island,
185 const Parameters& params)
186{
187 /* Lexicase survival */
188 HANDLE_ERROR_THROW("Lexicase survival not implemented");
189 return vector<size_t>();
190}
191
192}
193}
vector< size_t > get_island_indexes(int island)
Definition population.h:39
vector< std::shared_ptr< Individual< T > > > individuals
Definition population.h:19
Lexicase(bool surv=false)
Definition lexicase.cpp:11
vector< size_t > survive(Population< T > &pop, int island, const Parameters &p)
lexicase survival
Definition lexicase.cpp:184
vector< size_t > select(Population< T > &pop, int island, const Parameters &p)
function returns a set of selected indices from pop
Definition lexicase.cpp:18
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
float mad(const ArrayXf &x)
median absolute deviation
Definition utils.cpp:373
static Rnd & r
Definition rnd.h:176
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
vector< float > sample_weights
weights for each sample
Definition params.h:70
bool classification
Definition params.h:75
vector< float > class_weights
weights for each class
Definition params.h:69
string scorer
actual loss function used, determined by error
Definition params.h:66
unsigned int current_gen
Definition params.h:29