Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
search_space.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4*/
5#ifndef SEARCHSPACE_H
6#define SEARCHSPACE_H
7//internal includes
8#include "../init.h"
9#include "../program/node.h"
10#include "../program/nodetype.h"
12// #include "program/program.h"
13#include "../util/error.h"
14#include "../util/utils.h"
15#include "../util/rnd.h"
16#include "../params.h"
17#include <utility>
18#include <optional>
19#include <iostream>
20
21/* Defines the search space of Brush.
22 * The search spaces consists of nodes and their accompanying probability
23 * distribution.
24 * Nodes can be accessed by type, signature, or a combination.
25 * You may also sample the search space by return type.
26 * Sampling is done in proportion to the weight associated with
27 * each node. By default, sampling is done uniform randomly.
28*/
29using namespace Brush::Data;
30using namespace Brush::Util;
31using Brush::Node;
32using Brush::DataType;
33using std::type_index;
34
35
36namespace Brush
37{
39// node generation routines
40/* template<typename T> */
41/* tuple<set<Node>,set<type_index>> generate_nodes(vector<string>& op_names); */
42/* tuple<set<Node>,set<type_index>> generate_split_nodes(vector<string>& op_names); */
43
44// forward declarations
45using TreeIter = tree<Node>::pre_order_iterator;
46// template<typename T> struct Program;
47// enum class ProgramType: uint32_t;
48// template<typename T> struct ProgramTypeEnum;
49
50vector<Node> generate_terminals(const Dataset& d, const bool weights_init);
51
53
54extern std::unordered_map<std::size_t, std::string> ArgsName;
55
84{
85 using ArgsHash = std::size_t;
86
87 template<typename T>
88 using Map = unordered_map<DataType, // return type
89 unordered_map<ArgsHash, // hash of arg types
90 unordered_map<NodeType, // node type
91 T>>>; // the data!
92
101
104
105 // TODO: maybe we could flatten this terminal map
106
115 unordered_map<DataType, vector<Node>> terminal_map;
116
118 unordered_map<DataType, vector<float>> terminal_weights;
119
121 vector<DataType> terminal_types;
122
124 vector<string> op_names;
125
126 // serialization
127#ifndef DOXYGEN_SKIP
128
129 NLOHMANN_DEFINE_TYPE_INTRUSIVE(SearchSpace,
130 node_map,
136 )
137
138#endif
139
154 template<typename PT>
155 PT make_program(const Parameters& params, int max_d=0, int max_size=0);
156
161 RegressorProgram make_regressor(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
162
167 ClassifierProgram make_classifier(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
168
173 MulticlassClassifierProgram make_multiclass_classifier(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
174
179 RepresenterProgram make_representer(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
180
181 SearchSpace() = default;
182
187 SearchSpace(const Dataset& d, const unordered_map<string,float>& user_ops = {}, bool weights_init = true){
188 init(d,user_ops,weights_init);
189 }
190
195 void init(const Dataset& d, const unordered_map<string,float>& user_ops = {}, bool weights_init = true);
196
200 bool check(DataType R) const {
201 if (node_map.find(R) == node_map.end()){
202 auto msg = fmt::format("{} not in node_map\n",R);
203 HANDLE_ERROR_THROW(msg);
204 }
205 return true;
206 }
207
212 bool check(DataType R, size_t sig_hash) const
213 {
214 if (check(R)){
215 if (node_map.at(R).find(sig_hash) == node_map.at(R).end()){
216 auto msg = fmt::format("{} not in node_map.at({})\n", sig_hash, R);
217 HANDLE_ERROR_THROW(msg);
218 }
219 }
220 return true;
221 }
222
228 bool check(DataType R, size_t sig_hash, NodeType type) const
229 {
230 if (check(R,sig_hash)){
231 if (node_map.at(R).at(sig_hash).find(type) == node_map.at(R).at(sig_hash).end()){
232
233 auto msg = fmt::format("{} not in node_map[{}][{}]\n",type, sig_hash, R);
234 HANDLE_ERROR_THROW(msg);
235 }}
236 return true;
237 }
238
246 template<typename Iter>
247 bool has_solution_space(Iter start, Iter end) const {
248 return !std::all_of(start, end, [](const auto& w) { return w<=0.0; });
249 }
250
251 template<typename F> Node get(const string& name);
252
258 Node get(NodeType type, DataType R, size_t sig_hash) const
259 {
260 check(R, sig_hash, type);
261 return node_map.at(R).at(sig_hash).at(type);
262 };
263
270 template<typename S>
271 Node get(NodeType type, DataType R, S sig) const { return get(type, R, sig.hash()); };
272
275 vector<float> get_weights() const
276 {
277 vector<float> v;
278 for (auto& [ret, arg_w_map]: node_map_weights)
279 {
280 v.push_back(0);
281 for (const auto& [arg, name_map] : arg_w_map)
282 {
283 for (const auto& [name, w]: name_map)
284 {
285 v.back() += w;
286 }
287 }
288 }
289 return v;
290 };
291
295 vector<float> get_weights(DataType ret) const
296 {
297 vector<float> v;
298 for (const auto& [arg, name_map] : node_map_weights.at(ret))
299 {
300 v.push_back(0);
301 for (const auto& [name, w]: name_map)
302 {
303 v.back() += w;
304 }
305
306 }
307 return v;
308 };
309
314 vector<float> get_weights(DataType ret, ArgsHash sig_hash) const
315 {
316 vector<float> v;
317 for (const auto& [name, w]: node_map_weights.at(ret).at(sig_hash))
318 v.push_back(w);
319
320 return v;
321 };
322
325 std::optional<Node> sample_terminal(bool force_return=false) const
326 {
327 //TODO: match terminal args_type (probably '{}' or something?)
328 // make a separate terminal_map
329
330 // We'll make terminal types to have its weights proportional to the
331 // DataTypes Weights they hold
332 vector<float> data_type_weights(terminal_weights.size());
333 if (force_return)
334 {
335 std::fill(data_type_weights.begin(), data_type_weights.end(), 1.0f);
336 }
337 else
338 {
339 std::transform(
340 terminal_weights.begin(),
341 terminal_weights.end(),
342 data_type_weights.begin(),
343 [](const auto& tw){
344 return std::reduce(tw.second.begin(), tw.second.end()); }
345 );
346
347 if (!has_solution_space(data_type_weights.begin(),
348 data_type_weights.end()))
349 return std::nullopt;
350 }
351
352 // If we got this far, then it is garanteed that we'll return something
353 // The match take into account datatypes with non-zero weights
354 auto match = *r.select_randomly(
355 terminal_map.begin(),
356 terminal_map.end(),
357 data_type_weights.begin(),
358 data_type_weights.end()
359 );
360
361 // theres always a constant of each data type
362 vector<float> match_weights(match.second.size());
363 if (force_return)
364 {
365 std::fill(match_weights.begin(), match_weights.end(), 1.0f);
366 }
367 else
368 {
369 std::transform(
370 terminal_weights.at(match.first).begin(),
371 terminal_weights.at(match.first).end(),
372 match_weights.begin(),
373 [](const auto& w){ return w; });
374
375 if (!has_solution_space(match_weights.begin(),
376 match_weights.end()))
377 return std::nullopt;
378 }
379
380 return *r.select_randomly(match.second.begin(), match.second.end(),
381 match_weights.begin(), match_weights.end());
382 };
383
386 std::optional<Node> sample_terminal(DataType R, bool force_return=false) const
387 {
388 // should I keep doing this check?
389 // if (terminal_map.find(R) == terminal_map.end()){
390 // auto msg = fmt::format("{} not in terminal_map\n",R);
391 // HANDLE_ERROR_THROW(msg);
392 // }
393
394 // If there's at least one constant for every data type, its always possible to force sample_terminal to return something
395
396 // TODO: try to combine with above function
397 vector<float> match_weights(terminal_weights.at(R).size());
398 if (force_return)
399 {
400 // This should have at least the constant
401 std::fill(match_weights.begin(), match_weights.end(), 1.0f);
402 }
403 else
404 {
405 if (terminal_map.find(R) == terminal_map.end())
406 return std::nullopt;
407
408 std::transform(
409 terminal_weights.at(R).begin(),
410 terminal_weights.at(R).end(),
411 match_weights.begin(),
412 [](const auto& w){ return w; }
413 );
414
415 if (!has_solution_space(match_weights.begin(),
416 match_weights.end()))
417 return std::nullopt;
418 }
419
420 return *r.select_randomly(terminal_map.at(R).begin(),
421 terminal_map.at(R).end(),
422 match_weights.begin(),
423 match_weights.end());
424 };
425
429 std::optional<Node> sample_op(DataType ret) const
430 {
431 check(ret);
432 if (node_map.find(ret) == node_map.end())
433 return std::nullopt;
434
435 //TODO: match terminal args_type (probably '{}' or something?)
436 auto ret_match = node_map.at(ret);
437
438 vector<float> args_w = get_weights(ret);
439
440 if (!has_solution_space(args_w.begin(), args_w.end()))
441 return std::nullopt;
442
443 auto arg_match = *r.select_randomly(ret_match.begin(),
444 ret_match.end(),
445 args_w.begin(),
446 args_w.end());
447
448 vector<float> name_w = get_weights(ret, arg_match.first);
449
450 if (!has_solution_space(name_w.begin(), name_w.end()))
451 return std::nullopt;
452
453 return (*r.select_randomly(arg_match.second.begin(),
454 arg_match.second.end(),
455 name_w.begin(),
456 name_w.end())).second;
457 };
458
463 std::optional<Node> sample_op(NodeType type, DataType R)
464 {
465 check(R);
466 if (node_map.find(R) == node_map.end())
467 return std::nullopt;
468
469 auto ret_match = node_map.at(R);
470
471 vector<Node> matches;
472 vector<float> weights;
473 for (const auto& [arg_hash, node_type_map]: ret_match)
474 {
475 if (node_type_map.find(type) != node_type_map.end()
476 && node_map_weights.at(R).at(arg_hash).at(type) > 0.0f)
477 {
478 matches.push_back(node_type_map.at(type));
479 weights.push_back(node_map_weights.at(R).at(arg_hash).at(type));
480 }
481 }
482
483 if ( (weights.size()==0)
484 || (!has_solution_space(weights.begin(),
485 weights.end())) )
486 return std::nullopt;
487
488 return (*r.select_randomly(matches.begin(),
489 matches.end(),
490 weights.begin(),
491 weights.end()));
492 };
493
500 std::optional<Node> sample_op_with_arg(DataType ret, DataType arg,
501 bool terminal_compatible=true,
502 int max_args=0) const
503 {
504 // thoughts (TODO):
505 // this could be templated by return type and arg. although the lookup in the map should be
506 // fairly fast.
507 //TODO: these needs to be overhauled
508 // fmt::print("sample_op_with_arg");
509 check(ret);
510
511 auto args_map = node_map.at(ret);
512 vector<Node> matches;
513 vector<float> weights;
514
515 for (const auto& [args_type, name_map]: args_map) {
516 for (const auto& [name, node]: name_map) {
517 auto node_arg_types = node.get_arg_types();
518
519 // has no size limit (max_arg_count==0) or the number of
520 // arguments woudn't exceed the maximum number of arguments
521 auto within_size_limit = !(max_args) || (node.get_arg_count() <= max_args);
522
523 // TODO: I created constant terminals for all datatypes. Can I stop performing this check? (there is always gonna be a terminal)
524 if ( in(node_arg_types, arg)
525 && within_size_limit
526 && node_map_weights.at(ret).at(args_type).at(name) > 0.0f )
527 {
528 // if checking terminal compatibility, make sure there's
529 // a compatible terminal for the node's other arguments
530 if (terminal_compatible) {
531 bool compatible = true;
532 for (const auto& arg_type: node_arg_types) {
533 if (arg_type != arg) {
534 if ( ! in(terminal_types, arg_type) ) {
535 compatible = false;
536 break;
537 }
538 }
539 }
540 if (! compatible)
541 continue;
542 }
543 // if we made it this far, include the node as a match!
544 matches.push_back(node);
545 weights.push_back(node_map_weights.at(ret).at(args_type).at(name));
546 }
547 }
548 }
549
550 if ( (weights.size()==0)
551 || (!has_solution_space(weights.begin(),
552 weights.end())) )
553 return std::nullopt;
554
555 return (*r.select_randomly(matches.begin(), matches.end(),
556 weights.begin(), weights.end()));
557 };
558
562 std::optional<Node> get_node_like(Node node) const
563 {
565 return sample_terminal(node.ret_type);
566 }
567
568 auto matches = node_map.at(node.ret_type).at(node.args_type());
569 auto match_weights = get_weights(node.ret_type, node.args_type());
570
571 if ( (match_weights.size()==0)
572 || (!has_solution_space(match_weights.begin(),
573 match_weights.end())) )
574 return std::nullopt;
575
576 return (*r.select_randomly(matches.begin(),
577 matches.end(),
578 match_weights.begin(),
579 match_weights.end())
580 ).second;
581 };
582
588 std::optional<tree<Node>> sample_subtree(Node root, int max_d, int max_size) const;
589
591 void print() const;
592
594 std::string repr() const {
595 string output = "=== Search space ===\n";
596 output += fmt::format("terminal_map: {}\n", this->terminal_map);
597 output += fmt::format("terminal_weights: {}\n", this->terminal_weights);
598
599 for (const auto& [ret_type, v] : this->node_map) {
600 for (const auto& [args_type, v2] : v) {
601 for (const auto& [node_type, node] : v2) {
602 output += fmt::format("node_map[{}][{}][{}] = {}, weight = {}\n",
603 ret_type,
604 ArgsName[args_type],
605 node_type,
606 node,
607 this->node_map_weights.at(ret_type).at(args_type).at(node_type)
608 );
609 }
610 }
611 }
612 return output;
613 };
614
615 private:
616 tree<Node>& PTC2(tree<Node>& Tree, tree<Node>::iterator root, int max_d, int max_size) const;
617
618 template<NodeType NT, typename S>
620 static constexpr std::optional<Node> CreateNode(
621 const auto& unique_data_types,
622 bool use_all,
623 bool weighted
624 )
625 {
626 // prune the operators out that don't have argument types that
627 // overlap with feature data types
628 for (auto arg: S::get_arg_types()){
629 if (! in(unique_data_types,arg) ){
630 return {};
631 }
632 }
633 ArgsName[S::hash()] = fmt::format("{}", S::get_arg_types());
634 return Node(NT, S{}, weighted);
635 }
636
637 template<NodeType NT, typename S>
638 constexpr void AddNode(
639 const unordered_map<string,float>& user_ops,
640 const vector<DataType>& unique_data_types
641 )
642 {
643 bool use_all = user_ops.size() == 0;
644 auto name = NodeTypeName[NT];
645
646 bool weighted = false;
647 if (Is<NodeType::OffsetSum>(NT)) // this has to have weights on by default
648 weighted = true;
649
650 auto n_maybe = CreateNode<NT,S>(unique_data_types, use_all, weighted);
651
652 if (n_maybe){
653 auto n = n_maybe.value();
654 node_map[n.ret_type][n.args_type()][n.node_type] = n;
655 // sampling probability map
656 float w = use_all? 1.0 : user_ops.at(name);
657 node_map_weights[n.ret_type][n.args_type()][n.node_type] = w;
658 }
659 }
660
661 template <NodeType NT, typename Sigs, std::size_t... Is>
662 constexpr void AddNodes(const unordered_map<string, float> &user_ops,
663 const vector<DataType> &unique_data_types,
664 std::index_sequence<Is...>)
665 {
666 (AddNode<NT, std::tuple_element_t<Is, Sigs>>(user_ops, unique_data_types), ...);
667 }
668
669 template<NodeType NT>
670 void MakeNodes(const unordered_map<string,float>& user_ops,
671 const vector<DataType>& unique_data_types
672 )
673 {
675 return;
676 bool use_all = user_ops.size() == 0;
677 auto name = NodeTypeName.at(NT);
678
679 // skip operators not defined by user
680 if (!use_all & user_ops.find(name) == user_ops.end())
681 return;
682
683 using signatures = Signatures<NT>::type;
684 constexpr auto size = std::tuple_size<signatures>::value;
686 user_ops,
687 unique_data_types,
688 std::make_index_sequence<size>()
689 );
690 }
691
692 template<std::size_t... Is>
693 void GenerateNodeMap(const unordered_map<string,float>& user_ops,
694 const vector<DataType>& unique_data_types,
695 std::index_sequence<Is...>
696 )
697 {
698 auto nt = [](auto i) { return static_cast<NodeType>(1UL << i); };
699 (MakeNodes<nt(Is)>(user_ops, unique_data_types), ...);
700 }
701}; // SearchSpace
702
704template<typename T>
705T RandomDequeue(std::vector<T>& Q)
706{
707 int loc = r.rnd_int(0, Q.size()-1);
708 std::swap(Q[loc], Q[Q.size()-1]);
709 T val = Q.back();
710 Q.pop_back();
711 return val;
712};
713
714template<typename P>
715P SearchSpace::make_program(const Parameters& params, int max_d, int max_size)
716{
717 // this is what makes `make_program` create uniformly distributed
718 // individuals to feed initial population
719 if (max_d < 1)
720 max_d = r.rnd_int(1, params.max_depth);
721 if (max_size < 1)
722 max_size = r.rnd_int(1, params.max_size);
723
725 ProgramType program_type = P::program_type;
726 // ProgramType program_type = ProgramTypeEnum<PT>::value;
727
728 // Tree is pre-filled with some fixed nodes depending on program type
729 auto Tree = tree<Node>();
730
731 // building the tree for each program case. Then, we give the spot to PTC2,
732 // and it will fill the rest of the tree
733 tree<Node>::iterator spot;
734
735 // building the root node for each program case
736 if (P::program_type == ProgramType::BinaryClassifier)
737 {
738 Node node_logit = get(NodeType::Logistic, DataType::ArrayF, Signature<ArrayXf(ArrayXf)>());
739 node_logit.set_is_weighted(false);
740 node_logit.set_prob_change(0.0);
741 node_logit.fixed=true;
742
743 auto spot_logit = Tree.insert(Tree.begin(), node_logit);
744
745 if (true) { // Logistic(Add(Constant, <>)). TODO: let the user control this
746 Node node_offset = get(NodeType::OffsetSum, DataType::ArrayF, Signature<ArrayXf(ArrayXf)>());
747
748 node_offset.set_prob_change(0.0);
749 node_offset.fixed=true;
750
751 auto spot_offset = Tree.append_child(spot_logit);
752
753 spot = Tree.replace(spot_offset, node_offset);
754 }
755 else { // If false, then model will be Logistic(<>)
756 spot = spot_logit;
757 }
758 }
759 else if (P::program_type == ProgramType::MulticlassClassifier)
760 {
761 Node node_softmax = get(NodeType::Softmax, DataType::MatrixF, Signature<ArrayXXf(ArrayXXf)>());
762
763 node_softmax.set_prob_change(0.0);
764 node_softmax.set_is_weighted(false);
765 node_softmax.fixed=true;
766
767 spot = Tree.insert(Tree.begin(), node_softmax);
768 }
769 else // regression or representer --- sampling any candidate op or terminal
770 {
771 Node root;
772
773 std::optional<Node> opt=std::nullopt;
774
775 if (max_size>1 && max_d>1)
776 opt = sample_op(root_type);
777
778 if (!opt) // if failed, then we dont have any operator to use as root...
779 opt = sample_terminal(root_type, true);
780
781 root = opt.value();
782
783 spot = Tree.insert(Tree.begin(), root);
784 }
785
786 // max_d-1 because we always pick the root before calling ptc2
787 PTC2(Tree, spot, max_d-1, max_size); // change inplace
788
789 // weighting the tree if classification problem (so it can optimize the scale by default)
790 // if (P::program_type == ProgramType::BinaryClassifier
791 // || P::program_type == ProgramType::MulticlassClassifier)
792 // {
793 // // Get the child of the spot
794 // auto child = spot.begin();
795
796 // // Turn on the weight for the child
797 // child->set_prob_change(1.0);
798 // }
799
800 return P(*this, Tree);
801};
802
803extern SearchSpace SS;
804
805} // Brush
806
807// format overload
808template <>
809struct fmt::formatter<Brush::SearchSpace>: formatter<string_view> {
810 template <typename FormatContext>
811 auto format(const Brush::SearchSpace& SS, FormatContext& ctx) const {
812 string output = SS.repr();
813 return formatter<string_view>::format(output, ctx);
814 }
815};
816#endif
holds variable type data.
Definition data.h:51
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
namespace containing Data structures used in Brush
Definition data.cpp:49
namespace containing various utility functions
Definition error.cpp:11
bool in(const V &v, const T &i)
check if element is in vector.
Definition utils.h:192
static Rnd & r
Definition rnd.h:174
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
NodeType
Definition nodetype.h:31
Program< PT::Representer > RepresenterProgram
Definition types.h:81
ProgramType PT
Definition program.h:40
Program< PT::BinaryClassifier > ClassifierProgram
Definition types.h:79
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:272
std::unordered_map< std::size_t, std::string > ArgsName
vector< Node > generate_terminals(const Dataset &d, const bool weights_init)
generate terminals from the dataset features and random constants.
T RandomDequeue(std::vector< T > &Q)
queue for make program
tree< Node >::pre_order_iterator Iter
Definition program.h:37
tree< Node >::pre_order_iterator TreeIter
NodeType NT
Definition node.cpp:134
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
Program< PT::Regressor > RegressorProgram
Definition types.h:78
SearchSpace SS
ProgramType
Definition types.h:70
Program< PT::MulticlassClassifier > MulticlassClassifierProgram
Definition types.h:80
static constexpr bool is_in_v
Definition nodetype.h:268
class holding the data for a node in a tree.
Definition node.h:84
bool fixed
whether node is modifiable
Definition node.h:93
NodeType node_type
the node type
Definition node.h:95
DataType ret_type
return data type
Definition node.h:101
void set_is_weighted(bool is_weighted)
Definition node.h:256
std::size_t args_type() const
Definition node.h:180
void set_prob_change(float w)
Definition node.h:246
unsigned int max_depth
Definition params.h:37
unsigned int max_size
Definition params.h:38
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...
void print() const
prints the search space map.
Node get(NodeType type, DataType R, S sig) const
get a typed node.
static constexpr std::optional< Node > CreateNode(const auto &unique_data_types, bool use_all, bool weighted)
Node get(const string &name)
Node get(NodeType type, DataType R, size_t sig_hash) const
get a typed node
constexpr void AddNode(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types)
Map< Node > node_map
Maps return types to argument types to node types.
unordered_map< DataType, vector< Node > > terminal_map
Maps return types to terminals.
unordered_map< DataType, unordered_map< ArgsHash, unordered_map< NodeType, T > > > Map
constexpr void AddNodes(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types, std::index_sequence< Is... >)
std::optional< Node > sample_terminal(DataType R, bool force_return=false) const
Get a random terminal with return type R
void init(const Dataset &d, const unordered_map< string, float > &user_ops={}, bool weights_init=true)
Called by the constructor to initialize the search space.
RegressorProgram make_regressor(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random regressor program. Convenience wrapper for make_program.
bool check(DataType R, size_t sig_hash, NodeType type) const
check if a typed Node is in the search space
bool check(DataType R) const
check if a return type is in the node map
unordered_map< DataType, vector< float > > terminal_weights
A map of weights corresponding to elements in terminal_map, used to weight probabilities of each term...
std::optional< Node > sample_op(NodeType type, DataType R)
Get a specific node type that matches a return value.
vector< float > get_weights() const
get weights of the return types
vector< string > op_names
A vector storing the available operator names (used by bandits).
vector< float > get_weights(DataType ret) const
get weights of the argument types matching return type ret.
bool check(DataType R, size_t sig_hash) const
check if a function signature is in the search space
void GenerateNodeMap(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types, std::index_sequence< Is... >)
tree< Node > & PTC2(tree< Node > &Tree, tree< Node >::iterator root, int max_d, int max_size) const
std::optional< Node > sample_op(DataType ret) const
get an operator matching return type ret.
SearchSpace(const Dataset &d, const unordered_map< string, float > &user_ops={}, bool weights_init=true)
Construct a search space.
std::string repr() const
returns a string with a json representation of the search space map
std::optional< Node > get_node_like(Node node) const
get a node with a signature matching node
vector< float > get_weights(DataType ret, ArgsHash sig_hash) const
get the weights of nodes matching a signature.
RepresenterProgram make_representer(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random representer program. Convenience wrapper for make_program.
std::optional< Node > sample_op_with_arg(DataType ret, DataType arg, bool terminal_compatible=true, int max_args=0) const
get operator with at least one argument matching arg
Map< float > node_map_weights
A map of weights corresponding to elements in node_map, used to weight probabilities of each node bei...
MulticlassClassifierProgram make_multiclass_classifier(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random multiclass classifier program. Convenience wrapper for make_program.
std::optional< tree< Node > > sample_subtree(Node root, int max_d, int max_size) const
create a subtree with maximum size and depth restrictions and root of type root_type
std::size_t ArgsHash
PT make_program(const Parameters &params, int max_d=0, int max_size=0)
Makes a random program.
vector< DataType > terminal_types
A vector storing the available return types of terminals.
bool has_solution_space(Iter start, Iter end) const
Takes iterators to weight vectors and checks if they have a non-empty solution space....
void MakeNodes(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types)
SearchSpace()=default
ClassifierProgram make_classifier(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random classifier program. Convenience wrapper for make_program.
std::optional< Node > sample_terminal(bool force_return=false) const
Get a random terminal.
auto format(const Brush::SearchSpace &SS, FormatContext &ctx) const