Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
params.h
Go to the documentation of this file.
1/* Brush
2copyright 2024 William La Cava
3license: GNU/GPL v3
4*/
5
6#ifndef PARAMS_H
7#define PARAMS_H
8
9#include "util/logger.h"
10#include "util/utils.h"
11
12namespace ns = nlohmann;
13
14namespace Brush
15{
16
17using namespace Util;
18
20{
21public:
22 // by default, the rng generator will use any random seed if random_state is zero
23 int random_state = 0;
24 int verbosity = 0;
25
26 // Evolutionary algorithm settings
27 string mode="regression";
28
29 unsigned int current_gen = 1;
30
31 // termination criteria
32 int pop_size = 100;
33 int max_gens = 100;
34 int max_stall = 0;
35 int max_time = -1;
36
37 unsigned int max_depth = 5;
38 unsigned int max_size = 50;
39
40 vector<string> objectives{"scorer","linear_complexity"}; // scorer should be generic and deducted based on mode
41 string bandit = "thompson"; // TODO: should I rename dummy?
42 string sel = "lexicase"; //selection method
43 string surv = "nsga2"; //survival method
44 std::unordered_map<string, float> functions;
46
47 // if we should save pareto front of the entire evolution (use_arch=true)
48 // or just the final population (use_arch=false)
49 bool use_arch=false;
50 bool val_from_arch=true;
51
52 // Different simplification strategies
55
56 // variation
57 std::map<std::string, float> mutation_probs = {
58 {"point", 0.167},
59 {"insert", 0.167},
60 {"delete", 0.167},
61 {"subtree", 0.167},
62 {"toggle_weight_on", 0.167},
63 {"toggle_weight_off", 0.167}
64 };
65
66 float cx_prob=0.2;
67 float mig_prob = 0.05;
68
69 string scorer="mse";
70
71 vector<int> classes = vector<int>();
72 vector<float> class_weights = vector<float>();
73 vector<float> sample_weights = vector<float>();
74
75 // for creating dataset from X and y in Engine<T>::fit. Ignored if
76 // the uses uses an dataset
77 bool classification = false;
78 unsigned int n_classes = 0;
79
80 // validation partition
81 bool shuffle_split = false;
82 float validation_size = 0.75;
83 vector<string> feature_names = {};
84 vector<string> feature_types = {};
85 float batch_size = 0.0;
86 bool weights_init=true;
87
88 string load_population = "";
89 string save_population = "";
90
91 string logfile = "";
92
93 int n_jobs = 1;
94
97
98 // TODO: use logger to log information. Make getters const
99 void set_verbosity(int new_verbosity){ Brush::Util::logger.set_log_level(new_verbosity);
100 verbosity = new_verbosity; };
101 int get_verbosity(){ return verbosity; };
102
103 void set_random_state(int new_random_state){random_state = new_random_state; };
105
106 void set_pop_size(int new_pop_size){ pop_size = new_pop_size; };
107 int get_pop_size(){ return pop_size; };
108
109 void set_max_gens(int new_max_gens){ max_gens = new_max_gens; };
110 int get_max_gens(){ return max_gens; };
111
112 void set_bandit(string new_bandit){ bandit = new_bandit; };
113 string get_bandit(){ return bandit; };
114
115 void set_max_stall(int new_max_stall){ max_stall = new_max_stall; };
116 int get_max_stall(){ return max_stall; };
117
118 void set_max_time(int new_max_time){ max_time = new_max_time; };
119 int get_max_time(){ return max_time; };
120
121 void set_scorer(string new_scorer){ scorer = new_scorer; };
122 string get_scorer(){ return scorer; };
123
124 void set_load_population(string new_load_population){ load_population = new_load_population; };
126
127 void set_save_population(string new_save_population){ save_population = new_save_population; };
129
130 string get_logfile(){ return logfile; };
131 void set_logfile(string s){ logfile=s; };
132
133 void set_current_gen(unsigned int gen){ current_gen = gen; };
134 unsigned int get_current_gen(){ return current_gen; };
135
136 // TODO: improve vary_and_update to have island working in parallel
137 void set_num_islands(int new_num_islands){ num_islands = new_num_islands; };
138 int get_num_islands(){ return num_islands; };
139
140 void set_max_depth(unsigned new_max_depth){ max_depth = new_max_depth; };
141 unsigned get_max_depth() const { return max_depth; };
142
143 void set_n_jobs(int new_n_jobs){ n_jobs = new_n_jobs; };
144 int get_n_jobs(){ return n_jobs; };
145
146 void set_max_size(unsigned new_max_size){ max_size = new_max_size; };
147 unsigned get_max_size() const { return max_size; };
148
149 void set_objectives(vector<string> new_objectives){ objectives = new_objectives; };
150 vector<string> get_objectives() const {
151 // return objectives;
152
153 // properly replace scorer with the specified scorer
154 vector<string> aux_objectives(0);
155
156 for (auto& objective : objectives) {
157 if (objective.compare("scorer")==0)
158 aux_objectives.push_back(scorer);
159 else
160 aux_objectives.push_back(objective);
161 }
162
163 return aux_objectives;
164 };
165
166 void set_sel(string new_sel){ sel = new_sel; };
167 string get_sel(){ return sel; };
168
169 void set_surv(string new_surv){ surv = new_surv; };
170 string get_surv(){ return surv; };
171
172 void set_cx_prob(float new_cx_prob){ cx_prob = new_cx_prob; };
173 float get_cx_prob(){ return cx_prob; };
174
175 void set_mig_prob(float new_mig_prob){ mig_prob = new_mig_prob; };
176 float get_mig_prob(){ return mig_prob; };
177
178 void set_use_arch(bool new_use_arch){ use_arch = new_use_arch; };
179 bool get_use_arch(){ return use_arch; };
180
181 void set_val_from_arch(bool new_val_from_arch){ val_from_arch = new_val_from_arch; };
183
184 void set_classification(bool c){ classification = c; };
186
187 void set_shuffle_split(bool shuff){ shuffle_split = shuff; };
189
192
195
196 void set_weights_init(bool init){ weights_init = init; };
198
199 void set_n_classes(const ArrayXf& y){
200 if (classification)
201 {
202 vector<int> uc = unique( ArrayXi(y.cast<int>()) );
203
204 if (int(uc.at(0)) != 0)
205 HANDLE_ERROR_THROW("Class labels must start at 0");
206
207 vector<int> cont_classes(uc.size());
208 iota(cont_classes.begin(), cont_classes.end(), 0);
209 for (int i = 0; i < cont_classes.size(); ++i)
210 {
211 if ( int(uc.at(i)) != cont_classes.at(i))
212 HANDLE_ERROR_THROW("Class labels must be contiguous");
213 }
214 n_classes = uc.size();
215 // classes = uc;
216 }
217 };
218 void set_class_weights(const ArrayXf& y){
219 class_weights.resize(n_classes); // set_n_classes must be called first
220 for (unsigned i = 0; i < n_classes; ++i){
221 class_weights.at(i) = float((y.cast<int>().array() == i).count())/y.size();
222 class_weights.at(i) = (1.0 - class_weights.at(i));
223 }
224 };
225 void set_sample_weights(const ArrayXf& y){
226 sample_weights.resize(0); // set_class_weights must be called first
227 if (!class_weights.empty())
228 for (unsigned i = 0; i < y.size(); ++i)
229 sample_weights.push_back(class_weights.at(int(y(i))));
230 };
231
232 unsigned int get_n_classes(){ return n_classes; };
233 vector<float> get_class_weights(){ return class_weights; };
234 vector<float> get_sample_weights(){ return sample_weights; };
235
238
239 void set_feature_names(vector<string> vn){ feature_names = vn; };
240 vector<string> get_feature_names(){ return feature_names; };
241
242 void set_feature_types(vector<string> ft){ feature_types = ft; };
243 vector<string> get_feature_types(){ return feature_types; };
244
245 void set_batch_size(float c){ batch_size = c; };
246 float get_batch_size(){ return batch_size; };
247
248 void set_mutation_probs(std::map<std::string, float> new_mutation_probs){ mutation_probs = new_mutation_probs; };
249 std::map<std::string, float> get_mutation_probs(){ return mutation_probs; };
250
251 void set_functions(std::unordered_map<std::string, float> new_functions){ functions = new_functions; };
252 std::unordered_map<std::string, float> get_functions(){ return functions; };
253};
254
256 verbosity,
257 random_state,
258 pop_size,
259 max_gens,
260 max_stall,
261 max_time,
262 scorer,
263 load_population,
264 save_population,
265 logfile,
266 current_gen,
267 num_islands,
268 max_depth,
269 n_jobs,
270 max_size,
271 objectives,
272 sel,
273 surv,
274 cx_prob,
275 mig_prob,
276 classification,
277 n_classes,
278 classes, // TODO: get rid of this parameter? for some reason, when i remove it (or set it to any value) the load population starts to fail with regression
279 class_weights,
280 sample_weights,
281 validation_size,
282 feature_names,
283 batch_size,
284 mutation_probs,
285 functions
286);
287
288} // Brush
289
290#endif
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
namespace containing various utility functions
Definition error.cpp:11
static Logger & logger
Definition logger.h:60
vector< T > unique(vector< T > w)
returns unique elements in vector
Definition utils.h:334
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine< PT::Regressor >, params, best_ind, archive, pop, ss, is_fitted)
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
void set_save_population(string new_save_population)
Definition params.h:127
bool inexact_simplification
Definition params.h:54
bool get_classification()
Definition params.h:185
void set_max_depth(unsigned new_max_depth)
Definition params.h:140
int get_verbosity()
Definition params.h:101
vector< float > sample_weights
weights for each sample
Definition params.h:73
void set_max_stall(int new_max_stall)
Definition params.h:115
vector< string > objectives
Definition params.h:40
void set_max_gens(int new_max_gens)
Definition params.h:109
unsigned get_max_size() const
Definition params.h:147
void set_shuffle_split(bool shuff)
Definition params.h:187
bool get_use_arch()
Definition params.h:179
void set_max_time(int new_max_time)
Definition params.h:118
bool get_inexact_simplification()
Definition params.h:194
void set_feature_types(vector< string > ft)
Definition params.h:242
float cx_prob
cross rate for variation
Definition params.h:66
void set_functions(std::unordered_map< std::string, float > new_functions)
Definition params.h:251
void set_mutation_probs(std::map< std::string, float > new_mutation_probs)
Definition params.h:248
unsigned int max_depth
Definition params.h:37
bool get_val_from_arch()
Definition params.h:182
void set_n_classes(const ArrayXf &y)
Definition params.h:199
void set_inexact_simplification(bool is)
Definition params.h:193
int get_max_time()
Definition params.h:119
float batch_size
Definition params.h:85
bool get_weights_init()
Definition params.h:197
int get_max_gens()
Definition params.h:110
vector< string > get_objectives() const
Definition params.h:150
void set_batch_size(float c)
Definition params.h:245
bool shuffle_split
Definition params.h:81
void set_val_from_arch(bool new_val_from_arch)
Definition params.h:181
void set_class_weights(const ArrayXf &y)
Definition params.h:218
void set_constants_simplification(bool cs)
Definition params.h:190
std::map< std::string, float > mutation_probs
Definition params.h:57
void set_num_islands(int new_num_islands)
Definition params.h:137
void set_pop_size(int new_pop_size)
Definition params.h:106
vector< string > get_feature_names()
Definition params.h:240
vector< float > get_class_weights()
Definition params.h:233
void set_random_state(int new_random_state)
Definition params.h:103
void set_logfile(string s)
Definition params.h:131
string get_save_population()
Definition params.h:128
float get_batch_size()
Definition params.h:246
float validation_size
Definition params.h:82
unsigned int get_n_classes()
Definition params.h:232
bool get_constants_simplification()
Definition params.h:191
string get_bandit()
Definition params.h:113
void set_weights_init(bool init)
Definition params.h:196
int get_max_stall()
Definition params.h:116
float get_mig_prob()
Definition params.h:176
int get_random_state()
Definition params.h:104
std::unordered_map< string, float > functions
Definition params.h:44
string get_scorer()
Definition params.h:122
string bandit
Definition params.h:41
string get_load_population()
Definition params.h:125
bool classification
Definition params.h:77
unsigned get_max_depth() const
Definition params.h:141
void set_load_population(string new_load_population)
Definition params.h:124
int get_pop_size()
Definition params.h:107
vector< int > classes
class labels
Definition params.h:71
int n_jobs
number of parallel jobs -1 use all threads; 0 use same as number of islands; positive number specify ...
Definition params.h:93
void set_current_gen(unsigned int gen)
Definition params.h:133
vector< float > class_weights
weights for each class
Definition params.h:72
int get_num_islands()
Definition params.h:138
void set_sel(string new_sel)
Definition params.h:166
void set_scorer(string new_scorer)
Definition params.h:121
vector< string > feature_names
Definition params.h:83
bool val_from_arch
Definition params.h:50
string get_logfile()
Definition params.h:130
string scorer
actual loss function used, determined by error
Definition params.h:69
void set_surv(string new_surv)
Definition params.h:169
void set_use_arch(bool new_use_arch)
Definition params.h:178
void set_classification(bool c)
Definition params.h:184
vector< float > get_sample_weights()
Definition params.h:234
string logfile
Definition params.h:91
void set_objectives(vector< string > new_objectives)
Definition params.h:149
string get_surv()
Definition params.h:170
void set_sample_weights(const ArrayXf &y)
Definition params.h:225
bool constants_simplification
Definition params.h:53
bool get_shuffle_split()
Definition params.h:188
unsigned int current_gen
Definition params.h:29
string load_population
Definition params.h:88
float get_cx_prob()
Definition params.h:173
string save_population
Definition params.h:89
void set_max_size(unsigned new_max_size)
Definition params.h:146
void set_n_jobs(int new_n_jobs)
Definition params.h:143
bool weights_init
Definition params.h:86
void set_feature_names(vector< string > vn)
Definition params.h:239
void set_validation_size(float s)
Definition params.h:236
vector< string > feature_types
Definition params.h:84
void set_cx_prob(float new_cx_prob)
Definition params.h:172
unsigned int get_current_gen()
Definition params.h:134
std::map< std::string, float > get_mutation_probs()
Definition params.h:249
vector< string > get_feature_types()
Definition params.h:243
float get_validation_size()
Definition params.h:237
std::unordered_map< std::string, float > get_functions()
Definition params.h:252
unsigned int max_size
Definition params.h:38
void set_verbosity(int new_verbosity)
Definition params.h:99
unsigned int n_classes
Definition params.h:78
void set_mig_prob(float new_mig_prob)
Definition params.h:175
void set_bandit(string new_bandit)
Definition params.h:112
string get_sel()
Definition params.h:167