Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
variation.h
Go to the documentation of this file.
1/* Brush
2
3copyright 2020 William La Cava
4license: GNU/GPL v3
5*/
6#ifndef VARIATION_H
7#define VARIATION_H
8
9#include "../ind/individual.h"
10
11#include "../bandit/bandit.h"
13#include "../bandit/dummy.h"
14
15#include "../pop/population.h"
16#include "../eval/evaluation.h"
19
20#include <map>
21#include <optional>
22
23using namespace Brush::Pop;
24using namespace Brush::MAB;
25using namespace Brush::Eval;
26using namespace Brush::Simpl;
27
32namespace Brush {
33namespace Var {
34
43template<ProgramType T>
44class Variation {
45public:
49 Variation() = default;
50
58 : parameters(params)
59 , search_space(ss)
60 , data(d)
61 {
62 init();
63 };
64
69
76 void init(){
77 // initializing variation bandit with the probabilities of non-null variation
78 map<string, float> variation_probs;
79 for (const auto& mutation : parameters.get_mutation_probs())
80 // TODO: I need to figure out a better way of avoiding options that were not positive. Make sure that this does not break the code
81 if (mutation.second > 0.0)
82 variation_probs[mutation.first] = mutation.second;
83
84 if (parameters.cx_prob > 0.0)
85 variation_probs["cx"] = parameters.cx_prob;
86
87 this->variation_bandit = Bandit(parameters.bandit, variation_probs);
88
89 // TODO: should I set C parameter based on pop size or leave it fixed?
90 // TODO: update string comparisons to use .compare method
91 // if (parameters.bandit.compare("dynamic_thompson")==0)
92 // this->variation_bandit.pbandit.set_C(parameters.pop_size);
93
94 // initializing one bandit for each terminal type
95 for (const auto& entry : this->search_space.terminal_weights) {
96 // entry is a tuple <dataType, vector<float>> where the vector is the weights
97
98 // TODO: I dont think we need this find here
99 if (terminal_bandits.find(entry.first) == terminal_bandits.end())
100 {
101 map<string, float> terminal_probs;
102 for (int i = 0; i < entry.second.size(); i++)
103 // TODO: do not insert if coefficient is smaller than zero. Fix bandits to work with that
104 if (entry.second[i] > 0.0)
105 {
106 auto node_name = search_space.terminal_map.at(entry.first).at(i).get_feature();
107 terminal_probs[node_name] = entry.second[i];
108 }
109
110 if (terminal_probs.size()>0)
111 terminal_bandits[entry.first] = Bandit(parameters.bandit,
112 terminal_probs);
113 }
114 }
115
116 // one bandit for each return type. If we look at implementation of
117 // sample op, the thing that matters is the most nested probabilities, so we will
118 // learn only that
119 for (auto& [ret_type, arg_w_map]: search_space.node_map)
120 {
121 // if (op_bandits.find(ret_type) == op_bandits.end())
122 // op_bandits.at(ret_type) = map<size_t, Bandit>();
123
124 for (const auto& [args_type, node_map] : arg_w_map)
125 {
126 // if (op_bandits.at(ret_type).find(args_type) != op_bandits.at(ret_type).end())
127 // continue
128
129 // TODO: this could be made much easier using user_ops
130 map<string, float> node_probs;
131
132 for (const auto& [node_type, node]: node_map)
133 {
134 auto weight = search_space.node_map_weights.at(ret_type).at(args_type).at(node_type);
135
136 if (weight > 0.0)
137 {
138 // Attempt to emplace; if the key exists, do nothing
139 auto [it, inserted] = node_probs.try_emplace(node.name, weight);
140
141 // If the key already existed, update its value
142 if (!inserted) {
143 it->second = weight;
144 }
145 }
146 }
147
148 if (node_probs.size() > 0)
149 op_bandits[ret_type][args_type] = Bandit(parameters.bandit,
150 node_probs);
151 }
152 }
153
154 // ensuring all terminals exists as a simplification option
155 // Only initialize simplification tables if inexact_simplification is enabled
156 if (parameters.inexact_simplification) {
157 inexact_simplifier.init(256, data, 1);
158 for (const auto& entry : this->search_space.terminal_weights) {
159 map<string, float> terminal_probs;
160 for (int i = 0; i < entry.second.size(); i++)
161 if (entry.second[i] > 0.0)
162 {
163 Node node = search_space.terminal_map.at(entry.first).at(i);
164
165 tree<Node> dummy_tree;
166 dummy_tree.insert(dummy_tree.begin(), node);
167 auto it = dummy_tree.begin();
168 inexact_simplifier.index<T>(it, data.get_training_data());
169 }
170 }
171 }
172 };
173
182 std::optional<Individual<T>> cross(
183 const Individual<T>& mom, const Individual<T>& dad);
184
192 std::optional<Individual<T>> mutate(
193 const Individual<T>& parent, string choice="");
194
203 void vary(Population<T>& pop, int island, const vector<size_t>& parents);
204
211 void update_ss();
212
223 void vary_and_update(Population<T>& pop, int island, const vector<size_t>& parents,
224 const Dataset& data, Evaluation<T>& evaluator, bool do_simplification) {
225
226 // TODO: move implementation to cpp file and keep only declarations here
227 // TODO: rewrite this entire function to avoid repetition (this is a frankenstein)
228 auto indices = pop.get_island_indexes(island);
229
230 vector<std::shared_ptr<Individual<T>>> aux_individuals;
231 for (unsigned i = 0; i < indices.size(); ++i)
232 {
233 if (pop.individuals.at(indices.at(i)) != nullptr)
234 {
235 continue; // skipping if it is an individual --- we just want to fill invalid positions
236 }
237
238 // pass check for children undergoing variation
239 std::optional<Individual<T>> opt = std::nullopt; // new individual
240
241 // TODO: should this be randomly selected, or should I use each parent sequentially?
242 // auto idx = *r.select_randomly(parents.begin(), parents.end());
243 auto idx = parents.at(i % parents.size()); // use modulo to cycle through parents
244
245 const Individual<T>& mom = pop[idx];
246
247 // if we got here, then the individual is not fully locked and we can proceed with mutation
248 vector<Individual<T>> ind_parents = {mom};
249 string choice;
250
251 // this assumes that islands do not share indexes before doing variation
252 unsigned id = parameters.current_gen * parameters.pop_size + indices.at(i);
253
254 Individual<T> ind; // the new individual
255
256 // fully locked individuals should not be replaced by random ones. returning
257 // a copy
258 if (std::all_of(mom.program.Tree.begin(), mom.program.Tree.end(),
259 [](const auto& n) { return n.get_prob_change()<=0.0; }))
260 {
261 // Notice that if everything is locked then the entire population
262 // may be replaced (if the new random individuals dominates the old
263 // fixed ones). below we force to repeat individuals
264 ind = Individual<T>();
266
267 // Alternative: keep it as it is, and just re-fit the constants
268 // (comment out just the line below to disable, but keep ind.init)
269 Program<T> copy(mom.program);
270 ind.program = copy;
271 ind.variation = "clone";
272 }
273 else
274 {
275 choice = this->variation_bandit.choose();
276
277 if (choice.compare("cx") == 0)
278 {
279 // const Individual<T>& dad = pop[
280 // *r.select_randomly(parents.begin(), parents.end())];
281 const Individual<T>& dad = pop[parents.at((i+1) % parents.size())]; // use modulo to cycle through parents
282
283 opt = cross(mom, dad);
284 ind_parents.push_back(dad);
285 }
286 else
287 {
288 opt = mutate(mom, choice);
289 }
290
291 if (opt) // variation worked, lets keep this
292 {
293 ind = opt.value();
294 ind.set_parents(ind_parents);
295 }
296 else { // no optional value was returned. creating a new random individual
297 ind = Individual<T>();
299
300 // creates a new random individual
301 // ind.init(search_space, parameters);
302 // ind.variation = "born";
303 // ---------------------------------------------------------------
304
305 // instead of creating something new (code above), I will apply
306 // subtree mutation, so we can still have the fixed part of programs
307 int tries = 0;
308 while (tries++<=3 && !opt) { // try subtree mutation a few times before giving up
309 opt = this->mutate(mom, "subtree");
310 }
311
312 // it is very unlikely that subtree will fail, but in case it does, then
313 // we set variation to be a clone so we know it happened
314 if (opt) {
315 ind = opt.value();
316 ind.set_parents(ind_parents);
317 ind.variation = "subtree";
318 } else {
319 // fallback: mark as a subtree attempt that failed to produce a new individual
320 Program<T> copy(mom.program);
321 ind.program = copy;
322 ind.variation = "clone";
323 }
324 }
325 }
326
327 // for debugging
328 // cout << "tried " << choice << ", got " << ind.variation << endl;
329 // cout << "mom : " << mom.program.get_model() << endl;
330 // cout << "child: " << ind.program.get_model() << endl;
331
332 // ind.set_objectives(mom.get_objectives()); // it will have an invalid fitness
333
334 ind.set_id(id);
335 ind.fitness.set_loss(mom.fitness.get_loss());
337 ind.fitness.set_size(mom.fitness.get_size());
341
342
343 assert(ind.program.size() > 0);
344 assert(ind.fitness.valid() == false);
345
346 ind.program.fit(data.get_training_data());
347
348 // simplify before calculating fitness (order matters, as they are not refitted and constants simplifier does not replace with the right value.)
349 // simplify constants first to avoid letting the lsh simplifier to visit redundant branches
350
351 if (parameters.constants_simplification && do_simplification)
352 {
353 constants_simplifier.simplify_tree<T>(ind.program, search_space, data.get_training_data());
354 }
355
356 if (parameters.inexact_simplification)
357 {
358 auto inputDim = std::min(inexact_simplifier.inputDim, data.get_training_data().get_n_samples());
359
360 vector<size_t> idx(inputDim);
361 std::iota(idx.begin(), idx.end(), 0);
362 Dataset data_simp = data(idx);
363
364 if (do_simplification)
365 {
366 // string prg_str = ind.program.get_model();
367
368 inexact_simplifier.simplify_tree<T>(ind.program, search_space, data_simp);
369
370 // if (ind.program.get_model().compare(prg_str)!= 0)
371 // cout << prg_str << endl << ind.program.get_model() << endl << "=====" << endl;
372 }
373 else
374 {
375 inexact_simplifier.analyze_tree<T>(ind.program, search_space, data_simp);
376 }
377 }
378
379 evaluator.assign_fit(ind, data, parameters, false);
380
381 // vector<float> deltas(ind.get_objectives().size(), 0.0f);
382 vector<float> deltas;
383
384 float delta = 0.0f;
385 float weight = 0.0f;
386
387 for (const auto& obj : ind.get_objectives())
388 {
389 // some objectives are unsigned int, which can have weird values if we
390 // do subtractions. Instead, for these cases, we calculate a placeholder
391 // value indicating only if it was greater or not, so we can deal with
392 // this issue.
393
394 if (obj.compare(parameters.scorer) == 0) {
396 }
397 else if (obj.compare("complexity") == 0) {
398 delta = ind.fitness.get_complexity() > ind.fitness.get_prev_complexity() ? 1.0 : -1.0 ;
399 }
400 else if (obj.compare("linear_complexity") == 0) {
402 }
403 else if (obj.compare("size") == 0) {
404 delta = ind.fitness.get_size() > ind.fitness.get_prev_size() ? 1.0 : -1.0;
405 }
406 else if (obj.compare("depth") == 0) {
407 delta = ind.fitness.get_depth() > ind.fitness.get_prev_depth() ? 1.0 : -1.0;
408 }
409 else {
410 HANDLE_ERROR_THROW(obj + " is not a known objective");
411 }
412
413 auto it = Individual<T>::weightsMap.find(obj);
414 if (it == Individual<T>::weightsMap.end()) {
415 HANDLE_ERROR_THROW("Weight not found for objective: " + obj);
416 }
417
418 weight = it->second;
419 float weighted_delta = delta * weight;
420 deltas.push_back(weighted_delta);
421 }
422
423 bool allPositive = true;
424 bool allNegative = true;
425 for (float d : deltas) {
426 if (d < 0)
427 allPositive = false;
428 if (d > 0)
429 allNegative = false;
430 }
431
432 float r = 0.0;
433 if (allPositive && !allNegative)
434 r = 1.0;
435
436 if (!ind.get_variation().compare("born")
437 && !ind.get_variation().compare("cx")
438 && !ind.get_variation().compare("subtree") // TODO: handle subtree
439 )
440 {
441 this->variation_bandit.update(ind.get_variation(), r);
442
443 if (ind.get_sampled_nodes().size() > 0) {
444 const auto& changed_nodes = ind.get_sampled_nodes();
445 for (auto& node : changed_nodes) {
446 if (node.get_arg_count() == 0) {
447 auto datatype = node.get_ret_type();
448
449 this->terminal_bandits[datatype].update(node.get_feature(), r);
450 }
451 else {
452 auto ret_type = node.get_ret_type();
453 auto args_type = node.args_type();
454 auto name = node.name;
455
456 this->op_bandits[ret_type][args_type].update(name, r);
457 }
458 }
459 }
460 }
461 else
462 { // giving zero reward if the variation failed
463 this->variation_bandit.update(choice, 0.0);
464 }
465
466 // aux_individuals.push_back(std::make_shared<Individual<T>>(ind));
467 pop.individuals.at(indices.at(i)) = std::make_shared<Individual<T>>(ind);
468
469 }
470
471 // updating the population with the new individual
472 // int aux_index = 0;
473 // for (unsigned i = 0; i < indices.size(); ++i)
474 // {
475 // if (pop.individuals.at(indices.at(i)) != nullptr)
476 // {
477 // // the nullptrs should be at the end of the vector
478 // pop.individuals.at(indices.at(i)) = aux_individuals.at(aux_index);
479 // aux_index++;
480 // }
481 // }
482 }
483
484 // these functions below will extract context and use it to choose the nodes to replace
485 // bandit_sample_terminal
486 std::optional<Node> bandit_sample_terminal(DataType R)
487 {
488 if (terminal_bandits.find(R) == terminal_bandits.end()) {
489
490 return std::nullopt;
491 }
492
493 auto& bandit = terminal_bandits.at(R);
494 string terminal_name = bandit.choose();
495
496 auto it = std::find_if(
497 search_space.terminal_map.at(R).begin(),
498 search_space.terminal_map.at(R).end(),
499 [&](auto& node) { return node.get_feature() == terminal_name; });
500
501 if (it != search_space.terminal_map.at(R).end()) {
502 auto index = std::distance(search_space.terminal_map.at(R).begin(), it);
503
504 return search_space.terminal_map.at(R).at(index);
505 }
506
507 return std::nullopt;
508 };
509
510 // bandit_get_node_like
511 std::optional<Node> bandit_get_node_like(Node node)
512 {
513 // TODO: use search_space.terminal_types here (and in search_space get_node_like as well)
515
516 return bandit_sample_terminal(node.ret_type);
517 }
518
519 if (op_bandits.find(node.ret_type) == op_bandits.end()) {
520
521 return std::nullopt;
522 }
523 if (op_bandits.at(node.ret_type).find(node.args_type()) == op_bandits.at(node.ret_type).end()) {
524
525 return std::nullopt;
526 }
527
528 auto& bandit = op_bandits[node.ret_type][node.args_type()];
529 string node_name = bandit.choose();
530
531 auto entries = search_space.node_map[node.ret_type][node.args_type()];
532
533 for (const auto& [node_type, node_value]: entries)
534 {
535 if (node_value.name == node_name) {
536 return node_value;
537 }
538 }
539
540 return std::nullopt;
541 };
542
543 // bandit_sample_op_with_arg
544 std::optional<Node> bandit_sample_op_with_arg(DataType ret, DataType arg, int max_args=0)
545 {
546 auto args_map = search_space.node_map.at(ret);
547 vector<size_t> matches;
548 vector<float> weights;
549
550 for (const auto& [args_type, name_map]: args_map) {
551 for (const auto& [name, node]: name_map) {
552 auto node_arg_types = node.get_arg_types();
553
554 auto within_size_limit = !(max_args) || (node.get_arg_count() <= max_args);
555
556 if (in(node_arg_types, arg)
557 && within_size_limit
558 && search_space.node_map_weights.at(ret).at(args_type).at(name) > 0.0f ) {
559 // if it can be sampled
560 matches.push_back(node.args_type());
561 }
562 }
563 }
564
565 if (matches.size()==0)
566 return std::nullopt;
567
568 // we randomly select args type. This is what determines which bandit to use
569 auto args_type = *r.select_randomly(matches.begin(),
570 matches.end() );
571 auto& bandit = op_bandits[ret][args_type];
572 string node_name = bandit.choose();
573
574 // TODO: this could be more efficient
575 auto entries = search_space.node_map[ret][args_type];
576 for (const auto& [node_type, node_value]: entries)
577 {
578 if (node_value.name == node_name) {
579 return node_value;
580 }
581 }
582
583 return std::nullopt;
584 };
585
586 // bandit_sample_op
587 std::optional<Node> bandit_sample_op(DataType ret)
588 {
589 if (search_space.node_map.find(ret) == search_space.node_map.end())
590 return std::nullopt;
591
592 // any bandit to do the job
593 auto& [args_type, bandit] = *r.select_randomly(op_bandits[ret].begin(),
594 op_bandits[ret].end() );
595
596 string node_name = bandit.choose();
597
598 auto entries = search_space.node_map[ret][args_type];
599 for (const auto& [node_type, node_value]: entries)
600 {
601 if (node_value.name == node_name) {
602 return node_value;
603 }
604 }
605
606 return std::nullopt;
607 };
608
609 inline void log_simplification_table(std::ofstream& log) {
610 inexact_simplifier.log_simplification_table(log);
611 };
612
613 // bandit_sample_subtree // TODO: should I implement this? (its going to be hard).
614 // without this one being performed directly by the bandits, we then rely on
615 // the sampled probabilities we update after every generation. Since there are lots
616 // of samplings, I think it is ok to not update them and just use the distribution they learned.
617
618 // they need to be references because we are going to modify them
619 SearchSpace search_space; // The search space for the variation operator.
620 Dataset& data; // the data used to extract context and evaluate the models
621 Parameters parameters; // The parameters for the variation operator
622private:
623 // bandits will internaly work as an interface between variation and its searchspace.
624 // they will sample from the SS (instead of letting the search space do it directly),
625 // and also propagate what they learn back to the search space at the end of the execution.
627 map<DataType, Bandit> terminal_bandits;
628 map<DataType, map<size_t, Bandit>> op_bandits;
629
630 // simplification methods
633};
634
636public:
637 using Iter = tree<Node>::pre_order_iterator;
638
639 template<Brush::ProgramType T>
640 static auto find_spots(Program<T>& program, Variation<T>& variator,
641 const Parameters& params)
642 {
643 vector<float> weights(program.Tree.size());
644
645 // by default, mutation can happen anywhere, based on node weights
646 std::transform(program.Tree.begin(), program.Tree.end(), weights.begin(),
647 [&](const auto& n){ return n.get_prob_change();});
648
649 // Must have same size as tree, even if all weights <= 0.0
650 return weights;
651 }
652
653 template<Brush::ProgramType T>
654 static auto mutate(Program<T>& program, Iter spot, Variation<T>& variator,
655 const Parameters& params);
656};
657
658extern template class Variation<PT::Regressor>;
659extern template class Variation<PT::BinaryClassifier>;
660extern template class Variation<PT::MulticlassClassifier>;
661extern template class Variation<PT::Representer>;
662
663} //namespace Var
664} //namespace Brush
665#endif
holds variable type data.
Definition data.h:51
Class for evaluating the fitness of individuals in a population.
Definition evaluation.h:27
void assign_fit(Individual< T > &ind, const Dataset &data, const Parameters &params, bool val=false)
Assign fitness to an individual.
static std::map< std::string, float > weightsMap
set parent ids using id values
Definition individual.h:163
void set_id(unsigned i)
Definition individual.h:147
vector< Node > get_sampled_nodes() const
Definition individual.h:144
Fitness fitness
aggregate fitness score
Definition individual.h:37
string get_variation() const
Definition individual.h:138
vector< string > get_objectives() const
Definition individual.h:176
void set_parents(const vector< Individual< T > > &parents)
Definition individual.h:148
void init(SearchSpace &ss, const Parameters &params)
Definition individual.h:52
Program< T > program
executable data structure
Definition individual.h:17
vector< size_t > get_island_indexes(int island)
Definition population.h:39
vector< std::shared_ptr< Individual< T > > > individuals
Definition population.h:19
static auto find_spots(Program< T > &program, Variation< T > &variator, const Parameters &params)
Definition variation.h:640
tree< Node >::pre_order_iterator Iter
Definition variation.h:637
static auto mutate(Program< T > &program, Iter spot, Variation< T > &variator, const Parameters &params)
Class representing the variation operators in Brush.
Definition variation.h:44
std::optional< Node > bandit_sample_op(DataType ret)
Definition variation.h:587
map< DataType, map< size_t, Bandit > > op_bandits
Definition variation.h:628
void vary_and_update(Population< T > &pop, int island, const vector< size_t > &parents, const Dataset &data, Evaluation< T > &evaluator, bool do_simplification)
Varies a population and updates the selection strategy based on rewards.
Definition variation.h:223
void vary(Population< T > &pop, int island, const vector< size_t > &parents)
Handles variation of a population.
void init()
Initializes the Variation object with parameters and search space.
Definition variation.h:76
Variation()=default
Default constructor.
Inexact_simplifier inexact_simplifier
Definition variation.h:632
SearchSpace search_space
Definition variation.h:619
~Variation()
Destructor.
Definition variation.h:68
std::optional< Node > bandit_sample_op_with_arg(DataType ret, DataType arg, int max_args=0)
Definition variation.h:544
Constants_simplifier constants_simplifier
Definition variation.h:631
Variation(Parameters &params, SearchSpace &ss, Dataset &d)
Constructor that initializes the Variation object with parameters and search space.
Definition variation.h:57
std::optional< Individual< T > > cross(const Individual< T > &mom, const Individual< T > &dad)
Performs croearch_spaceover operation on two individuals.
std::optional< Node > bandit_sample_terminal(DataType R)
Definition variation.h:486
Parameters parameters
Definition variation.h:621
void log_simplification_table(std::ofstream &log)
Definition variation.h:609
std::optional< Node > bandit_get_node_like(Node node)
Definition variation.h:511
map< DataType, Bandit > terminal_bandits
Definition variation.h:627
std::optional< Individual< T > > mutate(const Individual< T > &parent, string choice="")
Performs mutation operation on an individual.
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
bool in(const V &v, const T &i)
check if element is in vector.
Definition utils.h:192
static Rnd & r
Definition rnd.h:176
< nsga2 selection operator for getting the front
Definition bandit.cpp:3
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:313
float get_loss() const
Definition fitness.h:64
unsigned int get_prev_depth() const
Definition fitness.h:87
float get_prev_loss() const
Definition fitness.h:65
unsigned int get_prev_size() const
Definition fitness.h:73
unsigned int get_prev_complexity() const
Definition fitness.h:78
bool valid() const
Definition fitness.h:155
void set_linear_complexity(unsigned int new_lc)
Definition fitness.h:80
void set_complexity(unsigned int new_c)
Definition fitness.h:75
float get_loss_v() const
Definition fitness.h:68
void set_loss_v(float f_v)
Definition fitness.h:67
void set_depth(unsigned int new_d)
Definition fitness.h:85
unsigned int get_prev_linear_complexity() const
Definition fitness.h:83
unsigned int get_complexity() const
Definition fitness.h:77
void set_size(unsigned int new_s)
Definition fitness.h:71
unsigned int get_depth() const
Definition fitness.h:86
unsigned int get_linear_complexity() const
Definition fitness.h:82
unsigned int get_size() const
Definition fitness.h:72
void set_loss(float f)
Definition fitness.h:63
The Bandit struct represents a multi-armed bandit.
Definition bandit.h:32
class holding the data for a node in a tree.
Definition node.h:89
NodeType node_type
the node type
Definition node.h:94
DataType ret_type
return data type
Definition node.h:97
std::size_t args_type() const
Definition node.h:201
An individual program, a.k.a. model.
Definition program.h:49
tree< Node > Tree
fitness
Definition program.h:72
Program< PType > & fit(const Dataset &d)
Definition program.h:150
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Definition program.h:110
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...