Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
search_space.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4*/
5#ifndef SEARCHSPACE_H
6#define SEARCHSPACE_H
7//internal includes
8#include "../init.h"
9#include "../program/node.h"
10#include "../program/nodetype.h"
12// #include "program/program.h"
13#include "../util/error.h"
14#include "../util/utils.h"
15#include "../util/rnd.h"
16#include "../params.h"
17#include <utility>
18#include <optional>
19#include <iostream>
20
21/* Defines the search space of Brush.
22 * The search spaces consists of nodes and their accompanying probability
23 * distribution.
24 * Nodes can be accessed by type, signature, or a combination.
25 * You may also sample the search space by return type.
26 * Sampling is done in proportion to the weight associated with
27 * each node. By default, sampling is done uniform randomly.
28*/
29using namespace Brush::Data;
30using namespace Brush::Util;
31using Brush::Node;
32using Brush::DataType;
33using std::type_index;
34
35
36namespace Brush
37{
39// node generation routines
40/* template<typename T> */
41/* tuple<set<Node>,set<type_index>> generate_nodes(vector<string>& op_names); */
42/* tuple<set<Node>,set<type_index>> generate_split_nodes(vector<string>& op_names); */
43
44// forward declarations
45using TreeIter = tree<Node>::pre_order_iterator;
46// template<typename T> struct Program;
47// enum class ProgramType: uint32_t;
48// template<typename T> struct ProgramTypeEnum;
49
50vector<Node> generate_terminals(const Dataset& d, const bool weights_init);
51
53
54extern std::unordered_map<std::size_t, std::string> ArgsName;
55
84{
85 using ArgsHash = std::size_t;
86
87 template<typename T>
88 using Map = unordered_map<DataType, // return type
89 unordered_map<ArgsHash, // hash of arg types
90 unordered_map<NodeType, // node type
91 T>>>; // the data!
92
101
104
105 // TODO: maybe we could flatten this terminal map
106
115 unordered_map<DataType, vector<Node>> terminal_map;
116
118 unordered_map<DataType, vector<float>> terminal_weights;
119
121 vector<DataType> terminal_types;
122
124 vector<string> op_names;
125
126 // serialization
127#ifndef DOXYGEN_SKIP
128
129 NLOHMANN_DEFINE_TYPE_INTRUSIVE(SearchSpace,
130 node_map,
136 )
137
138#endif
139
154 template<typename PT>
155 PT make_program(const Parameters& params, int max_d=0, int max_size=0);
156
161 RegressorProgram make_regressor(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
162
167 ClassifierProgram make_classifier(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
168
173 MulticlassClassifierProgram make_multiclass_classifier(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
174
179 RepresenterProgram make_representer(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
180
181 SearchSpace() = default;
182
187 SearchSpace(const Dataset& d, const unordered_map<string,float>& user_ops = {}, bool weights_init = true){
188 init(d,user_ops,weights_init);
189 }
190
195 void init(const Dataset& d, const unordered_map<string,float>& user_ops = {}, bool weights_init = true);
196
200 bool check(DataType R) const {
201 if (node_map.find(R) == node_map.end()){
202 auto msg = fmt::format("{} not in node_map\n",R);
203 HANDLE_ERROR_THROW(msg);
204 }
205 return true;
206 }
207
212 bool check(DataType R, size_t sig_hash) const
213 {
214 if (check(R)){
215 if (node_map.at(R).find(sig_hash) == node_map.at(R).end()){
216 auto msg = fmt::format("{} not in node_map.at({})\n", sig_hash, R);
217 HANDLE_ERROR_THROW(msg);
218 }
219 }
220 return true;
221 }
222
228 bool check(DataType R, size_t sig_hash, NodeType type) const
229 {
230 if (check(R,sig_hash)){
231 if (node_map.at(R).at(sig_hash).find(type) == node_map.at(R).at(sig_hash).end()){
232
233 auto msg = fmt::format("{} not in node_map[{}][{}]\n",type, sig_hash, R);
234 HANDLE_ERROR_THROW(msg);
235 }}
236 return true;
237 }
238
246 template<typename Iter>
247 bool has_solution_space(Iter start, Iter end) const {
248 return !std::all_of(start, end, [](const auto& w) { return w<=0.0; });
249 }
250
256 Node get(NodeType type, DataType R, size_t sig_hash) const
257 {
258 check(R, sig_hash, type);
259 return node_map.at(R).at(sig_hash).at(type);
260 };
261
268 template<typename S>
269 Node get(NodeType type, DataType R, S sig) const { return get(type, R, sig.hash()); };
270
273 vector<float> get_weights() const
274 {
275 vector<float> v;
276 for (auto& [ret, arg_w_map]: node_map_weights)
277 {
278 v.push_back(0);
279 for (const auto& [arg, name_map] : arg_w_map)
280 {
281 for (const auto& [name, w]: name_map)
282 {
283 v.back() += w;
284 }
285 }
286 }
287 return v;
288 };
289
293 vector<float> get_weights(DataType ret) const
294 {
295 vector<float> v;
296 for (const auto& [arg, name_map] : node_map_weights.at(ret))
297 {
298 v.push_back(0);
299 for (const auto& [name, w]: name_map)
300 {
301 v.back() += w;
302 }
303
304 }
305 return v;
306 };
307
312 vector<float> get_weights(DataType ret, ArgsHash sig_hash) const
313 {
314 vector<float> v;
315 for (const auto& [name, w]: node_map_weights.at(ret).at(sig_hash))
316 v.push_back(w);
317
318 return v;
319 };
320
323 std::optional<Node> sample_terminal(bool force_return=false) const
324 {
325 //TODO: match terminal args_type (probably '{}' or something?)
326 // make a separate terminal_map
327
328 // We'll make terminal types to have its weights proportional to the
329 // DataTypes Weights they hold
330 vector<float> data_type_weights(terminal_weights.size());
331 if (force_return)
332 {
333 std::fill(data_type_weights.begin(), data_type_weights.end(), 1.0f);
334 }
335 else
336 {
337 std::transform(
338 terminal_weights.begin(),
339 terminal_weights.end(),
340 data_type_weights.begin(),
341 [](const auto& tw){
342 return std::reduce(tw.second.begin(), tw.second.end()); }
343 );
344
345 if (!has_solution_space(data_type_weights.begin(),
346 data_type_weights.end()))
347 return std::nullopt;
348 }
349
350 // If we got this far, then it is garanteed that we'll return something
351 // The match take into account datatypes with non-zero weights
352 auto match = *r.select_randomly(
353 terminal_map.begin(),
354 terminal_map.end(),
355 data_type_weights.begin(),
356 data_type_weights.end()
357 );
358
359 // theres always a constant of each data type
360 vector<float> match_weights(match.second.size());
361 if (force_return)
362 {
363 std::fill(match_weights.begin(), match_weights.end(), 1.0f);
364 }
365 else
366 {
367 std::transform(
368 terminal_weights.at(match.first).begin(),
369 terminal_weights.at(match.first).end(),
370 match_weights.begin(),
371 [](const auto& w){ return w; });
372
373 if (!has_solution_space(match_weights.begin(),
374 match_weights.end()))
375 return std::nullopt;
376 }
377
378 return *r.select_randomly(match.second.begin(), match.second.end(),
379 match_weights.begin(), match_weights.end());
380 };
381
384 std::optional<Node> sample_terminal(DataType R, bool force_return=false) const
385 {
386 // should I keep doing this check?
387 // if (terminal_map.find(R) == terminal_map.end()){
388 // auto msg = fmt::format("{} not in terminal_map\n",R);
389 // HANDLE_ERROR_THROW(msg);
390 // }
391
392 // If there's at least one constant for every data type, its always possible to force sample_terminal to return something
393
394 // TODO: try to combine with above function
395 vector<float> match_weights(terminal_weights.at(R).size());
396 if (force_return)
397 {
398 // This should have at least the constant
399 std::fill(match_weights.begin(), match_weights.end(), 1.0f);
400 }
401 else
402 {
403 if (terminal_map.find(R) == terminal_map.end())
404 return std::nullopt;
405
406 std::transform(
407 terminal_weights.at(R).begin(),
408 terminal_weights.at(R).end(),
409 match_weights.begin(),
410 [](const auto& w){ return w; }
411 );
412
413 if (!has_solution_space(match_weights.begin(),
414 match_weights.end()))
415 return std::nullopt;
416 }
417
418 return *r.select_randomly(terminal_map.at(R).begin(),
419 terminal_map.at(R).end(),
420 match_weights.begin(),
421 match_weights.end());
422 };
423
427 std::optional<Node> sample_op(DataType ret) const
428 {
429 check(ret);
430 if (node_map.find(ret) == node_map.end())
431 return std::nullopt;
432
433 //TODO: match terminal args_type (probably '{}' or something?)
434 auto ret_match = node_map.at(ret);
435
436 vector<float> args_w = get_weights(ret);
437
438 if (!has_solution_space(args_w.begin(), args_w.end()))
439 return std::nullopt;
440
441 auto arg_match = *r.select_randomly(ret_match.begin(),
442 ret_match.end(),
443 args_w.begin(),
444 args_w.end());
445
446 vector<float> name_w = get_weights(ret, arg_match.first);
447
448 if (!has_solution_space(name_w.begin(), name_w.end()))
449 return std::nullopt;
450
451 return (*r.select_randomly(arg_match.second.begin(),
452 arg_match.second.end(),
453 name_w.begin(),
454 name_w.end())).second;
455 };
456
461 std::optional<Node> sample_op(NodeType type, DataType R, bool force_return=false)
462 {
463 check(R);
464 if (node_map.find(R) == node_map.end())
465 return std::nullopt;
466
467 auto ret_match = node_map.at(R);
468
469 vector<Node> matches;
470 vector<float> weights;
471 for (const auto& [arg_hash, node_type_map]: ret_match)
472 {
473 if (node_type_map.find(type) != node_type_map.end())
474 // && node_map_weights.at(R).at(arg_hash).at(type) > 0.0f)
475 {
476 matches.push_back(node_type_map.at(type));
477 weights.push_back(node_map_weights.at(R).at(arg_hash).at(type));
478 }
479 }
480
481 if (force_return)
482 {
483 std::fill(weights.begin(), weights.end(), 1.0f);
484 }
485
486 if ( (weights.size()==0)
487 || (!has_solution_space(weights.begin(),
488 weights.end())) )
489 return std::nullopt;
490
491 return (*r.select_randomly(matches.begin(),
492 matches.end(),
493 weights.begin(),
494 weights.end()));
495 };
496
503 std::optional<Node> sample_op_with_arg(DataType ret, DataType arg,
504 bool terminal_compatible=true,
505 int max_args=0) const
506 {
507 // thoughts (TODO):
508 // this could be templated by return type and arg. although the lookup in the map should be
509 // fairly fast.
510 //TODO: these needs to be overhauled
511 // fmt::print("sample_op_with_arg");
512 check(ret);
513
514 auto args_map = node_map.at(ret);
515 vector<Node> matches;
516 vector<float> weights;
517
518 for (const auto& [args_type, name_map]: args_map) {
519 for (const auto& [name, node]: name_map) {
520 auto node_arg_types = node.get_arg_types();
521
522 // has no size limit (max_arg_count==0) or the number of
523 // arguments woudn't exceed the maximum number of arguments
524 auto within_size_limit = !(max_args) || (node.get_arg_count() <= max_args);
525
526 // TODO: I created constant terminals for all datatypes. Can I stop performing this check? (there is always gonna be a terminal)
527 if ( in(node_arg_types, arg)
528 && within_size_limit
529 && node_map_weights.at(ret).at(args_type).at(name) > 0.0f )
530 {
531 // if checking terminal compatibility, make sure there's
532 // a compatible terminal for the node's other arguments
533 if (terminal_compatible) {
534 bool compatible = true;
535 for (const auto& arg_type: node_arg_types) {
536 if (arg_type != arg) {
537 if ( ! in(terminal_types, arg_type) ) {
538 compatible = false;
539 break;
540 }
541 }
542 }
543 if (! compatible)
544 continue;
545 }
546 // if we made it this far, include the node as a match!
547 matches.push_back(node);
548 weights.push_back(node_map_weights.at(ret).at(args_type).at(name));
549 }
550 }
551 }
552
553 if ( (weights.size()==0)
554 || (!has_solution_space(weights.begin(),
555 weights.end())) )
556 return std::nullopt;
557
558 return (*r.select_randomly(matches.begin(), matches.end(),
559 weights.begin(), weights.end()));
560 };
561
565 std::optional<Node> get_node_like(Node node) const
566 {
568 return sample_terminal(node.ret_type);
569 }
570
571 auto matches = node_map.at(node.ret_type).at(node.args_type());
572 auto match_weights = get_weights(node.ret_type, node.args_type());
573
574 if ( (match_weights.size()==0)
575 || (!has_solution_space(match_weights.begin(),
576 match_weights.end())) )
577 return std::nullopt;
578
579 return (*r.select_randomly(matches.begin(),
580 matches.end(),
581 match_weights.begin(),
582 match_weights.end())
583 ).second;
584 };
585
591 std::optional<tree<Node>> sample_subtree(Node root, int max_d, int max_size) const;
592
594 void print() const;
595
597 std::string repr() const {
598 string output = "=== Search space ===\n";
599 output += fmt::format("terminal_map: {}\n", this->terminal_map);
600 output += fmt::format("terminal_weights: {}\n", this->terminal_weights);
601
602 for (const auto& [ret_type, v] : this->node_map) {
603 for (const auto& [args_type, v2] : v) {
604 for (const auto& [node_type, node] : v2) {
605 output += fmt::format("{} node_map[{}][{}][{}] = {}, weight = {}\n",
606 node_type,
607 ret_type,
608 ArgsName[args_type],
609 node_type,
610 node,
611 this->node_map_weights.at(ret_type).at(args_type).at(node_type)
612 );
613 }
614 }
615 }
616 return output;
617 };
618
619 private:
620 tree<Node>& PTC2(tree<Node>& Tree, tree<Node>::iterator root, int max_d, int max_size) const;
621
622 template<NodeType NT, typename S>
624 static constexpr std::optional<Node> CreateNode(
625 const auto& unique_data_types,
626 bool use_all,
627 bool weighted
628 )
629 {
630 // prune the operators out that don't have argument types that
631 // overlap with feature data types
632 for (auto arg: S::get_arg_types()){
633 if (! in(unique_data_types,arg) ){
634 return {};
635 }
636 }
637 ArgsName[S::hash()] = fmt::format("{}", S::get_arg_types());
638 return Node(NT, S{}, weighted);
639 }
640
641 template<NodeType NT, typename S>
642 constexpr void AddNode(
643 const unordered_map<string,float>& user_ops,
644 const vector<DataType>& unique_data_types
645 )
646 {
647 bool use_all = user_ops.size() == 0;
648 auto name = NodeTypeName[NT];
649
650 bool weighted = false;
651 if (Is<NodeType::OffsetSum>(NT)) // this has to have weights on by default
652 weighted = true;
653
654 auto n_maybe = CreateNode<NT,S>(unique_data_types, use_all, weighted);
655
656 if (n_maybe){
657 auto n = n_maybe.value();
658 node_map[n.ret_type][n.args_type()][n.node_type] = n;
659 // sampling probability map
660 float w = use_all? 1.0 : user_ops.at(name);
661 node_map_weights[n.ret_type][n.args_type()][n.node_type] = w;
662 }
663 }
664
665 template <NodeType NT, typename Sigs, std::size_t... Is>
666 constexpr void AddNodes(const unordered_map<string, float> &user_ops,
667 const vector<DataType> &unique_data_types,
668 std::index_sequence<Is...>)
669 {
670 (AddNode<NT, std::tuple_element_t<Is, Sigs>>(user_ops, unique_data_types), ...);
671 }
672
673 template<NodeType NT>
674 void MakeNodes(const unordered_map<string,float>& user_ops,
675 const vector<DataType>& unique_data_types
676 )
677 {
679 return;
680
681 bool use_all = user_ops.size() == 0;
682 auto name = NodeTypeName.at(NT);
683
684 // skip operators not defined by user
685 if (!use_all & user_ops.find(name) == user_ops.end())
686 return;
687
688 using signatures = Signatures<NT>::type;
689 constexpr auto size = std::tuple_size<signatures>::value;
691 user_ops,
692 unique_data_types,
693 std::make_index_sequence<size>()
694 );
695 }
696
697 template<std::size_t... Is>
698 void GenerateNodeMap(const unordered_map<string,float>& user_ops,
699 const vector<DataType>& unique_data_types,
700 std::index_sequence<Is...>
701 )
702 {
703 auto nt = [](auto i) { return static_cast<NodeType>(1UL << i); };
704 (MakeNodes<nt(Is)>(user_ops, unique_data_types), ...);
705 }
706}; // SearchSpace
707
709template<typename T>
710T RandomDequeue(std::vector<T>& Q)
711{
712 int loc = r.rnd_int(0, Q.size()-1);
713 std::swap(Q[loc], Q[Q.size()-1]);
714 T val = Q.back();
715 Q.pop_back();
716 return val;
717};
718
719template<typename P>
720P SearchSpace::make_program(const Parameters& params, int max_d, int max_size)
721{
722 // this is what makes `make_program` create uniformly distributed
723 // individuals to feed initial population
724 if (max_d < 1)
725 max_d = r.rnd_int(1, params.max_depth);
726 if (max_size < 1)
727 max_size = r.rnd_int(1, params.max_size);
728
730 ProgramType program_type = P::program_type;
731
732 auto Tree = tree<Node>();
733 tree<Node>::iterator spot;
734
735 if (P::program_type == ProgramType::BinaryClassifier)
736 {
737 // sample_op should never return the empty value of optional
738 Node node_logit = sample_op(NodeType::Logistic, DataType::ArrayF, true).value();
739
740 node_logit.set_is_weighted(false);
741 node_logit.set_prob_change(0.0);
742 node_logit.fixed=true;
743
744 auto spot_logit = Tree.insert(Tree.begin(), node_logit);
745
746 if (true) {
747 Node node_offset = sample_op(NodeType::OffsetSum, DataType::ArrayF, true).value();
748
749 node_offset.set_prob_change(0.0);
750 node_offset.fixed=true;
751
752 auto spot_offset = Tree.append_child(spot_logit);
753
754 spot = Tree.replace(spot_offset, node_offset);
755 }
756 else {
757 spot = spot_logit;
758 }
759 }
760 else if (P::program_type == ProgramType::MulticlassClassifier)
761 {
762 Node node_softmax = sample_op(NodeType::Softmax, DataType::MatrixF, true).value();
763
764 node_softmax.set_prob_change(0.0);
765 node_softmax.set_is_weighted(false);
766 node_softmax.fixed=true;
767
768 spot = Tree.insert(Tree.begin(), node_softmax);
769 }
770 else
771 {
772 Node root;
773 std::optional<Node> opt=std::nullopt;
774
775 if (max_size>1 && max_d>1)
776 opt = sample_op(root_type);
777
778 if (!opt)
779 opt = sample_terminal(root_type, true);
780
781 root = opt.value();
782
783 spot = Tree.insert(Tree.begin(), root);
784 }
785
786 // max_d-1 because we always pick the root before calling ptc2
787 PTC2(Tree, spot, max_d-1, max_size); // change inplace
788
789 // weighting the tree if classification problem (so it can optimize the scale by default)
790 // if (P::program_type == ProgramType::BinaryClassifier
791 // || P::program_type == ProgramType::MulticlassClassifier)
792 // {
793 // // Get the child of the spot
794 // auto child = spot.begin();
795
796 // // Turn on the weight for the child
797 // child->set_prob_change(1.0);
798 // }
799
800 return P(*this, Tree);
801};
802
803extern SearchSpace SS;
804
805} // Brush
806
807// format overload
808template <>
809struct fmt::formatter<Brush::SearchSpace>: formatter<string_view> {
810 template <typename FormatContext>
811 auto format(const Brush::SearchSpace& SS, FormatContext& ctx) const {
812 string output = SS.repr();
813 return formatter<string_view>::format(output, ctx);
814 }
815};
816#endif
holds variable type data.
Definition data.h:51
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
namespace containing Data structures used in Brush
Definition data.cpp:49
namespace containing various utility functions
Definition error.cpp:11
bool in(const V &v, const T &i)
check if element is in vector.
Definition utils.h:192
static Rnd & r
Definition rnd.h:176
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
NodeType
Definition nodetype.h:31
Program< PT::Representer > RepresenterProgram
Definition types.h:81
ProgramType PT
Definition program.h:40
Program< PT::BinaryClassifier > ClassifierProgram
Definition types.h:79
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:291
std::unordered_map< std::size_t, std::string > ArgsName
vector< Node > generate_terminals(const Dataset &d, const bool weights_init)
generate terminals from the dataset features and random constants.
T RandomDequeue(std::vector< T > &Q)
queue for make program
tree< Node >::pre_order_iterator Iter
Definition program.h:37
tree< Node >::pre_order_iterator TreeIter
NodeType NT
Definition node.cpp:134
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
Program< PT::Regressor > RegressorProgram
Definition types.h:78
SearchSpace SS
ProgramType
Definition types.h:70
Program< PT::MulticlassClassifier > MulticlassClassifierProgram
Definition types.h:80
static constexpr bool is_in_v
Definition nodetype.h:269
class holding the data for a node in a tree.
Definition node.h:84
bool fixed
whether the node is replaceable. Weights are still optimized.
Definition node.h:101
NodeType node_type
the node type
Definition node.h:89
DataType ret_type
return data type
Definition node.h:92
void set_is_weighted(bool is_weighted)
Definition node.h:274
std::size_t args_type() const
Definition node.h:180
void set_prob_change(float w)
Definition node.h:255
unsigned int max_depth
Definition params.h:37
unsigned int max_size
Definition params.h:38
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...
void print() const
prints the search space map.
Node get(NodeType type, DataType R, S sig) const
get a typed node.
static constexpr std::optional< Node > CreateNode(const auto &unique_data_types, bool use_all, bool weighted)
Node get(NodeType type, DataType R, size_t sig_hash) const
get a typed node
constexpr void AddNode(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types)
Map< Node > node_map
Maps return types to argument types to node types.
unordered_map< DataType, vector< Node > > terminal_map
Maps return types to terminals.
unordered_map< DataType, unordered_map< ArgsHash, unordered_map< NodeType, T > > > Map
constexpr void AddNodes(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types, std::index_sequence< Is... >)
std::optional< Node > sample_terminal(DataType R, bool force_return=false) const
Get a random terminal with return type R
void init(const Dataset &d, const unordered_map< string, float > &user_ops={}, bool weights_init=true)
Called by the constructor to initialize the search space.
RegressorProgram make_regressor(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random regressor program. Convenience wrapper for make_program.
bool check(DataType R, size_t sig_hash, NodeType type) const
check if a typed Node is in the search space
bool check(DataType R) const
check if a return type is in the node map
std::optional< Node > sample_op(NodeType type, DataType R, bool force_return=false)
Get a specific node type that matches a return value.
unordered_map< DataType, vector< float > > terminal_weights
A map of weights corresponding to elements in terminal_map, used to weight probabilities of each term...
vector< float > get_weights() const
get weights of the return types
vector< string > op_names
A vector storing the available operator names (used by bandits).
vector< float > get_weights(DataType ret) const
get weights of the argument types matching return type ret.
bool check(DataType R, size_t sig_hash) const
check if a function signature is in the search space
void GenerateNodeMap(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types, std::index_sequence< Is... >)
tree< Node > & PTC2(tree< Node > &Tree, tree< Node >::iterator root, int max_d, int max_size) const
std::optional< Node > sample_op(DataType ret) const
get an operator matching return type ret.
SearchSpace(const Dataset &d, const unordered_map< string, float > &user_ops={}, bool weights_init=true)
Construct a search space.
std::string repr() const
returns a string with a json representation of the search space map
std::optional< Node > get_node_like(Node node) const
get a node with a signature matching node
vector< float > get_weights(DataType ret, ArgsHash sig_hash) const
get the weights of nodes matching a signature.
RepresenterProgram make_representer(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random representer program. Convenience wrapper for make_program.
std::optional< Node > sample_op_with_arg(DataType ret, DataType arg, bool terminal_compatible=true, int max_args=0) const
get operator with at least one argument matching arg
Map< float > node_map_weights
A map of weights corresponding to elements in node_map, used to weight probabilities of each node bei...
MulticlassClassifierProgram make_multiclass_classifier(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random multiclass classifier program. Convenience wrapper for make_program.
std::optional< tree< Node > > sample_subtree(Node root, int max_d, int max_size) const
create a subtree with maximum size and depth restrictions and root of type root_type
std::size_t ArgsHash
PT make_program(const Parameters &params, int max_d=0, int max_size=0)
Makes a random program.
vector< DataType > terminal_types
A vector storing the available return types of terminals.
bool has_solution_space(Iter start, Iter end) const
Takes iterators to weight vectors and checks if they have a non-empty solution space....
void MakeNodes(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types)
SearchSpace()=default
ClassifierProgram make_classifier(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random classifier program. Convenience wrapper for make_program.
std::optional< Node > sample_terminal(bool force_return=false) const
Get a random terminal.
auto format(const Brush::SearchSpace &SS, FormatContext &ctx) const