Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
variation.h
Go to the documentation of this file.
1/* Brush
2
3copyright 2020 William La Cava
4license: GNU/GPL v3
5*/
6#ifndef VARIATION_H
7#define VARIATION_H
8
9#include "../ind/individual.h"
10
11#include "../bandit/bandit.h"
13#include "../bandit/dummy.h"
14
15#include "../pop/population.h"
16#include "../eval/evaluation.h"
19
20#include <map>
21#include <optional>
22
23using namespace Brush::Pop;
24using namespace Brush::MAB;
25using namespace Brush::Eval;
26using namespace Brush::Simpl;
27
32namespace Brush {
33namespace Var {
34
43template<ProgramType T>
44class Variation {
45public:
49 Variation() = default;
50
58 : parameters(params)
59 , search_space(ss)
60 , data(d)
61 {
62 init();
63 };
64
69
76 void init(){
77 // initializing variation bandit with the probabilities of non-null variation
78 map<string, float> variation_probs;
79 for (const auto& mutation : parameters.get_mutation_probs())
80 // TODO: I need to figure out a better way of avoiding options that were not positive. Make sure that this does not break the code
81 if (mutation.second > 0.0)
82 variation_probs[mutation.first] = mutation.second;
83
84 if (parameters.cx_prob > 0.0)
85 variation_probs["cx"] = parameters.cx_prob;
86
87 this->variation_bandit = Bandit(parameters.bandit, variation_probs);
88
89 // TODO: should I set C parameter based on pop size or leave it fixed?
90 // TODO: update string comparisons to use .compare method
91 // if (parameters.bandit.compare("dynamic_thompson")==0)
92 // this->variation_bandit.pbandit.set_C(parameters.pop_size);
93
94 // initializing one bandit for each terminal type
95 for (const auto& entry : this->search_space.terminal_weights) {
96 // entry is a tuple <dataType, vector<float>> where the vector is the weights
97
98 // TODO: I dont think we need this find here
99 if (terminal_bandits.find(entry.first) == terminal_bandits.end())
100 {
101 map<string, float> terminal_probs;
102 for (int i = 0; i < entry.second.size(); i++)
103 // TODO: do not insert if coefficient is smaller than zero. Fix bandits to work with that
104 if (entry.second[i] > 0.0)
105 {
106 auto node_name = search_space.terminal_map.at(entry.first).at(i).get_feature();
107 terminal_probs[node_name] = entry.second[i];
108 }
109
110 if (terminal_probs.size()>0)
111 terminal_bandits[entry.first] = Bandit(parameters.bandit,
112 terminal_probs);
113 }
114 }
115
116 // one bandit for each return type. If we look at implementation of
117 // sample op, the thing that matters is the most nested probabilities, so we will
118 // learn only that
119 for (auto& [ret_type, arg_w_map]: search_space.node_map)
120 {
121 // if (op_bandits.find(ret_type) == op_bandits.end())
122 // op_bandits.at(ret_type) = map<size_t, Bandit>();
123
124 for (const auto& [args_type, node_map] : arg_w_map)
125 {
126 // if (op_bandits.at(ret_type).find(args_type) != op_bandits.at(ret_type).end())
127 // continue
128
129 // TODO: this could be made much easier using user_ops
130 map<string, float> node_probs;
131
132 for (const auto& [node_type, node]: node_map)
133 {
134 auto weight = search_space.node_map_weights.at(ret_type).at(args_type).at(node_type);
135
136 if (weight > 0.0)
137 {
138 // Attempt to emplace; if the key exists, do nothing
139 auto [it, inserted] = node_probs.try_emplace(node.name, weight);
140
141 // If the key already existed, update its value
142 if (!inserted) {
143 it->second = weight;
144 }
145 }
146 }
147
148 if (node_probs.size() > 0)
149 op_bandits[ret_type][args_type] = Bandit(parameters.bandit,
150 node_probs);
151 }
152 }
153
154 // ensuring all terminals exists as a simplification option
155 // Only initialize simplification tables if inexact_simplification is enabled
156 if (parameters.inexact_simplification) {
157 inexact_simplifier.init(256, data, 1);
158 for (const auto& entry : this->search_space.terminal_weights) {
159 map<string, float> terminal_probs;
160 for (int i = 0; i < entry.second.size(); i++)
161 if (entry.second[i] > 0.0)
162 {
163 Node node = search_space.terminal_map.at(entry.first).at(i);
164
165 tree<Node> dummy_tree;
166 dummy_tree.insert(dummy_tree.begin(), node);
167 auto it = dummy_tree.begin();
168 inexact_simplifier.index<T>(it, data.get_training_data());
169 }
170 }
171 }
172 };
173
182 std::optional<Individual<T>> cross(
183 const Individual<T>& mom, const Individual<T>& dad);
184
192 std::optional<Individual<T>> mutate(
193 const Individual<T>& parent, string choice="");
194
203 void vary(Population<T>& pop, int island, const vector<size_t>& parents);
204
211 void update_ss();
212
223 void vary_and_update(Population<T>& pop, int island, const vector<size_t>& parents,
224 const Dataset& data, Evaluation<T>& evaluator, bool do_simplification) {
225
226 // TODO: move implementation to cpp file and keep only declarations here
227 // TODO: rewrite this entire function to avoid repetition (this is a frankenstein)
228 auto indices = pop.get_island_indexes(island);
229
230 vector<std::shared_ptr<Individual<T>>> aux_individuals;
231 for (unsigned i = 0; i < indices.size(); ++i)
232 {
233 if (pop.individuals.at(indices.at(i)) != nullptr)
234 {
235 continue; // skipping if it is an individual --- we just want to fill invalid positions
236 }
237
238 // pass check for children undergoing variation
239 std::optional<Individual<T>> opt = std::nullopt; // new individual
240
241 // TODO: should this be randomly selected, or should I use each parent sequentially?
242 // auto idx = *r.select_randomly(parents.begin(), parents.end());
243 auto idx = parents.at(i % parents.size()); // use modulo to cycle through parents
244
245 const Individual<T>& mom = pop[idx];
246
247 // if we got here, then the individual is not fully locked and we can proceed with mutation
248 vector<Individual<T>> ind_parents = {mom};
249 string choice;
250
251 // this assumes that islands do not share indexes before doing variation
252 unsigned id = parameters.current_gen * parameters.pop_size + indices.at(i);
253
254 Individual<T> ind; // the new individual
255
256 // fully locked individuals should not be replaced by random ones. returning
257 // a copy
258 if (std::all_of(mom.program.Tree.begin(), mom.program.Tree.end(),
259 [](const auto& n) { return n.get_prob_change()<=0.0; }))
260 {
261 // Notice that if everything is locked then the entire population
262 // may be replaced (if the new random individuals dominates the old
263 // fixed ones). below we force to repeat individuals
264 ind = Individual<T>();
266
267 // Alternative: keep it as it is, and just re-fit the constants
268 // (comment out just the line below to disable, but keep ind.init)
269 Program<T> copy(mom.program);
270 ind.program = copy;
271 ind.variation = "clone";
272 }
273 else
274 {
275 choice = this->variation_bandit.choose();
276
277 if (choice.compare("cx") == 0)
278 {
279 // const Individual<T>& dad = pop[
280 // *r.select_randomly(parents.begin(), parents.end())];
281 const Individual<T>& dad = pop[parents.at((i+1) % parents.size())]; // use modulo to cycle through parents
282
283 opt = cross(mom, dad);
284 ind_parents.push_back(dad);
285 }
286 else
287 {
288 opt = mutate(mom, choice);
289 }
290
291 if (opt) // variation worked, lets keep this
292 {
293 ind = opt.value();
294 ind.set_parents(ind_parents);
295 }
296 else { // no optional value was returned. creating a new random individual
297 ind = Individual<T>();
299
300 // creates a new random individual
301 // ind.init(search_space, parameters);
302 // ind.variation = "born";
303 // ---------------------------------------------------------------
304
305 // instead of creating something new (code above), I will apply
306 // subtree mutation, so we can still have the fixed part of programs
307 int tries = 0;
308 while (tries++<=3 && !opt) { // try subtree mutation a few times before giving up
309 opt = this->mutate(mom, "subtree");
310 }
311
312 // it is very unlikely that subtree will fail, but in case it does, then
313 // we set variation to be a clone so we know it happened
314 if (opt) {
315 ind = opt.value();
316 ind.set_parents(ind_parents);
317 ind.variation = "subtree";
318 } else {
319 // fallback: mark as a subtree attempt that failed to produce a new individual
320 Program<T> copy(mom.program);
321 ind.program = copy;
322 ind.variation = "clone";
323 }
324 }
325 }
326
327 // for debugging
328 // cout << "tried " << choice << ", got " << ind.variation << endl;
329 // cout << "mom : " << mom.program.get_model() << endl;
330 // cout << "child: " << ind.program.get_model() << endl;
331
332 // ind.set_objectives(mom.get_objectives()); // it will have an invalid fitness
333
334 ind.set_id(id);
335 ind.fitness.set_loss(mom.fitness.get_loss());
337 ind.fitness.set_size(mom.fitness.get_size());
341
342
343 assert(ind.program.size() > 0);
344 assert(ind.fitness.valid() == false);
345
346 ind.program.fit(data.get_training_data());
347
348 // simplify before calculating fitness (order matters, as they are not refitted and constants simplifier does not replace with the right value.)
349 // TODO: constants_simplifier should set the correct value for the constant (so we dont have to refit).
350 // simplify constants first to avoid letting the lsh simplifier to visit redundant branches
351
352 if (parameters.constants_simplification && do_simplification)
353 {
354 constants_simplifier.simplify_tree<T>(ind.program, search_space, data.get_training_data());
355 }
356
357 if (parameters.inexact_simplification)
358 {
359 auto inputDim = std::min(inexact_simplifier.inputDim, data.get_training_data().get_n_samples());
360
361 vector<size_t> idx(inputDim);
362 std::iota(idx.begin(), idx.end(), 0);
363 Dataset data_simp = data(idx);
364
365 if (do_simplification)
366 {
367 // string prg_str = ind.program.get_model();
368
369 inexact_simplifier.simplify_tree<T>(ind.program, search_space, data_simp);
370
371 // if (ind.program.get_model().compare(prg_str)!= 0)
372 // cout << prg_str << endl << ind.program.get_model() << endl << "=====" << endl;
373 }
374 else
375 {
376 inexact_simplifier.analyze_tree<T>(ind.program, search_space, data_simp);
377 }
378 }
379
380 evaluator.assign_fit(ind, data, parameters, false);
381
382 // vector<float> deltas(ind.get_objectives().size(), 0.0f);
383 vector<float> deltas;
384
385 float delta = 0.0f;
386 float weight = 0.0f;
387
388 for (const auto& obj : ind.get_objectives())
389 {
390 // some objectives are unsigned int, which can have weird values if we
391 // do subtractions. Instead, for these cases, we calculate a placeholder
392 // value indicating only if it was greater or not, so we can deal with
393 // this issue.
394
395 if (obj.compare(parameters.scorer) == 0) {
397 }
398 else if (obj.compare("complexity") == 0) {
399 delta = ind.fitness.get_complexity() > ind.fitness.get_prev_complexity() ? 1.0 : -1.0 ;
400 }
401 else if (obj.compare("linear_complexity") == 0) {
403 }
404 else if (obj.compare("size") == 0) {
405 delta = ind.fitness.get_size() > ind.fitness.get_prev_size() ? 1.0 : -1.0;
406 }
407 else if (obj.compare("depth") == 0) {
408 delta = ind.fitness.get_depth() > ind.fitness.get_prev_depth() ? 1.0 : -1.0;
409 }
410 else {
411 HANDLE_ERROR_THROW(obj + " is not a known objective");
412 }
413
414 auto it = Individual<T>::weightsMap.find(obj);
415 if (it == Individual<T>::weightsMap.end()) {
416 HANDLE_ERROR_THROW("Weight not found for objective: " + obj);
417 }
418
419 weight = it->second;
420 float weighted_delta = delta * weight;
421 deltas.push_back(weighted_delta);
422 }
423
424 bool allPositive = true;
425 bool allNegative = true;
426 for (float d : deltas) {
427 if (d < 0)
428 allPositive = false;
429 if (d > 0)
430 allNegative = false;
431 }
432
433 float r = 0.0;
434 if (allPositive && !allNegative)
435 r = 1.0;
436
437 if (!ind.get_variation().compare("born")
438 && !ind.get_variation().compare("cx")
439 && !ind.get_variation().compare("subtree") // TODO: handle subtree
440 )
441 {
442 this->variation_bandit.update(ind.get_variation(), r);
443
444 if (ind.get_sampled_nodes().size() > 0) {
445 const auto& changed_nodes = ind.get_sampled_nodes();
446 for (auto& node : changed_nodes) {
447 if (node.get_arg_count() == 0) {
448 auto datatype = node.get_ret_type();
449
450 this->terminal_bandits[datatype].update(node.get_feature(), r);
451 }
452 else {
453 auto ret_type = node.get_ret_type();
454 auto args_type = node.args_type();
455 auto name = node.name;
456
457 this->op_bandits[ret_type][args_type].update(name, r);
458 }
459 }
460 }
461 }
462 else
463 { // giving zero reward if the variation failed
464 this->variation_bandit.update(choice, 0.0);
465 }
466
467 // aux_individuals.push_back(std::make_shared<Individual<T>>(ind));
468 pop.individuals.at(indices.at(i)) = std::make_shared<Individual<T>>(ind);
469
470 }
471
472 // updating the population with the new individual
473 // int aux_index = 0;
474 // for (unsigned i = 0; i < indices.size(); ++i)
475 // {
476 // if (pop.individuals.at(indices.at(i)) != nullptr)
477 // {
478 // // the nullptrs should be at the end of the vector
479 // pop.individuals.at(indices.at(i)) = aux_individuals.at(aux_index);
480 // aux_index++;
481 // }
482 // }
483 }
484
485 // these functions below will extract context and use it to choose the nodes to replace
486 // bandit_sample_terminal
487 std::optional<Node> bandit_sample_terminal(DataType R)
488 {
489 if (terminal_bandits.find(R) == terminal_bandits.end()) {
490
491 return std::nullopt;
492 }
493
494 auto& bandit = terminal_bandits.at(R);
495 string terminal_name = bandit.choose();
496
497 auto it = std::find_if(
498 search_space.terminal_map.at(R).begin(),
499 search_space.terminal_map.at(R).end(),
500 [&](auto& node) { return node.get_feature() == terminal_name; });
501
502 if (it != search_space.terminal_map.at(R).end()) {
503 auto index = std::distance(search_space.terminal_map.at(R).begin(), it);
504
505 return search_space.terminal_map.at(R).at(index);
506 }
507
508 return std::nullopt;
509 };
510
511 // bandit_get_node_like
512 std::optional<Node> bandit_get_node_like(Node node)
513 {
514 // TODO: use search_space.terminal_types here (and in search_space get_node_like as well)
516
517 return bandit_sample_terminal(node.ret_type);
518 }
519
520 if (op_bandits.find(node.ret_type) == op_bandits.end()) {
521
522 return std::nullopt;
523 }
524 if (op_bandits.at(node.ret_type).find(node.args_type()) == op_bandits.at(node.ret_type).end()) {
525
526 return std::nullopt;
527 }
528
529 auto& bandit = op_bandits[node.ret_type][node.args_type()];
530 string node_name = bandit.choose();
531
532 auto entries = search_space.node_map[node.ret_type][node.args_type()];
533
534 for (const auto& [node_type, node_value]: entries)
535 {
536 if (node_value.name == node_name) {
537 return node_value;
538 }
539 }
540
541 return std::nullopt;
542 };
543
544 // bandit_sample_op_with_arg
545 std::optional<Node> bandit_sample_op_with_arg(DataType ret, DataType arg, int max_args=0)
546 {
547 auto args_map = search_space.node_map.at(ret);
548 vector<size_t> matches;
549 vector<float> weights;
550
551 for (const auto& [args_type, name_map]: args_map) {
552 for (const auto& [name, node]: name_map) {
553 auto node_arg_types = node.get_arg_types();
554
555 auto within_size_limit = !(max_args) || (node.get_arg_count() <= max_args);
556
557 if (in(node_arg_types, arg)
558 && within_size_limit
559 && search_space.node_map_weights.at(ret).at(args_type).at(name) > 0.0f ) {
560 // if it can be sampled
561 matches.push_back(node.args_type());
562 }
563 }
564 }
565
566 if (matches.size()==0)
567 return std::nullopt;
568
569 // we randomly select args type. This is what determines which bandit to use
570 auto args_type = *r.select_randomly(matches.begin(),
571 matches.end() );
572 auto& bandit = op_bandits[ret][args_type];
573 string node_name = bandit.choose();
574
575 // TODO: this could be more efficient
576 auto entries = search_space.node_map[ret][args_type];
577 for (const auto& [node_type, node_value]: entries)
578 {
579 if (node_value.name == node_name) {
580 return node_value;
581 }
582 }
583
584 return std::nullopt;
585 };
586
587 // bandit_sample_op
588 std::optional<Node> bandit_sample_op(DataType ret)
589 {
590 if (search_space.node_map.find(ret) == search_space.node_map.end())
591 return std::nullopt;
592
593 // any bandit to do the job
594 auto& [args_type, bandit] = *r.select_randomly(op_bandits[ret].begin(),
595 op_bandits[ret].end() );
596
597 string node_name = bandit.choose();
598
599 auto entries = search_space.node_map[ret][args_type];
600 for (const auto& [node_type, node_value]: entries)
601 {
602 if (node_value.name == node_name) {
603 return node_value;
604 }
605 }
606
607 return std::nullopt;
608 };
609
610 inline void log_simplification_table(std::ofstream& log) {
611 inexact_simplifier.log_simplification_table(log);
612 };
613
614 // bandit_sample_subtree // TODO: should I implement this? (its going to be hard).
615 // without this one being performed directly by the bandits, we then rely on
616 // the sampled probabilities we update after every generation. Since there are lots
617 // of samplings, I think it is ok to not update them and just use the distribution they learned.
618
619 // they need to be references because we are going to modify them
620 SearchSpace search_space; // The search space for the variation operator.
621 Dataset& data; // the data used to extract context and evaluate the models
622 Parameters parameters; // The parameters for the variation operator
623private:
624 // bandits will internaly work as an interface between variation and its searchspace.
625 // they will sample from the SS (instead of letting the search space do it directly),
626 // and also propagate what they learn back to the search space at the end of the execution.
628 map<DataType, Bandit> terminal_bandits;
629 map<DataType, map<size_t, Bandit>> op_bandits;
630
631 // simplification methods
634};
635
636// // Explicitly instantiate the template for brush program types
637// template class Variation<ProgramType::Regressor>;
638// template class Variation<ProgramType::BinaryClassifier>;
639// template class Variation<ProgramType::MulticlassClassifier>;
640// template class Variation<ProgramType::Representer>;
641
643public:
644 using Iter = tree<Node>::pre_order_iterator;
645
646 template<Brush::ProgramType T>
647 static auto find_spots(Program<T>& program, Variation<T>& variator,
648 const Parameters& params)
649 {
650 vector<float> weights(program.Tree.size());
651
652 // by default, mutation can happen anywhere, based on node weights
653 std::transform(program.Tree.begin(), program.Tree.end(), weights.begin(),
654 [&](const auto& n){ return n.get_prob_change();});
655
656 // Must have same size as tree, even if all weights <= 0.0
657 return weights;
658 }
659
660 template<Brush::ProgramType T>
661 static auto mutate(Program<T>& program, Iter spot, Variation<T>& variator,
662 const Parameters& params);
663};
664
665} //namespace Var
666} //namespace Brush
667#endif
holds variable type data.
Definition data.h:51
Class for evaluating the fitness of individuals in a population.
Definition evaluation.h:27
void assign_fit(Individual< T > &ind, const Dataset &data, const Parameters &params, bool val=false)
Assign fitness to an individual.
static std::map< std::string, float > weightsMap
set parent ids using id values
Definition individual.h:165
void set_id(unsigned i)
Definition individual.h:148
vector< Node > get_sampled_nodes() const
Definition individual.h:145
Fitness fitness
aggregate fitness score
Definition individual.h:37
string get_variation() const
Definition individual.h:139
vector< string > get_objectives() const
Definition individual.h:178
void set_parents(const vector< Individual< T > > &parents)
Definition individual.h:149
void init(SearchSpace &ss, const Parameters &params)
Definition individual.h:52
Program< T > program
executable data structure
Definition individual.h:17
vector< size_t > get_island_indexes(int island)
Definition population.h:39
vector< std::shared_ptr< Individual< T > > > individuals
Definition population.h:19
static auto find_spots(Program< T > &program, Variation< T > &variator, const Parameters &params)
Definition variation.h:647
tree< Node >::pre_order_iterator Iter
Definition variation.h:644
static auto mutate(Program< T > &program, Iter spot, Variation< T > &variator, const Parameters &params)
Class representing the variation operators in Brush.
Definition variation.h:44
std::optional< Node > bandit_sample_op(DataType ret)
Definition variation.h:588
map< DataType, map< size_t, Bandit > > op_bandits
Definition variation.h:629
void vary_and_update(Population< T > &pop, int island, const vector< size_t > &parents, const Dataset &data, Evaluation< T > &evaluator, bool do_simplification)
Varies a population and updates the selection strategy based on rewards.
Definition variation.h:223
void vary(Population< T > &pop, int island, const vector< size_t > &parents)
Handles variation of a population.
void init()
Initializes the Variation object with parameters and search space.
Definition variation.h:76
Variation()=default
Default constructor.
Inexact_simplifier inexact_simplifier
Definition variation.h:633
SearchSpace search_space
Definition variation.h:620
~Variation()
Destructor.
Definition variation.h:68
std::optional< Node > bandit_sample_op_with_arg(DataType ret, DataType arg, int max_args=0)
Definition variation.h:545
Constants_simplifier constants_simplifier
Definition variation.h:632
Variation(Parameters &params, SearchSpace &ss, Dataset &d)
Constructor that initializes the Variation object with parameters and search space.
Definition variation.h:57
std::optional< Individual< T > > cross(const Individual< T > &mom, const Individual< T > &dad)
Performs croearch_spaceover operation on two individuals.
std::optional< Node > bandit_sample_terminal(DataType R)
Definition variation.h:487
Parameters parameters
Definition variation.h:622
void log_simplification_table(std::ofstream &log)
Definition variation.h:610
std::optional< Node > bandit_get_node_like(Node node)
Definition variation.h:512
map< DataType, Bandit > terminal_bandits
Definition variation.h:628
std::optional< Individual< T > > mutate(const Individual< T > &parent, string choice="")
Performs mutation operation on an individual.
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
bool in(const V &v, const T &i)
check if element is in vector.
Definition utils.h:192
static Rnd & r
Definition rnd.h:176
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:313
float get_loss() const
Definition fitness.h:64
unsigned int get_prev_depth() const
Definition fitness.h:87
float get_prev_loss() const
Definition fitness.h:65
unsigned int get_prev_size() const
Definition fitness.h:73
unsigned int get_prev_complexity() const
Definition fitness.h:78
bool valid() const
Definition fitness.h:155
void set_linear_complexity(unsigned int new_lc)
Definition fitness.h:80
void set_complexity(unsigned int new_c)
Definition fitness.h:75
float get_loss_v() const
Definition fitness.h:68
void set_loss_v(float f_v)
Definition fitness.h:67
void set_depth(unsigned int new_d)
Definition fitness.h:85
unsigned int get_prev_linear_complexity() const
Definition fitness.h:83
unsigned int get_complexity() const
Definition fitness.h:77
void set_size(unsigned int new_s)
Definition fitness.h:71
unsigned int get_depth() const
Definition fitness.h:86
unsigned int get_linear_complexity() const
Definition fitness.h:82
unsigned int get_size() const
Definition fitness.h:72
void set_loss(float f)
Definition fitness.h:63
The Bandit struct represents a multi-armed bandit.
Definition bandit.h:32
class holding the data for a node in a tree.
Definition node.h:89
NodeType node_type
the node type
Definition node.h:94
DataType ret_type
return data type
Definition node.h:97
std::size_t args_type() const
Definition node.h:201
An individual program, a.k.a. model.
Definition program.h:50
tree< Node > Tree
fitness
Definition program.h:73
Program< PType > & fit(const Dataset &d)
Definition program.h:151
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Definition program.h:111
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...