Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
variation.cpp
Go to the documentation of this file.
1#include "variation.h"
2
3namespace Brush {
4namespace Var {
5
6
7using namespace Brush;
8using namespace Pop;
9using namespace MAB;
10
11namespace {
12enum class MutationType {
13 Point,
14 Insert,
15 Delete,
16 Subtree,
17 ToggleWeightOn,
18 ToggleWeightOff,
19 Crossover,
20 Unknown
21};
22
23MutationType mutation_type_from_string(const std::string& choice) {
24 if (choice == "point")
25 return MutationType::Point;
26 if (choice == "insert")
27 return MutationType::Insert;
28 if (choice == "delete")
29 return MutationType::Delete;
30 if (choice == "subtree")
31 return MutationType::Subtree;
32 if (choice == "toggle_weight_on")
33 return MutationType::ToggleWeightOn;
34 if (choice == "toggle_weight_off")
35 return MutationType::ToggleWeightOff;
36 if (choice == "cx")
37 return MutationType::Crossover;
38 return MutationType::Unknown;
39}
40
41const char* mutation_type_to_string(MutationType choice) {
42 switch (choice) {
43 case MutationType::Point:
44 return "point";
45 case MutationType::Insert:
46 return "insert";
47 case MutationType::Delete:
48 return "delete";
49 case MutationType::Subtree:
50 return "subtree";
51 case MutationType::ToggleWeightOn:
52 return "toggle_weight_on";
53 case MutationType::ToggleWeightOff:
54 return "toggle_weight_off";
55 case MutationType::Crossover:
56 return "cx";
57 case MutationType::Unknown:
58 return "unknown";
59 }
60
61 return "unknown";
62}
63} // namespace
64
72{
73public:
74 template<Brush::ProgramType T>
75 static auto mutate(Program<T>& program, Iter spot, Variation<T>& variator,
76 const Parameters& params)
77 {
78 // get_node_like will sample a similar node based on node_map_weights or
79 // terminal_weights, and maybe will return a Node.
80
81 optional<Node> newNode = variator.bandit_get_node_like(spot.node->data);
82
83 if (!newNode) // overload to check if newNode == nullopt
84 return false;
85
86 // keeping the weight if it is fixed
87 if (IsWeighable(spot.node->data.node_type) && spot.node->data.weight_is_fixed){
88 // we do not need to check if the new node is also weightable, because
89 // the return type will be equivalent.
90
91 (*newNode).W = spot.node->data.W;
92 (*newNode).set_is_weighted(true);
93 (*newNode).weight_is_fixed=true;
94 }
95
96 // if optional contains a Node, we access its contained value as an address
97 program.Tree.replace(spot, *newNode);
98
99 return true;
100 }
101};
102
110{
111public:
112 template<Brush::ProgramType T>
113 static auto find_spots(Program<T>& program, Variation<T>& variator,
114 const Parameters& params)
115 {
116 vector<float> weights;
117
118 if (program.Tree.size() < params.get_max_size()) {
119 Iter iter = program.Tree.begin();
120 std::transform(program.Tree.begin(), program.Tree.end(), std::back_inserter(weights),
121 [&](const auto& n){
122 size_t d = 1+program.Tree.depth(iter);
123 std::advance(iter, 1);
124
125 // check if SS holds an operator to avoid failing `check` in sample_op_with_arg
126 if ((d >= params.get_max_depth())
127 || (variator.search_space.node_map.find(n.ret_type) == variator.search_space.node_map.end())) {
128 return 0.0f;
129 }
130 else {
131 return n.get_prob_change();
132 }
133 });
134 }
135 else {
136 // fill the vector with zeros, since we're already at max_size
137 weights.resize(program.Tree.size());
138 std::fill(weights.begin(), weights.end(), 0.0f);
139 }
140
141 return weights;
142 }
143
144 template<Brush::ProgramType T>
145 static auto mutate(Program<T>& program, Iter spot, Variation<T>& variator,
146 const Parameters& params)
147 {
148 auto spot_type = spot.node->data.ret_type;
149
150 // pick a random compatible node to insert (with probabilities given by
151 // node_map_weights). The `-1` represents the node being inserted.
152 // Ideally, it should always find at least one match (the same node
153 // used as a reference when calling the function). However, we have a
154 // size restriction, which will be relaxed here (just as it is in the PTC2
155 // algorithm). This mutation can create a new expression that exceeds the
156 // maximum size by the highest arity among the operators.
157
158 std::optional<Node> n = variator.bandit_sample_op_with_arg(
159 spot_type, spot_type, params.max_size-program.Tree.size()-1);
160
161 if (!n) // there is no operator with compatible arguments
162 return false;
163
164 // moving the fixed weight to the new inserted node
165 // (this should be done before Tree.wrap, because that function affects the spot reference)
166 if (IsWeighable(spot.node->data.node_type) && spot.node->data.weight_is_fixed){
167 Node& prev_n = spot.node->data;
168
169 // moving the fixed weight to the inserted node (n is the new node).
170 // because n is optional<node>, we need to solve the reference to access the node itself
171 // (it is wrapped in the optional<>)
172 (*n).W = spot.node->data.W;
173 (*n).set_is_weighted(true);
174 (*n).weight_is_fixed=true;
175
176 // toggling off the weight of the previous node
177 prev_n.set_is_weighted(false);
178 prev_n.weight_is_fixed=false;
179 }
180
181 // make node `n` wrap the subtree at the chosen spot
182 auto parent_node = program.Tree.wrap(spot, *n);
183
184 // now fill the arguments of n appropriately
185 bool spot_filled = false;
186 for (auto a: (*n).arg_types)
187 {
188 if (spot_filled)
189 {
190 // if spot is in its child position, append children.
191 auto opt = variator.bandit_sample_terminal(a);
192
193 if (!opt)
194 return false;
195
196 program.Tree.append_child(parent_node, opt.value());
197 }
198 // if types match, treat this spot as filled by the spot node
199 else if (a == spot_type)
200 spot_filled = true;
201 // otherwise, add siblings before spot node
202 else {
203 auto opt = variator.bandit_sample_terminal(a);
204
205 if (!opt)
206 return false;
207
208 program.Tree.insert(spot, opt.value());
209 }
210 }
211
212 return true;
213 }
214};
215
223{
224public:
225 template<Brush::ProgramType T>
226 static auto find_spots(Program<T>& program, Variation<T>& variator,
227 const Parameters& params)
228 {
229 vector<float> weights(program.Tree.size());
230
231 // by default, mutation can happen anywhere, based on node weights
232 std::transform(program.Tree.begin(), program.Tree.end(), weights.begin(),
233 [&](const auto& n){
234 // keeping the node if the weight is fixed
235 if (n.weight_is_fixed){
236 // we cant delete a node if its weight is fixed.
237 // we will let other mutations do their job and avoid deletion.
238 return 0.0f;
239 }
240
241 return n.get_prob_change(); // this already checks for node_is_fixed
242 });
243
244 // Must have same size as tree, even if all weights <= 0.0
245 return weights;
246 }
247
248 template<Brush::ProgramType T>
249 static auto mutate(Program<T>& program, Iter spot, Variation<T>& variator,
250 const Parameters& params)
251 {
252 // sample_terminal will sample based on terminal_weights. If it succeeds,
253 // then the new terminal will be in `opt.value()`
254
255 auto opt = variator.bandit_sample_terminal(spot.node->data.ret_type);
256
257 if (!opt) // there is no terminal with compatible arguments
258 return false;
259
260 program.Tree.erase_children(spot);
261
262 program.Tree.replace(spot, opt.value());
263
264 return true;
265 }
266};
267
275{
276public:
277 template<Brush::ProgramType T>
278 static auto find_spots(Program<T>& program, Variation<T>& variator,
279 const Parameters& params)
280 {
281 vector<float> weights(program.Tree.size());
282
283 if (program.Tree.size() < params.max_size) {
284 std::transform(program.Tree.begin(), program.Tree.end(), weights.begin(),
285 [&](const auto& n){
286 // some nodetypes must always have a weight
287 if (Is<NodeType::OffsetSum>(n.node_type) || Is<NodeType::Constant>(n.node_type))
288 return 0.0f;
289
290 // only weighted nodes can be toggled off
291 if ((!n.get_is_weighted())
292 && (!n.weight_is_fixed)
293 && IsWeighable(n.node_type))
294 {
295 return n.get_prob_change();
296 }
297 else
298 return 0.0f;
299 });
300 }
301 else {
302 // fill the vector with zeros, since we're already at max_size
303 std::fill(weights.begin(), weights.end(), 0.0f);
304 }
305
306 return weights;
307 }
308
309 template<Brush::ProgramType T>
310 static auto mutate(Program<T>& program, Iter spot, Variation<T>& variator,
311 const Parameters& params)
312 {
313 if (spot.node->data.get_is_weighted()==true // cant turn on whats already on
314 || !IsWeighable(spot.node->data.node_type)) // does not accept weights (e.g. boolean)
315 return false; // false indicates that mutation failed and should return std::nullopt
316
317 spot.node->data.set_is_weighted(true);
318 return true;
319 }
320};
321
329{
330public:
331 template<Brush::ProgramType T>
332 static auto find_spots(Program<T>& program, Variation<T>& variator,
333 const Parameters& params)
334 {
335 vector<float> weights(program.Tree.size());
336
337 std::transform(program.Tree.begin(), program.Tree.end(), weights.begin(),
338 [&](const auto& n){
339 // some nodetypes must always have a weight
340 if (Is<NodeType::OffsetSum>(n.node_type) || Is<NodeType::Constant>(n.node_type))
341 return 0.0f;
342
343 if (n.get_is_weighted()
344 && (!n.weight_is_fixed)
345 && IsWeighable(n.node_type))
346 return n.get_prob_change();
347 else
348 return 0.0f;
349 });
350
351 return weights;
352 }
353
354 template<Brush::ProgramType T>
355 static auto mutate(Program<T>& program, Iter spot, Variation<T>& variator,
356 const Parameters& params)
357 {
358 if (spot.node->data.get_is_weighted()==false) // TODO: This condition should never happen. Verified by find_spots; keep guard for safety. (this is also true for toggleweighton, also fix that)
359 return false;
360
361 spot.node->data.set_is_weighted(false);
362 return true;
363 }
364};
365
373{
374public:
375 template<Brush::ProgramType T>
376 static auto find_spots(Program<T>& program, Variation<T>& variator,
377 const Parameters& params)
378 {
379 vector<float> weights;
380
381 auto node_map = variator.search_space.node_map;
382
383 // The minimal size increment would be 2 - replacing a constant with a weighted terminal.
384 // we dont check for size constraints because the replacement can shrink the tree.
385 Iter iter = program.Tree.begin();
386 std::transform(program.Tree.begin(), program.Tree.end(), std::back_inserter(weights),
387 [&](const auto& n){
388 size_t d = program.Tree.depth(iter);
389 size_t s = program.Tree.size(iter);
390 std::advance(iter, 1);
391
392 // we need to make sure there's some node to start the subtree
393 if ((d >= params.max_depth)
394 || (node_map.find(n.ret_type) == node_map.end()) )
395 return 0.0f;
396 else
397 return n.get_prob_change();
398 });
399
400 return weights;
401 }
402
403 template<Brush::ProgramType T>
404 static auto mutate(Program<T>& program, Iter spot, Variation<T>& variator,
405 const Parameters& params)
406 {
407 // check if we exceeded the size/depth constrains (without subtracting,
408 // to avoid overflow cases if the user sets max_size smaller than arity
409 // of smallest operator. The overflow would happen when calculating d and
410 // s in the following lines, to choose the PTC2 limits)
411 if ( params.max_size <= (program.Tree.size() - program.Tree.size(spot))
412 || params.max_depth <= program.Tree.depth(spot) )
413 return false;
414
415 auto spot_type = spot.node->data.ret_type;
416
417 // d and s must be compatible with PTC2 --- they should be based on
418 // tree structure, not program structure
419 size_t d = params.max_depth - program.Tree.depth(spot);
420 size_t s;
421
422 // since `s` is size_t, we need to ensure the operation below will not overflow
423 if (program.Tree.size() < params.max_size)
424 s = params.max_size - (program.Tree.size() - program.Tree.size(spot));
425 else
426 s = 1;
427
428 s = r.rnd_int(1, s+1);
429
430 // sample subtree uses PTC2, which operates on depth and size of the tree<Node>
431 // (and not on the program!). we shoudn't care for weights here
432
433 auto subtree = variator.search_space.sample_subtree(spot.node->data, d, s);
434
435 if (!subtree) // there is no terminal with compatible arguments
436 return false;
437
438 // keeping the weight if it is fixed.
439 // I need to manipulate spot before Tree.erase_children!
440 if (IsWeighable(spot.node->data.node_type) && spot.node->data.weight_is_fixed){
441 Node& n = subtree.value().begin().node->data;
442
443 // moving the weight and fixing it
444 n.W = spot.node->data.W;
445 n.set_is_weighted(true);
446 n.weight_is_fixed=true;
447 }
448
449 // if optional contains a Node, we access its contained value
450 program.Tree.erase_children(spot);
451
452 program.Tree.move_ontop(spot, subtree.value().begin());
453
454 return true;
455 }
456};
457
481template<Brush::ProgramType T>
482std::optional<Individual<T>> Variation<T>::cross(
483 const Individual<T>& mom, const Individual<T>& dad)
484{
485 /* subtree crossover between this and other, producing new Program */
486 // choose location by weighted sampling of program
487 // TODO: why doesn't this copy the search space reference to child?
488 Program<T> child(mom.program);
489
490 // pick a subtree to replace
491 vector<float> child_weights(child.Tree.size());
492
493 auto child_iter = child.Tree.begin();
494 std::transform(child.Tree.begin(), child.Tree.end(), child_weights.begin(),
495 [&](const auto& n){
496 auto s_at = child.size_at(child_iter);
497 auto d_at = child.depth_to_reach(child_iter);
498
499 std::advance(child_iter, 1);
500
501 // We don't have to check size here, because it will be replaced
502 // by something with a valid new size.
503 if (
504 // s_at<parameters.max_size &&
505 d_at<parameters.max_depth
506 )
507 return n.get_prob_change();
508 else
509 return 0.0f;
510 }
511 );
512
513 if (std::all_of(child_weights.begin(), child_weights.end(), [](const auto& w) {
514 return w<=0.0;
515 }))
516 { // There is no spot that has a probability to be selected
517 return std::nullopt;
518 }
519
520 // pick a subtree to insert. Selection is based on other_weights
521 Program<T> other(dad.program);
522
523 int attempts = 0;
524 while (++attempts <= 3)
525 {
526 auto child_spot = r.select_randomly(child.Tree.begin(),
527 child.Tree.end(),
528 child_weights.begin(),
529 child_weights.end()
530 );
531
532 auto child_ret_type = child_spot.node->data.ret_type;
533
534 auto allowed_size = parameters.max_size -
535 ( child.size() - child.size_at(child_spot) );
536 auto allowed_depth = parameters.max_depth -
537 ( child.depth_to_reach(child_spot) );
538
539 vector<float> other_weights(other.Tree.size());
540
541 // Iterator to traverse the tree during transformation
542 auto other_iter = other.Tree.begin();
543 std::transform(other.Tree.begin(), other.Tree.end(), other_weights.begin(),
544 [&other, &other_iter, allowed_size, allowed_depth, child_ret_type](const auto& n) mutable {
545 int s = other.size_at(other_iter);
546 int d = other.depth_at(other_iter);
547
548 std::advance(other_iter, 1);
549
550 // Check feasibility and matching return type
551 if ( (s <= allowed_size)
552 && (d <= allowed_depth)
553 && (n.ret_type == child_ret_type) // this condition helps making sure the crossover will succeed, and also that we can keep fixed weights
554 ) {
555 return n.get_prob_change();
556 }
557
558 return 0.0f; // Non-feasible crossover point
559 }
560 );
561
562 bool matching_spots_found = std::any_of(other_weights.begin(), other_weights.end(),
563 [](float w) { return w > 0.0f; });
564
565 if (matching_spots_found) {
566 auto other_spot = r.select_randomly(
567 other.Tree.begin(), other.Tree.end(),
568 other_weights.begin(), other_weights.end()
569 );
570
571 // manipulate before move_ontop (it will mess references)
572 if (IsWeighable(child_spot.node->data.node_type) && child_spot.node->data.weight_is_fixed){
573 Node& n = other_spot.node->data;
574
575 // moving the weight and fixing it
576 n.W = child_spot.node->data.W;
577 n.set_is_weighted(true);
578 n.weight_is_fixed=true;
579 }
580
581 // fmt::print("other_spot : {}\n",other_spot.node->data);
582 // swap subtrees at child_spot and other_spot
583 child.Tree.move_ontop(child_spot, other_spot);
584
585 Individual<T> ind(child);
586 ind.set_variation(mutation_type_to_string(MutationType::Crossover));
587
588 return ind;
589 }
590 }
591
592 return std::nullopt;
593};
594
635template<Brush::ProgramType T>
636std::optional<Individual<T>> Variation<T>::mutate(
637 const Individual<T>& parent, string choice)
638{
639 if (choice.empty())
640 {
641 auto options = parameters.mutation_probs;
642
643 bool all_zero = true;
644 for (auto &it : parameters.mutation_probs) {
645 if (it.second > 0.0) {
646 all_zero = false;
647 break;
648 }
649 }
650
651 if (all_zero) { // No mutation can be successfully applied to this solution
652 return std::nullopt;
653 }
654
655 // picking a valid mutation option
656 choice = r.random_choice(parameters.mutation_probs);
657 }
658
659 const auto mutation_choice = mutation_type_from_string(choice);
660 if (mutation_choice == MutationType::Unknown) {
661 std::string msg = fmt::format("{} not a valid mutation choice", choice);
663 }
664
665 Program<T> copy(parent.program);
666
667 vector<float> weights; // choose location by weighted sampling of program
668 switch (mutation_choice) {
669 case MutationType::Point:
670 weights = PointMutation::find_spots(copy, (*this), parameters);
671 break;
672 case MutationType::Insert:
673 weights = InsertMutation::find_spots(copy, (*this), parameters);
674 break;
675 case MutationType::Delete:
676 weights = DeleteMutation::find_spots(copy, (*this), parameters);
677 break;
678 case MutationType::Subtree:
679 weights = SubtreeMutation::find_spots(copy, (*this), parameters);
680 break;
681 case MutationType::ToggleWeightOn:
682 weights = ToggleWeightOnMutation::find_spots(copy, (*this), parameters);
683 break;
684 case MutationType::ToggleWeightOff:
685 weights = ToggleWeightOffMutation::find_spots(copy, (*this), parameters);
686 break;
687 case MutationType::Crossover:
688 case MutationType::Unknown:
689 HANDLE_ERROR_THROW("Crossover is not a valid mutation choice\n");
690 break;
691 }
692
693 if (std::all_of(weights.begin(), weights.end(), [](const auto& w) {
694 return w<=0.0;
695 }))
696 { // There is no spot that has a probability to be selected
697 return std::nullopt;
698 }
699
700 int attempts = 0;
701 while(attempts++ < 3)
702 {
703 Program<T> child(parent.program);
704
705 // apply the mutation and check if it succeeded
706 auto spot = r.select_randomly(child.Tree.begin(), child.Tree.end(),
707 weights.begin(), weights.end());
708
709 // Every mutation here works inplace, so they return bool instead of
710 // std::optional to indicare the result of their manipulation over the
711 // program tree. Here we call the mutation function and return the result
712
713 bool success;
714 switch (mutation_choice) {
715 case MutationType::Point:
716 success = PointMutation::mutate(child, spot, (*this), parameters);
717 break;
718 case MutationType::Insert:
719 success = InsertMutation::mutate(child, spot, (*this), parameters);
720 break;
721 case MutationType::Delete:
722 success = DeleteMutation::mutate(child, spot, (*this), parameters);
723 break;
724 case MutationType::Subtree:
725 success = SubtreeMutation::mutate(child, spot, (*this), parameters);
726 break;
727 case MutationType::ToggleWeightOn:
728 success = ToggleWeightOnMutation::mutate(child, spot, (*this), parameters);
729 break;
730 case MutationType::ToggleWeightOff:
731 success = ToggleWeightOffMutation::mutate(child, spot, (*this), parameters);
732 break;
733 case MutationType::Crossover:
734 case MutationType::Unknown:
735 success = false;
736 break;
737 }
738
739 if (// strict mutation --- returns only valid solutions.
740 ( success
741 && (child.size() <= parameters.max_size)
742 && (child.depth() <= parameters.max_depth) )
743
744 // TODO: delete 2 commented lines below
745 // loose mutation --- it will try its best, but may return something slightly larger.
746 // || attempts==3 // this is the final attempt, return whatever we got.
747 ){
748 Individual<T> ind(child);
749
750 ind.set_variation(choice);
751
752 // subtree performs several samplings, and it will leverate
753 // what point/insert/delete mutations learned about each node utility.
754
755 // TODO: handle subtree - it will sample too many nodes and it may
756 // be hard to track which ones actually improved the expression to
757 // update the bandits/ maybe we should skip it?
758 // mutations that sampled from search space
759 if (choice.compare("point") == 0
760 || choice.compare("insert") == 0
761 || choice.compare("delete") == 0
762 ) {
763 ind.set_sampled_nodes({spot.node->data});
764 }
765
766 return ind;
767 }
768 else { // reseting
769 }
770 }
771
772 return std::nullopt;
773};
774
775template<Brush::ProgramType T>
776void Variation<T>::vary(Population<T>& pop, int island,
777 const vector<size_t>& parents)
778{
779 auto indices = pop.get_island_indexes(island);
780
781 for (unsigned i = 0; i<indices.size(); ++i)
782 {
783 if (pop.individuals.at(indices.at(i)) != nullptr)
784 {
785 continue; // skipping if it is an individual
786 }
787
788 // pass check for children undergoing variation
789 std::optional<Individual<T>> opt=std::nullopt; // new individual
790
791 const Individual<T>& mom = pop[
792 *r.select_randomly(parents.begin(), parents.end())];
793
794 vector<Individual<T>> ind_parents;
795
796 bool crossover = ( r() < parameters.cx_prob );
797 if (crossover)
798 {
799 const Individual<T>& dad = pop[
800 *r.select_randomly(parents.begin(), parents.end())];
801
802 auto variation_result = cross(mom, dad);
803 ind_parents = {mom, dad};
804 opt = variation_result;
805 }
806 else
807 {
808 auto variation_result = mutate(mom);
809
810 ind_parents = {mom};
811 opt = variation_result;
812 }
813
814 // this assumes that islands do not share indexes before doing variation
815 unsigned id = parameters.current_gen*parameters.pop_size+indices.at(i);
816
817 // mutation and crossover already perform 3 attempts. If it fails, we just fill with a random individual
818
819 Individual<T> ind;
820 if (opt) // variation worked, lets keep this
821 {
822 ind = opt.value();
823 ind.set_parents(ind_parents);
824 }
825 else { // no optional value was returned. creating a new random individual
826 // It seems that the line below will not fix the root in clf programs
827 ind.init(search_space, parameters); // ind.variation is born by default
828
829 // Program<T> p = search_space.make_program<Program<T>>(parameters, 0, 0);
830 // ind = Individual<T>(p);
831 }
832
833 ind.set_objectives(mom.get_objectives()); // it will have an invalid fitness
834
835 ind.is_fitted_ = false;
836 ind.set_id(id);
837
838 // TODO: smarter way of copying the entire fitness
839 // copying mom fitness to the new individual (without making the fitnes valid)
840 // so the bandits can access this information. Fitness will be valid
841 // only when we do set_values(). We are setting these parameters below
842 // because we want to keep the previous values for the bandits, and
843 // we are not updating the fitness values here.
844 ind.fitness.set_loss(mom.fitness.get_loss());
846 ind.fitness.set_size(mom.fitness.get_size());
850
851 // dont set stuff that is not used to calculate the rewards, like crowding_dist
852 // ind.fitness.set_crowding_dist(0.0);
853
854 assert(ind.program.size()>0);
855 assert(ind.fitness.valid()==false);
856
857 pop.individuals.at(indices.at(i)) = std::make_shared<Individual<T>>(ind);
858 }
859};
860
861template <Brush::ProgramType T>
863{
864 // propagate bandits learnt information to the search space.
865
866 // variation: getting new probabilities for variation operators
867 auto variation_probs = variation_bandit.sample_probs(true);
868
869 if (variation_probs.find("cx") != variation_probs.end())
870 parameters.set_cx_prob(variation_probs.at("cx"));
871
872 for (const auto& variation : variation_probs)
873 if (variation.first != "cx")
874 parameters.mutation_probs[variation.first] = variation.second;
875
876 // terminal: getting new probabilities for terminal nodes in search space
877 for (auto& bandit : terminal_bandits) {
878 auto datatype = bandit.first;
879
880 auto terminal_probs = bandit.second.sample_probs(true);
881
882 for (auto& [terminal_name, terminal_prob] : terminal_probs) {
883 // Search for the index that matches the terminal name
884 auto it = std::find_if(
885 search_space.terminal_map.at(datatype).begin(),
886 search_space.terminal_map.at(datatype).end(),
887 [&](auto& node) { return node.get_feature() == terminal_name; });
888
889 if (it == search_space.terminal_map.at(datatype).end()) {
890 continue;
891 }
892
893 auto index = std::distance(search_space.terminal_map.at(datatype).begin(), it);
894
895 // Update the terminal weights with the second value
896 search_space.terminal_weights.at(datatype)[index] = terminal_prob;
897 }
898 }
899
900 // operators: getting new probabilities for op nodes
901 for (auto& [ret_type, bandit_map] : op_bandits) {
902 for (auto& [args_type, bandit] : bandit_map) {
903 auto op_probs = bandit.sample_probs(true);
904
905 for (auto& [op_name, op_prob] : op_probs) {
906 bool updated = false;
907 for (const auto& [node_type, node_value]: search_space.node_map.at(ret_type).at(args_type))
908 {
909 if (node_value.name == op_name) {
910
911 search_space.node_map_weights.at(ret_type).at(args_type).at(node_type) = op_prob;
912 updated = true;
913 break;
914 }
915 }
916 if (!updated) {
917 continue;
918 }
919 }
920 }
921 }
922};
923
924} //namespace Var
925} //namespace Brush
926
void set_variation(string v)
Definition individual.h:137
void set_sampled_nodes(const vector< Node > &nodes)
Definition individual.h:143
void set_id(unsigned i)
Definition individual.h:147
Fitness fitness
aggregate fitness score
Definition individual.h:37
vector< string > get_objectives() const
Definition individual.h:176
void set_objectives(vector< string > objs)
Definition individual.h:177
void set_parents(const vector< Individual< T > > &parents)
Definition individual.h:148
void init(SearchSpace &ss, const Parameters &params)
Definition individual.h:52
Program< T > program
executable data structure
Definition individual.h:17
vector< size_t > get_island_indexes(int island)
Definition population.h:39
vector< std::shared_ptr< Individual< T > > > individuals
Definition population.h:19
delete subtree and replace it with a terminal of the same return type
static auto mutate(Program< T > &program, Iter spot, Variation< T > &variator, const Parameters &params)
static auto find_spots(Program< T > &program, Variation< T > &variator, const Parameters &params)
insert a node with spot as a child
static auto mutate(Program< T > &program, Iter spot, Variation< T > &variator, const Parameters &params)
static auto find_spots(Program< T > &program, Variation< T > &variator, const Parameters &params)
static auto find_spots(Program< T > &program, Variation< T > &variator, const Parameters &params)
Definition variation.h:640
tree< Node >::pre_order_iterator Iter
Definition variation.h:637
replace node with same typed node
Definition variation.cpp:72
static auto mutate(Program< T > &program, Iter spot, Variation< T > &variator, const Parameters &params)
Definition variation.cpp:75
replaces the subtree rooted in spot
static auto find_spots(Program< T > &program, Variation< T > &variator, const Parameters &params)
static auto mutate(Program< T > &program, Iter spot, Variation< T > &variator, const Parameters &params)
toggle the node's weight OFF
static auto find_spots(Program< T > &program, Variation< T > &variator, const Parameters &params)
static auto mutate(Program< T > &program, Iter spot, Variation< T > &variator, const Parameters &params)
toggle the node's weight ON
static auto mutate(Program< T > &program, Iter spot, Variation< T > &variator, const Parameters &params)
static auto find_spots(Program< T > &program, Variation< T > &variator, const Parameters &params)
Class representing the variation operators in Brush.
Definition variation.h:44
map< DataType, map< size_t, Bandit > > op_bandits
Definition variation.h:628
void vary(Population< T > &pop, int island, const vector< size_t > &parents)
Handles variation of a population.
SearchSpace search_space
Definition variation.h:619
std::optional< Node > bandit_sample_op_with_arg(DataType ret, DataType arg, int max_args=0)
Definition variation.h:544
std::optional< Individual< T > > cross(const Individual< T > &mom, const Individual< T > &dad)
Performs croearch_spaceover operation on two individuals.
std::optional< Node > bandit_sample_terminal(DataType R)
Definition variation.h:486
Parameters parameters
Definition variation.h:621
std::optional< Node > bandit_get_node_like(Node node)
Definition variation.h:511
map< DataType, Bandit > terminal_bandits
Definition variation.h:627
std::optional< Individual< T > > mutate(const Individual< T > &parent, string choice="")
Performs mutation operation on an individual.
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
static Rnd & r
Definition rnd.h:176
< nsga2 selection operator for getting the front
Definition bandit.cpp:3
auto IsWeighable() noexcept -> bool
Definition node.h:51
float get_loss() const
Definition fitness.h:64
bool valid() const
Definition fitness.h:155
void set_linear_complexity(unsigned int new_lc)
Definition fitness.h:80
void set_complexity(unsigned int new_c)
Definition fitness.h:75
float get_loss_v() const
Definition fitness.h:68
void set_loss_v(float f_v)
Definition fitness.h:67
void set_depth(unsigned int new_d)
Definition fitness.h:85
unsigned int get_complexity() const
Definition fitness.h:77
void set_size(unsigned int new_s)
Definition fitness.h:71
unsigned int get_depth() const
Definition fitness.h:86
unsigned int get_linear_complexity() const
Definition fitness.h:82
unsigned int get_size() const
Definition fitness.h:72
void set_loss(float f)
Definition fitness.h:63
class holding the data for a node in a tree.
Definition node.h:89
bool weight_is_fixed
whether the weight should be kept during variation. Notice that weight_is_fixed alows us to fix the w...
Definition node.h:114
void set_is_weighted(bool is_weighted)
Definition node.h:296
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:122
unsigned get_max_size() const
Definition params.h:145
unsigned int max_depth
Definition params.h:37
unsigned int max_size
Definition params.h:38
An individual program, a.k.a. model.
Definition program.h:49
tree< Node > Tree
fitness
Definition program.h:72
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
Definition program.h:121
int depth() const
count the tree depth of the program. The depth is not influenced by weighted nodes.
Definition program.h:128
int depth_to_reach(Iter &top) const
count the depth until reaching the given subtree. The depth is not influenced by weighted nodes....
Definition program.h:146
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Definition program.h:110
Map< Node > node_map
Maps return types to argument types to node types.
std::optional< tree< Node > > sample_subtree(Node root, int max_d, int max_size) const
create a subtree with maximum size and depth restrictions and root of type root_type