Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
variation.h
Go to the documentation of this file.
1/* Brush
2
3copyright 2020 William La Cava
4license: GNU/GPL v3
5*/
6#ifndef VARIATION_H
7#define VARIATION_H
8
9#include "../ind/individual.h"
10
11#include "../bandit/bandit.h"
13#include "../bandit/dummy.h"
14
15#include "../pop/population.h"
16#include "../eval/evaluation.h"
19
20#include <map>
21#include <optional>
22
23using namespace Brush::Pop;
24using namespace Brush::MAB;
25using namespace Brush::Eval;
26using namespace Brush::Simpl;
27
32namespace Brush {
33namespace Var {
34
43template<ProgramType T>
44class Variation {
45public:
49 Variation() = default;
50
58 : parameters(params)
59 , search_space(ss)
60 , data(d)
61 {
62 init();
63 };
64
69
76 void init(){
77 // initializing variation bandit with the probabilities of non-null variation
78 map<string, float> variation_probs;
79 for (const auto& mutation : parameters.get_mutation_probs())
80 // TODO: I need to figure out a better way of avoiding options that were not positive. Make sure that this does not break the code
81 if (mutation.second > 0.0)
82 variation_probs[mutation.first] = mutation.second;
83
84 if (parameters.cx_prob > 0.0)
85 variation_probs["cx"] = parameters.cx_prob;
86
87 this->variation_bandit = Bandit(parameters.bandit, variation_probs);
88
89 // TODO: should I set C parameter based on pop size or leave it fixed?
90 // TODO: update string comparisons to use .compare method
91 // if (parameters.bandit.compare("dynamic_thompson")==0)
92 // this->variation_bandit.pbandit.set_C(parameters.pop_size);
93
94 // initializing one bandit for each terminal type
95 for (const auto& entry : this->search_space.terminal_weights) {
96 // entry is a tuple <dataType, vector<float>> where the vector is the weights
97
98 // TODO: I dont think we need this find here
99 if (terminal_bandits.find(entry.first) == terminal_bandits.end())
100 {
101 map<string, float> terminal_probs;
102 for (int i = 0; i < entry.second.size(); i++)
103 // TODO: do not insert if coefficient is smaller than zero. Fix bandits to work with that
104 if (entry.second[i] > 0.0)
105 {
106 auto node_name = search_space.terminal_map.at(entry.first).at(i).get_feature();
107 terminal_probs[node_name] = entry.second[i];
108 }
109
110 if (terminal_probs.size()>0)
111 terminal_bandits[entry.first] = Bandit(parameters.bandit,
112 terminal_probs);
113 }
114 }
115
116 // one bandit for each return type. If we look at implementation of
117 // sample op, the thing that matters is the most nested probabilities, so we will
118 // learn only that
119 for (auto& [ret_type, arg_w_map]: search_space.node_map)
120 {
121 // if (op_bandits.find(ret_type) == op_bandits.end())
122 // op_bandits.at(ret_type) = map<size_t, Bandit>();
123
124 for (const auto& [args_type, node_map] : arg_w_map)
125 {
126 // if (op_bandits.at(ret_type).find(args_type) != op_bandits.at(ret_type).end())
127 // continue
128
129 // TODO: this could be made much easier using user_ops
130 map<string, float> node_probs;
131
132 for (const auto& [node_type, node]: node_map)
133 {
134 auto weight = search_space.node_map_weights.at(ret_type).at(args_type).at(node_type);
135
136 if (weight > 0.0)
137 {
138 // Attempt to emplace; if the key exists, do nothing
139 auto [it, inserted] = node_probs.try_emplace(node.name, weight);
140
141 // If the key already existed, update its value
142 if (!inserted) {
143 it->second = weight;
144 }
145 }
146 }
147
148 if (node_probs.size() > 0)
149 op_bandits[ret_type][args_type] = Bandit(parameters.bandit,
150 node_probs);
151 }
152 }
153
154 // ensuring all terminals exists as a simplification option
155 inexact_simplifier.init(256, data, 1);
156 for (const auto& entry : this->search_space.terminal_weights) {
157 map<string, float> terminal_probs;
158 for (int i = 0; i < entry.second.size(); i++)
159 if (entry.second[i] > 0.0)
160 {
161 Node node = search_space.terminal_map.at(entry.first).at(i);
162
163 tree<Node> dummy_tree;
164 dummy_tree.insert(dummy_tree.begin(), node);
165 auto it = dummy_tree.begin();
166 inexact_simplifier.index<T>(it, data.get_training_data());
167 }
168 }
169 };
170
179 std::optional<Individual<T>> cross(
180 const Individual<T>& mom, const Individual<T>& dad);
181
189 std::optional<Individual<T>> mutate(
190 const Individual<T>& parent, string choice="");
191
200 void vary(Population<T>& pop, int island, const vector<size_t>& parents);
201
208 void update_ss();
209
220 void vary_and_update(Population<T>& pop, int island, const vector<size_t>& parents,
221 const Dataset& data, Evaluation<T>& evaluator, bool do_simplification) {
222
223 // TODO: move implementation to cpp file and keep only declarations here
224 // TODO: rewrite this entire function to avoid repetition (this is a frankenstein)
225 auto indices = pop.get_island_indexes(island);
226
227 vector<std::shared_ptr<Individual<T>>> aux_individuals;
228 for (unsigned i = 0; i < indices.size(); ++i)
229 {
230 if (pop.individuals.at(indices.at(i)) != nullptr)
231 {
232
233 continue; // skipping if it is an individual --- we just want to fill invalid positions
234 }
235
236 // pass check for children undergoing variation
237 std::optional<Individual<T>> opt = std::nullopt; // new individual
238
239 // TODO: should this be randomly selected, or should I use each parent sequentially?
240 // auto idx = *r.select_randomly(parents.begin(), parents.end());
241 auto idx = parents.at(i % parents.size()); // use modulo to cycle through parents
242
243 const Individual<T>& mom = pop[idx];
244
245 // if we got here, then the individual is not fully locked and we can proceed with mutation
246 vector<Individual<T>> ind_parents = {mom};
247 string choice;
248
249 // this assumes that islands do not share indexes before doing variation
250 unsigned id = parameters.current_gen * parameters.pop_size + indices.at(i);
251
252 Individual<T> ind; // the new individual
253
254 // fully locked individuals should not be replaced by random ones. returning
255 // a copy
256 if (std::all_of(mom.program.Tree.begin(), mom.program.Tree.end(),
257 [](const auto& n) { return n.get_prob_change()<=0.0; }))
258 {
259 // Notice that if everything is locked then the entire population
260 // may be replaced (if the new random individuals dominates the old
261 // fixed ones)
262 ind = Individual<T>();
263 ind.variation = "born";
264
266
267 // Alternative: keep it as it is, and just re-fit the constants
268 // (comment out just the line below to disable, but keep ind.init)
269 Program<T> copy(mom.program);
270 ind.program = copy;
271 }
272 else
273 {
274 choice = this->variation_bandit.choose();
275
276 if (choice.compare("cx") == 0)
277 {
278 // const Individual<T>& dad = pop[
279 // *r.select_randomly(parents.begin(), parents.end())];
280 const Individual<T>& dad = pop[parents.at((i+1) % parents.size())]; // use modulo to cycle through parents
281
282 opt = cross(mom, dad);
283 ind_parents.push_back(dad);
284 }
285 else
286 {
287 opt = mutate(mom, choice);
288 }
289
290 if (opt) // variation worked, lets keep this
291 {
292 ind = opt.value();
293 ind.set_parents(ind_parents);
294 }
295 else { // no optional value was returned. creating a new random individual
296 ind = Individual<T>();
298 ind.variation = "born";
299 }
300 }
301
302 // ind.set_objectives(mom.get_objectives()); // it will have an invalid fitness
303
304 ind.set_id(id);
305
306 ind.fitness.set_loss(mom.fitness.get_loss());
308 ind.fitness.set_size(mom.fitness.get_size());
312
313
314 assert(ind.program.size() > 0);
315 assert(ind.fitness.valid() == false);
316
317 ind.program.fit(data.get_training_data());
318
319 // simplify before calculating fitness (order matters, as they are not refitted and constants simplifier does not replace with the right value.)
320 // TODO: constants_simplifier should set the correct value for the constant (so we dont have to refit).
321 // simplify constants first to avoid letting the lsh simplifier to visit redundant branches
322
323 if (parameters.constants_simplification && do_simplification)
324 {
325 constants_simplifier.simplify_tree<T>(ind.program, search_space, data.get_training_data());
326 }
327
328 if (parameters.inexact_simplification)
329 {
330 auto inputDim = std::min(inexact_simplifier.inputDim, data.get_training_data().get_n_samples());
331
332 vector<size_t> idx(inputDim);
333 std::iota(idx.begin(), idx.end(), 0);
334 Dataset data_simp = data(idx);
335
336 if (do_simplification)
337 {
338 // string prg_str = ind.program.get_model();
339
340 inexact_simplifier.simplify_tree<T>(ind.program, search_space, data_simp);
341
342 // if (ind.program.get_model().compare(prg_str)!= 0)
343 // cout << prg_str << endl << ind.program.get_model() << endl << "=====" << endl;
344 }
345 else
346 {
347 inexact_simplifier.analyze_tree<T>(ind.program, search_space, data_simp);
348 }
349 }
350
351 evaluator.assign_fit(ind, data, parameters, false);
352
353 // vector<float> deltas(ind.get_objectives().size(), 0.0f);
354 vector<float> deltas;
355
356 float delta = 0.0f;
357 float weight = 0.0f;
358
359 for (const auto& obj : ind.get_objectives())
360 {
361 // some objectives are unsigned int, which can have weird values if we
362 // do subtractions. Instead, for these cases, we calculate a placeholder
363 // value indicating only if it was greater or not, so we can deal with
364 // this issue.
365
366 if (obj.compare(parameters.scorer) == 0) {
368 }
369 else if (obj.compare("complexity") == 0) {
370 delta = ind.fitness.get_complexity() > ind.fitness.get_prev_complexity() ? 1.0 : -1.0 ;
371 }
372 else if (obj.compare("linear_complexity") == 0) {
374 }
375 else if (obj.compare("size") == 0) {
376 delta = ind.fitness.get_size() > ind.fitness.get_prev_size() ? 1.0 : -1.0;
377 }
378 else if (obj.compare("depth") == 0) {
379 delta = ind.fitness.get_depth() > ind.fitness.get_prev_depth() ? 1.0 : -1.0;
380 }
381 else {
382 HANDLE_ERROR_THROW(obj + " is not a known objective");
383 }
384
385 auto it = Individual<T>::weightsMap.find(obj);
386 if (it == Individual<T>::weightsMap.end()) {
387 HANDLE_ERROR_THROW("Weight not found for objective: " + obj);
388 }
389
390 weight = it->second;
391 float weighted_delta = delta * weight;
392 deltas.push_back(weighted_delta);
393 }
394
395 bool allPositive = true;
396 bool allNegative = true;
397 for (float d : deltas) {
398 if (d < 0)
399 allPositive = false;
400 if (d > 0)
401 allNegative = false;
402 }
403
404 float r = 0.0;
405 if (allPositive && !allNegative)
406 r = 1.0;
407
408 if (!ind.get_variation().compare("born")
409 && !ind.get_variation().compare("cx")
410 && !ind.get_variation().compare("subtree") // TODO: handle subtree
411 )
412 {
413 this->variation_bandit.update(ind.get_variation(), r);
414
415 if (ind.get_sampled_nodes().size() > 0) {
416 const auto& changed_nodes = ind.get_sampled_nodes();
417 for (auto& node : changed_nodes) {
418 if (node.get_arg_count() == 0) {
419 auto datatype = node.get_ret_type();
420
421 this->terminal_bandits[datatype].update(node.get_feature(), r);
422 }
423 else {
424 auto ret_type = node.get_ret_type();
425 auto args_type = node.args_type();
426 auto name = node.name;
427
428 this->op_bandits[ret_type][args_type].update(name, r);
429 }
430 }
431 }
432 }
433 else
434 { // giving zero reward if the variation failed
435 this->variation_bandit.update(choice, 0.0);
436 }
437
438 // aux_individuals.push_back(std::make_shared<Individual<T>>(ind));
439 pop.individuals.at(indices.at(i)) = std::make_shared<Individual<T>>(ind);
440
441 }
442
443 // updating the population with the new individual
444 // int aux_index = 0;
445 // for (unsigned i = 0; i < indices.size(); ++i)
446 // {
447 // if (pop.individuals.at(indices.at(i)) != nullptr)
448 // {
449 // // the nullptrs should be at the end of the vector
450 // pop.individuals.at(indices.at(i)) = aux_individuals.at(aux_index);
451 // aux_index++;
452 // }
453 // }
454 }
455
456 // these functions below will extract context and use it to choose the nodes to replace
457 // bandit_sample_terminal
458 std::optional<Node> bandit_sample_terminal(DataType R)
459 {
460 if (terminal_bandits.find(R) == terminal_bandits.end()) {
461
462 return std::nullopt;
463 }
464
465 auto& bandit = terminal_bandits.at(R);
466 string terminal_name = bandit.choose();
467
468 auto it = std::find_if(
469 search_space.terminal_map.at(R).begin(),
470 search_space.terminal_map.at(R).end(),
471 [&](auto& node) { return node.get_feature() == terminal_name; });
472
473 if (it != search_space.terminal_map.at(R).end()) {
474 auto index = std::distance(search_space.terminal_map.at(R).begin(), it);
475
476 return search_space.terminal_map.at(R).at(index);
477 }
478
479 return std::nullopt;
480 };
481
482 // bandit_get_node_like
483 std::optional<Node> bandit_get_node_like(Node node)
484 {
485 // TODO: use search_space.terminal_types here (and in search_space get_node_like as well)
487
488 return bandit_sample_terminal(node.ret_type);
489 }
490
491 if (op_bandits.find(node.ret_type) == op_bandits.end()) {
492
493 return std::nullopt;
494 }
495 if (op_bandits.at(node.ret_type).find(node.args_type()) == op_bandits.at(node.ret_type).end()) {
496
497 return std::nullopt;
498 }
499
500 auto& bandit = op_bandits[node.ret_type][node.args_type()];
501 string node_name = bandit.choose();
502
503 auto entries = search_space.node_map[node.ret_type][node.args_type()];
504
505 for (const auto& [node_type, node_value]: entries)
506 {
507 if (node_value.name == node_name) {
508 return node_value;
509 }
510 }
511
512 return std::nullopt;
513 };
514
515 // bandit_sample_op_with_arg
516 std::optional<Node> bandit_sample_op_with_arg(DataType ret, DataType arg, int max_args=0)
517 {
518 auto args_map = search_space.node_map.at(ret);
519 vector<size_t> matches;
520 vector<float> weights;
521
522 for (const auto& [args_type, name_map]: args_map) {
523 for (const auto& [name, node]: name_map) {
524 auto node_arg_types = node.get_arg_types();
525
526 auto within_size_limit = !(max_args) || (node.get_arg_count() <= max_args);
527
528 if (in(node_arg_types, arg)
529 && within_size_limit
530 && search_space.node_map_weights.at(ret).at(args_type).at(name) > 0.0f ) {
531 // if it can be sampled
532 matches.push_back(node.args_type());
533 }
534 }
535 }
536
537 if (matches.size()==0)
538 return std::nullopt;
539
540 // we randomly select args type. This is what determines which bandit to use
541 auto args_type = *r.select_randomly(matches.begin(),
542 matches.end() );
543 auto& bandit = op_bandits[ret][args_type];
544 string node_name = bandit.choose();
545
546 // TODO: this could be more efficient
547 auto entries = search_space.node_map[ret][args_type];
548 for (const auto& [node_type, node_value]: entries)
549 {
550 if (node_value.name == node_name) {
551 return node_value;
552 }
553 }
554
555 return std::nullopt;
556 };
557
558 // bandit_sample_op
559 std::optional<Node> bandit_sample_op(DataType ret)
560 {
561 if (search_space.node_map.find(ret) == search_space.node_map.end())
562 return std::nullopt;
563
564 // any bandit to do the job
565 auto& [args_type, bandit] = *r.select_randomly(op_bandits[ret].begin(),
566 op_bandits[ret].end() );
567
568 string node_name = bandit.choose();
569
570 auto entries = search_space.node_map[ret][args_type];
571 for (const auto& [node_type, node_value]: entries)
572 {
573 if (node_value.name == node_name) {
574 return node_value;
575 }
576 }
577
578 return std::nullopt;
579 };
580
581 inline void log_simplification_table(std::ofstream& log) {
582 inexact_simplifier.log_simplification_table(log);
583 };
584
585 // bandit_sample_subtree // TODO: should I implement this? (its going to be hard).
586 // without this one being performed directly by the bandits, we then rely on
587 // the sampled probabilities we update after every generation. Since there are lots
588 // of samplings, I think it is ok to not update them and just use the distribution they learned.
589
590 // they need to be references because we are going to modify them
591 SearchSpace search_space; // The search space for the variation operator.
592 Dataset& data; // the data used to extract context and evaluate the models
593 Parameters parameters; // The parameters for the variation operator
594private:
595 // bandits will internaly work as an interface between variation and its searchspace.
596 // they will sample from the SS (instead of letting the search space do it directly),
597 // and also propagate what they learn back to the search space at the end of the execution.
599 map<DataType, Bandit> terminal_bandits;
600 map<DataType, map<size_t, Bandit>> op_bandits;
601
602 // simplification methods
605};
606
607// // Explicitly instantiate the template for brush program types
608// template class Variation<ProgramType::Regressor>;
609// template class Variation<ProgramType::BinaryClassifier>;
610// template class Variation<ProgramType::MulticlassClassifier>;
611// template class Variation<ProgramType::Representer>;
612
614public:
615 using Iter = tree<Node>::pre_order_iterator;
616
617 template<Brush::ProgramType T>
618 static auto find_spots(Program<T>& program, Variation<T>& variator,
619 const Parameters& params)
620 {
621 vector<float> weights(program.Tree.size());
622
623 // by default, mutation can happen anywhere, based on node weights
624 std::transform(program.Tree.begin(), program.Tree.end(), weights.begin(),
625 [&](const auto& n){ return n.get_prob_change();});
626
627 // Must have same size as tree, even if all weights <= 0.0
628 return weights;
629 }
630
631 template<Brush::ProgramType T>
632 static auto mutate(Program<T>& program, Iter spot, Variation<T>& variator,
633 const Parameters& params);
634};
635
636} //namespace Var
637} //namespace Brush
638#endif
holds variable type data.
Definition data.h:51
Class for evaluating the fitness of individuals in a population.
Definition evaluation.h:27
void assign_fit(Individual< T > &ind, const Dataset &data, const Parameters &params, bool val=false)
Assign fitness to an individual.
static std::map< std::string, float > weightsMap
set parent ids using id values
Definition individual.h:139
void set_id(unsigned i)
Definition individual.h:122
vector< Node > get_sampled_nodes() const
Definition individual.h:119
Fitness fitness
aggregate fitness score
Definition individual.h:37
string get_variation() const
Definition individual.h:113
vector< string > get_objectives() const
Definition individual.h:152
void set_parents(const vector< Individual< T > > &parents)
Definition individual.h:123
void init(SearchSpace &ss, const Parameters &params)
Definition individual.h:52
Program< T > program
executable data structure
Definition individual.h:17
vector< size_t > get_island_indexes(int island)
Definition population.h:39
vector< std::shared_ptr< Individual< T > > > individuals
Definition population.h:19
static auto find_spots(Program< T > &program, Variation< T > &variator, const Parameters &params)
Definition variation.h:618
tree< Node >::pre_order_iterator Iter
Definition variation.h:615
static auto mutate(Program< T > &program, Iter spot, Variation< T > &variator, const Parameters &params)
Class representing the variation operators in Brush.
Definition variation.h:44
std::optional< Node > bandit_sample_op(DataType ret)
Definition variation.h:559
map< DataType, map< size_t, Bandit > > op_bandits
Definition variation.h:600
void vary_and_update(Population< T > &pop, int island, const vector< size_t > &parents, const Dataset &data, Evaluation< T > &evaluator, bool do_simplification)
Varies a population and updates the selection strategy based on rewards.
Definition variation.h:220
void vary(Population< T > &pop, int island, const vector< size_t > &parents)
Handles variation of a population.
void init()
Initializes the Variation object with parameters and search space.
Definition variation.h:76
Variation()=default
Default constructor.
Inexact_simplifier inexact_simplifier
Definition variation.h:604
SearchSpace search_space
Definition variation.h:591
~Variation()
Destructor.
Definition variation.h:68
std::optional< Node > bandit_sample_op_with_arg(DataType ret, DataType arg, int max_args=0)
Definition variation.h:516
Constants_simplifier constants_simplifier
Definition variation.h:603
Variation(Parameters &params, SearchSpace &ss, Dataset &d)
Constructor that initializes the Variation object with parameters and search space.
Definition variation.h:57
std::optional< Individual< T > > cross(const Individual< T > &mom, const Individual< T > &dad)
Performs croearch_spaceover operation on two individuals.
std::optional< Node > bandit_sample_terminal(DataType R)
Definition variation.h:458
Parameters parameters
Definition variation.h:593
void log_simplification_table(std::ofstream &log)
Definition variation.h:581
std::optional< Node > bandit_get_node_like(Node node)
Definition variation.h:483
map< DataType, Bandit > terminal_bandits
Definition variation.h:599
std::optional< Individual< T > > mutate(const Individual< T > &parent, string choice="")
Performs mutation operation on an individual.
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
bool in(const V &v, const T &i)
check if element is in vector.
Definition utils.h:192
static Rnd & r
Definition rnd.h:176
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:291
float get_loss() const
Definition fitness.h:64
unsigned int get_prev_depth() const
Definition fitness.h:87
float get_prev_loss() const
Definition fitness.h:65
unsigned int get_prev_size() const
Definition fitness.h:73
unsigned int get_prev_complexity() const
Definition fitness.h:78
bool valid() const
Definition fitness.h:155
void set_linear_complexity(unsigned int new_lc)
Definition fitness.h:80
void set_complexity(unsigned int new_c)
Definition fitness.h:75
float get_loss_v() const
Definition fitness.h:68
void set_loss_v(float f_v)
Definition fitness.h:67
void set_depth(unsigned int new_d)
Definition fitness.h:85
unsigned int get_prev_linear_complexity() const
Definition fitness.h:83
unsigned int get_complexity() const
Definition fitness.h:77
void set_size(unsigned int new_s)
Definition fitness.h:71
unsigned int get_depth() const
Definition fitness.h:86
unsigned int get_linear_complexity() const
Definition fitness.h:82
unsigned int get_size() const
Definition fitness.h:72
void set_loss(float f)
Definition fitness.h:63
The Bandit struct represents a multi-armed bandit.
Definition bandit.h:32
class holding the data for a node in a tree.
Definition node.h:84
NodeType node_type
the node type
Definition node.h:89
DataType ret_type
return data type
Definition node.h:92
std::size_t args_type() const
Definition node.h:180
An individual program, a.k.a. model.
Definition program.h:50
tree< Node > Tree
fitness
Definition program.h:73
Program< PType > & fit(const Dataset &d)
Definition program.h:150
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Definition program.h:110
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...