Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
params.h
Go to the documentation of this file.
1/* Brush
2copyright 2024 William La Cava
3license: GNU/GPL v3
4*/
5
6#ifndef PARAMS_H
7#define PARAMS_H
8
9#include "util/logger.h"
10#include "util/utils.h"
11
12namespace ns = nlohmann;
13
14namespace Brush
15{
16
17using namespace Util;
18
20{
21public:
22 // by default, the rng generator will use any random seed if random_state is zero
23 int random_state = 0;
24 int verbosity = 0;
25
26 // Evolutionary algorithm settings
27 string mode="regression";
28
29 unsigned int current_gen = 1;
30
31 // termination criteria
32 int pop_size = 100;
33 int max_gens = 100;
34 int max_stall = 0;
35 int max_time = -1;
36
37 unsigned int max_depth = 5;
38 unsigned int max_size = 50;
39
40 vector<string> objectives{"scorer","linear_complexity"}; // scorer should be generic and deducted based on mode
41 string bandit = "thompson"; // TODO: should I rename dummy?
42 string sel = "lexicase"; //selection method
43 string surv = "nsga2"; //survival method
44 std::unordered_map<string, float> functions;
46
47 // Different simplification strategies
50
51 // population initialization
53
54 // variation
55 std::map<std::string, float> mutation_probs = {
56 {"point", 0.167},
57 {"insert", 0.167},
58 {"delete", 0.167},
59 {"subtree", 0.167},
60 {"toggle_weight_on", 0.167},
61 {"toggle_weight_off", 0.167}
62 };
63 float cx_prob=0.2;
64 float mig_prob = 0.05;
65
66 string scorer="mse";
67
68 vector<int> classes = vector<int>();
69 vector<float> class_weights = vector<float>();
70 vector<float> sample_weights = vector<float>();
71 string class_weights_type = "unbalanced"; // user_defined, unbalanced, support
72
73 // for creating dataset from X and y in Engine<T>::fit. Ignored if
74 // the uses uses an dataset
75 bool classification = false;
76 unsigned int n_classes = 0;
77
78 // validation partition
79 bool shuffle_split = false;
80 float validation_size = 0.2;
81 vector<string> feature_names = {};
82 vector<string> feature_types = {};
83 float batch_size = 0.0;
84 bool weights_init=true;
85
86 string load_population = "";
87 string save_population = "";
88
89 string logfile = "";
90
91 int n_jobs = 1;
92
95
96 // TODO: use logger to log information. Make getters const
97 void set_verbosity(int new_verbosity){ Brush::Util::logger.set_log_level(new_verbosity);
98 verbosity = new_verbosity; };
99 int get_verbosity(){ return verbosity; };
100
101 void set_random_state(int new_random_state){random_state = new_random_state; };
103
104 void set_pop_size(int new_pop_size){ pop_size = new_pop_size; };
105 int get_pop_size(){ return pop_size; };
106
107 void set_max_gens(int new_max_gens){ max_gens = new_max_gens; };
108 int get_max_gens(){ return max_gens; };
109
110 void set_bandit(string new_bandit){ bandit = new_bandit; };
111 string get_bandit(){ return bandit; };
112
113 void set_max_stall(int new_max_stall){ max_stall = new_max_stall; };
114 int get_max_stall(){ return max_stall; };
115
116 void set_max_time(int new_max_time){ max_time = new_max_time; };
117 int get_max_time(){ return max_time; };
118
119 void set_scorer(string new_scorer){ scorer = new_scorer; };
120 string get_scorer(){ return scorer; };
121
122 void set_load_population(string new_load_population){ load_population = new_load_population; };
124
125 void set_save_population(string new_save_population){ save_population = new_save_population; };
127
128 string get_logfile(){ return logfile; };
129 void set_logfile(string s){ logfile=s; };
130
131 void set_current_gen(unsigned int gen){ current_gen = gen; };
132 unsigned int get_current_gen(){ return current_gen; };
133
134 // TODO: improve vary_and_update to have island working in parallel
135 void set_num_islands(int new_num_islands){ num_islands = new_num_islands; };
136 int get_num_islands(){ return num_islands; };
137
138 void set_max_depth(unsigned new_max_depth){ max_depth = new_max_depth; };
139 unsigned get_max_depth() const { return max_depth; };
140
141 void set_n_jobs(int new_n_jobs){ n_jobs = new_n_jobs; };
142 int get_n_jobs(){ return n_jobs; };
143
144 void set_max_size(unsigned new_max_size){ max_size = new_max_size; };
145 unsigned get_max_size() const { return max_size; };
146
147 void set_objectives(vector<string> new_objectives){ objectives = new_objectives; };
148 vector<string> get_objectives() const {
149 // return objectives;
150
151 // properly replace scorer with the specified scorer
152 vector<string> aux_objectives(0);
153
154 for (auto& objective : objectives) {
155 if (objective.compare("scorer")==0)
156 aux_objectives.push_back(scorer);
157 else
158 aux_objectives.push_back(objective);
159 }
160
161 return aux_objectives;
162 };
163
164 void set_sel(string new_sel){ sel = new_sel; };
165 string get_sel(){ return sel; };
166
167 void set_surv(string new_surv){ surv = new_surv; };
168 string get_surv(){ return surv; };
169
170 void set_cx_prob(float new_cx_prob){ cx_prob = new_cx_prob; };
171 float get_cx_prob(){ return cx_prob; };
172
173 void set_mig_prob(float new_mig_prob){ mig_prob = new_mig_prob; };
174 float get_mig_prob(){ return mig_prob; };
175
176 void set_classification(bool c){ classification = c; };
178
179 void set_shuffle_split(bool shuff){ shuffle_split = shuff; };
181
184
187
188 void set_weights_init(bool init){ weights_init = init; };
190
191 void set_n_classes(const ArrayXf& y){
192 if (classification)
193 {
194 vector<int> uc = unique( ArrayXi(y.cast<int>()) );
195
196 if (int(uc.at(0)) != 0)
197 HANDLE_ERROR_THROW("Class labels must start at 0");
198
199 vector<int> cont_classes(uc.size());
200 iota(cont_classes.begin(), cont_classes.end(), 0);
201 for (int i = 0; i < cont_classes.size(); ++i)
202 {
203 if ( int(uc.at(i)) != cont_classes.at(i))
204 HANDLE_ERROR_THROW("Class labels must be contiguous");
205 }
206 n_classes = uc.size();
207 // classes = uc;
208 }
209 };
210 void set_class_weights(const vector<float>& weights){
211 if (weights.size() != n_classes)
212 HANDLE_ERROR_THROW("Length of class_weights does not match expected number of classes");
213
214 class_weights.clear();
215 for (unsigned int i = 0; i < n_classes; ++i) {
216 class_weights.push_back(weights[i]);
217 }
218 };
219
220 void set_sample_weights(const ArrayXf& y){
221 sample_weights.resize(0);
222
223 // one if for each case, so the default is unbalanced or user defined
224 if (class_weights_type == "support")
225 { // ignores everything and calculate the weights here.
226 class_weights.resize(n_classes); // set_n_classes must be called first
227 for (unsigned i = 0; i < n_classes; ++i){
228 // weighting by support
229 int support = (y.cast<int>().array() == i).count();
230
231 if (support==0)
232 class_weights.at(i) = 0.0;
233 else
234 class_weights.at(i) = float(y.size()) / float(n_classes * support);
235 }
236 } // else it is either unbalanced or user_defined
237
238 if (!class_weights.empty())
239 for (unsigned i = 0; i < y.size(); ++i)
240 sample_weights.push_back(class_weights.at(int(y(i))));
241 };
242
243 unsigned int get_n_classes(){ return n_classes; };
244 vector<float> get_class_weights(){ return class_weights; };
245 vector<float> get_sample_weights(){ return sample_weights; };
246
248 void set_class_weights_type(string cwt){ class_weights_type = cwt; };
249
252
255
256 void set_feature_names(vector<string> vn){ feature_names = vn; };
257 vector<string> get_feature_names(){ return feature_names; };
258
259 void set_feature_types(vector<string> ft){ feature_types = ft; };
260 vector<string> get_feature_types(){ return feature_types; };
261
262 void set_batch_size(float c){ batch_size = c; };
263 float get_batch_size(){ return batch_size; };
264
265 void set_mutation_probs(std::map<std::string, float> new_mutation_probs){ mutation_probs = new_mutation_probs; };
266 std::map<std::string, float> get_mutation_probs(){ return mutation_probs; };
267
268 void set_functions(std::unordered_map<std::string, float> new_functions){ functions = new_functions; };
269 std::unordered_map<std::string, float> get_functions(){ return functions; };
270};
271
273 random_state,
274 verbosity,
275
276 mode,
277
278 current_gen,
279
280 pop_size,
281 max_gens,
282 max_stall,
283 max_time,
284
285 max_depth,
286 max_size,
287
288 objectives,
289 bandit,
290 sel,
291 surv,
292 functions,
293 num_islands,
294
295 constants_simplification,
296 inexact_simplification,
297
298 mutation_probs,
299 cx_prob,
300 mig_prob,
301
302 scorer,
303
304 classes, // TODO: get rid of this parameter? for some reason, when i remove it (or set it to any value) the load population starts to fail with regression
305 class_weights,
306 sample_weights,
307
308 classification,
309 n_classes,
310
311 start_from_decision_trees,
312
313 shuffle_split,
314 validation_size,
315 feature_names,
316 feature_types,
317 batch_size,
318 weights_init,
319
320 load_population,
321 save_population,
322
323 logfile,
324
325 n_jobs
326);
327
328} // Brush
329
330#endif
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
namespace containing various utility functions
Definition error.cpp:11
static Logger & logger
Definition logger.h:60
vector< T > unique(vector< T > w)
returns unique elements in vector
Definition utils.h:334
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine< PT::Regressor >, params, best_ind, archive, pop, ss, is_fitted)
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
void set_save_population(string new_save_population)
Definition params.h:125
bool start_from_decision_trees
Definition params.h:52
bool inexact_simplification
Definition params.h:49
bool get_classification()
Definition params.h:177
void set_max_depth(unsigned new_max_depth)
Definition params.h:138
int get_verbosity()
Definition params.h:99
vector< float > sample_weights
weights for each sample
Definition params.h:70
void set_max_stall(int new_max_stall)
Definition params.h:113
vector< string > objectives
Definition params.h:40
void set_max_gens(int new_max_gens)
Definition params.h:107
unsigned get_max_size() const
Definition params.h:145
string get_class_weights_type()
Definition params.h:247
void set_shuffle_split(bool shuff)
Definition params.h:179
void set_max_time(int new_max_time)
Definition params.h:116
bool get_inexact_simplification()
Definition params.h:186
void set_feature_types(vector< string > ft)
Definition params.h:259
float cx_prob
cross rate for variation
Definition params.h:63
void set_functions(std::unordered_map< std::string, float > new_functions)
Definition params.h:268
void set_mutation_probs(std::map< std::string, float > new_mutation_probs)
Definition params.h:265
unsigned int max_depth
Definition params.h:37
void set_n_classes(const ArrayXf &y)
Definition params.h:191
void set_inexact_simplification(bool is)
Definition params.h:185
int get_max_time()
Definition params.h:117
float batch_size
Definition params.h:83
bool get_weights_init()
Definition params.h:189
int get_max_gens()
Definition params.h:108
vector< string > get_objectives() const
Definition params.h:148
void set_batch_size(float c)
Definition params.h:262
bool shuffle_split
Definition params.h:79
void set_start_from_decision_trees(bool start_dt)
Definition params.h:251
void set_constants_simplification(bool cs)
Definition params.h:182
std::map< std::string, float > mutation_probs
Definition params.h:55
void set_num_islands(int new_num_islands)
Definition params.h:135
void set_pop_size(int new_pop_size)
Definition params.h:104
vector< string > get_feature_names()
Definition params.h:257
vector< float > get_class_weights()
Definition params.h:244
void set_random_state(int new_random_state)
Definition params.h:101
void set_logfile(string s)
Definition params.h:129
string get_save_population()
Definition params.h:126
float get_batch_size()
Definition params.h:263
float validation_size
Definition params.h:80
unsigned int get_n_classes()
Definition params.h:243
bool get_constants_simplification()
Definition params.h:183
string get_bandit()
Definition params.h:111
void set_weights_init(bool init)
Definition params.h:188
int get_max_stall()
Definition params.h:114
float get_mig_prob()
Definition params.h:174
int get_random_state()
Definition params.h:102
std::unordered_map< string, float > functions
Definition params.h:44
string get_scorer()
Definition params.h:120
string bandit
Definition params.h:41
string get_load_population()
Definition params.h:123
bool classification
Definition params.h:75
bool get_start_from_decision_trees()
Definition params.h:250
unsigned get_max_depth() const
Definition params.h:139
void set_load_population(string new_load_population)
Definition params.h:122
int get_pop_size()
Definition params.h:105
vector< int > classes
class labels
Definition params.h:68
int n_jobs
number of parallel jobs -1 use all threads; 0 use same as number of islands; positive number specify ...
Definition params.h:91
void set_current_gen(unsigned int gen)
Definition params.h:131
vector< float > class_weights
weights for each class
Definition params.h:69
int get_num_islands()
Definition params.h:136
void set_sel(string new_sel)
Definition params.h:164
void set_scorer(string new_scorer)
Definition params.h:119
vector< string > feature_names
Definition params.h:81
string get_logfile()
Definition params.h:128
string scorer
actual loss function used, determined by error
Definition params.h:66
void set_surv(string new_surv)
Definition params.h:167
void set_classification(bool c)
Definition params.h:176
vector< float > get_sample_weights()
Definition params.h:245
string logfile
Definition params.h:89
void set_objectives(vector< string > new_objectives)
Definition params.h:147
string get_surv()
Definition params.h:168
void set_sample_weights(const ArrayXf &y)
Definition params.h:220
bool constants_simplification
Definition params.h:48
bool get_shuffle_split()
Definition params.h:180
unsigned int current_gen
Definition params.h:29
void set_class_weights_type(string cwt)
Definition params.h:248
string load_population
Definition params.h:86
float get_cx_prob()
Definition params.h:171
string save_population
Definition params.h:87
void set_max_size(unsigned new_max_size)
Definition params.h:144
void set_n_jobs(int new_n_jobs)
Definition params.h:141
bool weights_init
Definition params.h:84
void set_feature_names(vector< string > vn)
Definition params.h:256
void set_validation_size(float s)
Definition params.h:253
vector< string > feature_types
Definition params.h:82
void set_cx_prob(float new_cx_prob)
Definition params.h:170
unsigned int get_current_gen()
Definition params.h:132
std::map< std::string, float > get_mutation_probs()
Definition params.h:266
vector< string > get_feature_types()
Definition params.h:260
float get_validation_size()
Definition params.h:254
void set_class_weights(const vector< float > &weights)
Definition params.h:210
std::unordered_map< std::string, float > get_functions()
Definition params.h:269
unsigned int max_size
Definition params.h:38
void set_verbosity(int new_verbosity)
Definition params.h:97
unsigned int n_classes
Definition params.h:76
void set_mig_prob(float new_mig_prob)
Definition params.h:173
void set_bandit(string new_bandit)
Definition params.h:110
string class_weights_type
Definition params.h:71
string get_sel()
Definition params.h:165