Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
params.h
Go to the documentation of this file.
1/* Brush
2copyright 2024 William La Cava
3license: GNU/GPL v3
4*/
5
6#ifndef PARAMS_H
7#define PARAMS_H
8
9#include "util/logger.h"
10#include "util/utils.h"
11
12namespace ns = nlohmann;
13
14namespace Brush
15{
16
17using namespace Util;
18
20{
21public:
22 // by default, the rng generator will use any random seed if random_state is zero
23 int random_state = 0;
24 int verbosity = 0;
25
26 // Evolutionary algorithm settings
27 string mode="regression";
28
29 unsigned int current_gen = 1;
30
31 // termination criteria
32 int pop_size = 100;
33 int max_gens = 100;
34 int max_stall = 0;
35 int max_time = -1;
36
37 unsigned int max_depth = 5;
38 unsigned int max_size = 50;
39
40 vector<string> objectives{"scorer","linear_complexity"}; // scorer should be generic and deducted based on mode
41 string bandit = "thompson"; // TODO: should I rename dummy?
42 string sel = "lexicase"; //selection method
43 string surv = "nsga2"; //survival method
44 std::unordered_map<string, float> functions;
46
47 // Different simplification strategies
50
51 // variation
52 std::map<std::string, float> mutation_probs = {
53 {"point", 0.167},
54 {"insert", 0.167},
55 {"delete", 0.167},
56 {"subtree", 0.167},
57 {"toggle_weight_on", 0.167},
58 {"toggle_weight_off", 0.167}
59 };
60 float cx_prob=0.2;
61 float mig_prob = 0.05;
62
63 string scorer="mse";
64
65 vector<int> classes = vector<int>();
66 vector<float> class_weights = vector<float>();
67 vector<float> sample_weights = vector<float>();
68 string class_weights_type = "unbalanced"; // user_defined, unbalanced, support
69
70 // for creating dataset from X and y in Engine<T>::fit. Ignored if
71 // the uses uses an dataset
72 bool classification = false;
73 unsigned int n_classes = 0;
74
75 // validation partition
76 bool shuffle_split = false;
77 float validation_size = 0.2;
78 vector<string> feature_names = {};
79 vector<string> feature_types = {};
80 float batch_size = 0.0;
81 bool weights_init=true;
82
83 string load_population = "";
84 string save_population = "";
85
86 string logfile = "";
87
88 int n_jobs = 1;
89
92
93 // TODO: use logger to log information. Make getters const
94 void set_verbosity(int new_verbosity){ Brush::Util::logger.set_log_level(new_verbosity);
95 verbosity = new_verbosity; };
96 int get_verbosity(){ return verbosity; };
97
98 void set_random_state(int new_random_state){random_state = new_random_state; };
100
101 void set_pop_size(int new_pop_size){ pop_size = new_pop_size; };
102 int get_pop_size(){ return pop_size; };
103
104 void set_max_gens(int new_max_gens){ max_gens = new_max_gens; };
105 int get_max_gens(){ return max_gens; };
106
107 void set_bandit(string new_bandit){ bandit = new_bandit; };
108 string get_bandit(){ return bandit; };
109
110 void set_max_stall(int new_max_stall){ max_stall = new_max_stall; };
111 int get_max_stall(){ return max_stall; };
112
113 void set_max_time(int new_max_time){ max_time = new_max_time; };
114 int get_max_time(){ return max_time; };
115
116 void set_scorer(string new_scorer){ scorer = new_scorer; };
117 string get_scorer(){ return scorer; };
118
119 void set_load_population(string new_load_population){ load_population = new_load_population; };
121
122 void set_save_population(string new_save_population){ save_population = new_save_population; };
124
125 string get_logfile(){ return logfile; };
126 void set_logfile(string s){ logfile=s; };
127
128 void set_current_gen(unsigned int gen){ current_gen = gen; };
129 unsigned int get_current_gen(){ return current_gen; };
130
131 // TODO: improve vary_and_update to have island working in parallel
132 void set_num_islands(int new_num_islands){ num_islands = new_num_islands; };
133 int get_num_islands(){ return num_islands; };
134
135 void set_max_depth(unsigned new_max_depth){ max_depth = new_max_depth; };
136 unsigned get_max_depth() const { return max_depth; };
137
138 void set_n_jobs(int new_n_jobs){ n_jobs = new_n_jobs; };
139 int get_n_jobs(){ return n_jobs; };
140
141 void set_max_size(unsigned new_max_size){ max_size = new_max_size; };
142 unsigned get_max_size() const { return max_size; };
143
144 void set_objectives(vector<string> new_objectives){ objectives = new_objectives; };
145 vector<string> get_objectives() const {
146 // return objectives;
147
148 // properly replace scorer with the specified scorer
149 vector<string> aux_objectives(0);
150
151 for (auto& objective : objectives) {
152 if (objective.compare("scorer")==0)
153 aux_objectives.push_back(scorer);
154 else
155 aux_objectives.push_back(objective);
156 }
157
158 return aux_objectives;
159 };
160
161 void set_sel(string new_sel){ sel = new_sel; };
162 string get_sel(){ return sel; };
163
164 void set_surv(string new_surv){ surv = new_surv; };
165 string get_surv(){ return surv; };
166
167 void set_cx_prob(float new_cx_prob){ cx_prob = new_cx_prob; };
168 float get_cx_prob(){ return cx_prob; };
169
170 void set_mig_prob(float new_mig_prob){ mig_prob = new_mig_prob; };
171 float get_mig_prob(){ return mig_prob; };
172
173 void set_classification(bool c){ classification = c; };
175
176 void set_shuffle_split(bool shuff){ shuffle_split = shuff; };
178
181
184
185 void set_weights_init(bool init){ weights_init = init; };
187
188 void set_n_classes(const ArrayXf& y){
189 if (classification)
190 {
191 vector<int> uc = unique( ArrayXi(y.cast<int>()) );
192
193 if (int(uc.at(0)) != 0)
194 HANDLE_ERROR_THROW("Class labels must start at 0");
195
196 vector<int> cont_classes(uc.size());
197 iota(cont_classes.begin(), cont_classes.end(), 0);
198 for (int i = 0; i < cont_classes.size(); ++i)
199 {
200 if ( int(uc.at(i)) != cont_classes.at(i))
201 HANDLE_ERROR_THROW("Class labels must be contiguous");
202 }
203 n_classes = uc.size();
204 // classes = uc;
205 }
206 };
207 void set_class_weights(const vector<float>& weights){
208 if (weights.size() != n_classes)
209 HANDLE_ERROR_THROW("Length of class_weights does not match expected number of classes");
210
211 class_weights.clear();
212 for (unsigned int i = 0; i < n_classes; ++i) {
213 class_weights.push_back(weights[i]);
214 }
215 };
216
217 void set_sample_weights(const ArrayXf& y){
218 sample_weights.resize(0);
219
220 // one if for each case, so the default is unbalanced or user defined
221 if (class_weights_type == "support")
222 { // ignores everything and calculate the weights here.
223 class_weights.resize(n_classes); // set_n_classes must be called first
224 for (unsigned i = 0; i < n_classes; ++i){
225 // weighting by support
226 int support = (y.cast<int>().array() == i).count();
227
228 if (support==0)
229 class_weights.at(i) = 0.0;
230 else
231 class_weights.at(i) = float(y.size()) / float(n_classes * support);
232 }
233 } // else it is either unbalanced or user_defined
234
235 if (!class_weights.empty())
236 for (unsigned i = 0; i < y.size(); ++i)
237 sample_weights.push_back(class_weights.at(int(y(i))));
238 };
239
240 unsigned int get_n_classes(){ return n_classes; };
241 vector<float> get_class_weights(){ return class_weights; };
242 vector<float> get_sample_weights(){ return sample_weights; };
243
245 void set_class_weights_type(string cwt){ class_weights_type = cwt; };
246
249
250 void set_feature_names(vector<string> vn){ feature_names = vn; };
251 vector<string> get_feature_names(){ return feature_names; };
252
253 void set_feature_types(vector<string> ft){ feature_types = ft; };
254 vector<string> get_feature_types(){ return feature_types; };
255
256 void set_batch_size(float c){ batch_size = c; };
257 float get_batch_size(){ return batch_size; };
258
259 void set_mutation_probs(std::map<std::string, float> new_mutation_probs){ mutation_probs = new_mutation_probs; };
260 std::map<std::string, float> get_mutation_probs(){ return mutation_probs; };
261
262 void set_functions(std::unordered_map<std::string, float> new_functions){ functions = new_functions; };
263 std::unordered_map<std::string, float> get_functions(){ return functions; };
264};
265
267 random_state,
268 verbosity,
269
270 mode,
271
272 current_gen,
273
274 pop_size,
275 max_gens,
276 max_stall,
277 max_time,
278
279 max_depth,
280 max_size,
281
282 objectives,
283 bandit,
284 sel,
285 surv,
286 functions,
287 num_islands,
288
289 constants_simplification,
290 inexact_simplification,
291
292 mutation_probs,
293 cx_prob,
294 mig_prob,
295
296 scorer,
297
298 classes, // TODO: get rid of this parameter? for some reason, when i remove it (or set it to any value) the load population starts to fail with regression
299 class_weights,
300 sample_weights,
301
302 classification,
303 n_classes,
304
305 shuffle_split,
306 validation_size,
307 feature_names,
308 feature_types,
309 batch_size,
310 weights_init,
311
312 load_population,
313 save_population,
314
315 logfile,
316
317 n_jobs
318);
319
320} // Brush
321
322#endif
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
namespace containing various utility functions
Definition error.cpp:11
static Logger & logger
Definition logger.h:60
vector< T > unique(vector< T > w)
returns unique elements in vector
Definition utils.h:334
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine< PT::Regressor >, params, best_ind, archive, pop, ss, is_fitted)
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
void set_save_population(string new_save_population)
Definition params.h:122
bool inexact_simplification
Definition params.h:49
bool get_classification()
Definition params.h:174
void set_max_depth(unsigned new_max_depth)
Definition params.h:135
int get_verbosity()
Definition params.h:96
vector< float > sample_weights
weights for each sample
Definition params.h:67
void set_max_stall(int new_max_stall)
Definition params.h:110
vector< string > objectives
Definition params.h:40
void set_max_gens(int new_max_gens)
Definition params.h:104
unsigned get_max_size() const
Definition params.h:142
string get_class_weights_type()
Definition params.h:244
void set_shuffle_split(bool shuff)
Definition params.h:176
void set_max_time(int new_max_time)
Definition params.h:113
bool get_inexact_simplification()
Definition params.h:183
void set_feature_types(vector< string > ft)
Definition params.h:253
float cx_prob
cross rate for variation
Definition params.h:60
void set_functions(std::unordered_map< std::string, float > new_functions)
Definition params.h:262
void set_mutation_probs(std::map< std::string, float > new_mutation_probs)
Definition params.h:259
unsigned int max_depth
Definition params.h:37
void set_n_classes(const ArrayXf &y)
Definition params.h:188
void set_inexact_simplification(bool is)
Definition params.h:182
int get_max_time()
Definition params.h:114
float batch_size
Definition params.h:80
bool get_weights_init()
Definition params.h:186
int get_max_gens()
Definition params.h:105
vector< string > get_objectives() const
Definition params.h:145
void set_batch_size(float c)
Definition params.h:256
bool shuffle_split
Definition params.h:76
void set_constants_simplification(bool cs)
Definition params.h:179
std::map< std::string, float > mutation_probs
Definition params.h:52
void set_num_islands(int new_num_islands)
Definition params.h:132
void set_pop_size(int new_pop_size)
Definition params.h:101
vector< string > get_feature_names()
Definition params.h:251
vector< float > get_class_weights()
Definition params.h:241
void set_random_state(int new_random_state)
Definition params.h:98
void set_logfile(string s)
Definition params.h:126
string get_save_population()
Definition params.h:123
float get_batch_size()
Definition params.h:257
float validation_size
Definition params.h:77
unsigned int get_n_classes()
Definition params.h:240
bool get_constants_simplification()
Definition params.h:180
string get_bandit()
Definition params.h:108
void set_weights_init(bool init)
Definition params.h:185
int get_max_stall()
Definition params.h:111
float get_mig_prob()
Definition params.h:171
int get_random_state()
Definition params.h:99
std::unordered_map< string, float > functions
Definition params.h:44
string get_scorer()
Definition params.h:117
string bandit
Definition params.h:41
string get_load_population()
Definition params.h:120
bool classification
Definition params.h:72
unsigned get_max_depth() const
Definition params.h:136
void set_load_population(string new_load_population)
Definition params.h:119
int get_pop_size()
Definition params.h:102
vector< int > classes
class labels
Definition params.h:65
int n_jobs
number of parallel jobs -1 use all threads; 0 use same as number of islands; positive number specify ...
Definition params.h:88
void set_current_gen(unsigned int gen)
Definition params.h:128
vector< float > class_weights
weights for each class
Definition params.h:66
int get_num_islands()
Definition params.h:133
void set_sel(string new_sel)
Definition params.h:161
void set_scorer(string new_scorer)
Definition params.h:116
vector< string > feature_names
Definition params.h:78
string get_logfile()
Definition params.h:125
string scorer
actual loss function used, determined by error
Definition params.h:63
void set_surv(string new_surv)
Definition params.h:164
void set_classification(bool c)
Definition params.h:173
vector< float > get_sample_weights()
Definition params.h:242
string logfile
Definition params.h:86
void set_objectives(vector< string > new_objectives)
Definition params.h:144
string get_surv()
Definition params.h:165
void set_sample_weights(const ArrayXf &y)
Definition params.h:217
bool constants_simplification
Definition params.h:48
bool get_shuffle_split()
Definition params.h:177
unsigned int current_gen
Definition params.h:29
void set_class_weights_type(string cwt)
Definition params.h:245
string load_population
Definition params.h:83
float get_cx_prob()
Definition params.h:168
string save_population
Definition params.h:84
void set_max_size(unsigned new_max_size)
Definition params.h:141
void set_n_jobs(int new_n_jobs)
Definition params.h:138
bool weights_init
Definition params.h:81
void set_feature_names(vector< string > vn)
Definition params.h:250
void set_validation_size(float s)
Definition params.h:247
vector< string > feature_types
Definition params.h:79
void set_cx_prob(float new_cx_prob)
Definition params.h:167
unsigned int get_current_gen()
Definition params.h:129
std::map< std::string, float > get_mutation_probs()
Definition params.h:260
vector< string > get_feature_types()
Definition params.h:254
float get_validation_size()
Definition params.h:248
void set_class_weights(const vector< float > &weights)
Definition params.h:207
std::unordered_map< std::string, float > get_functions()
Definition params.h:263
unsigned int max_size
Definition params.h:38
void set_verbosity(int new_verbosity)
Definition params.h:94
unsigned int n_classes
Definition params.h:73
void set_mig_prob(float new_mig_prob)
Definition params.h:170
void set_bandit(string new_bandit)
Definition params.h:107
string class_weights_type
Definition params.h:68
string get_sel()
Definition params.h:162