45using TreeIter = tree<Node>::pre_order_iterator;
54extern std::unordered_map<std::size_t, std::string>
ArgsName;
154 template<
typename PT>
187 SearchSpace(
const Dataset& d,
const unordered_map<string,float>& user_ops = {},
bool weights_init =
true){
188 init(d,user_ops,weights_init);
195 void init(
const Dataset& d,
const unordered_map<string,float>& user_ops = {},
bool weights_init =
true);
202 auto msg = fmt::format(
"{} not in node_map\n",R);
216 auto msg = fmt::format(
"{} not in node_map.at({})\n", sig_hash, R);
230 if (
check(R,sig_hash)){
231 if (
node_map.at(R).at(sig_hash).find(type) ==
node_map.at(R).at(sig_hash).end()){
233 auto msg = fmt::format(
"{} not in node_map[{}][{}]\n",type, sig_hash, R);
246 template<
typename Iter>
248 return !std::all_of(start, end, [](
const auto& w) {
return w<=0.0; });
251 template<
typename F>
Node get(
const string& name);
260 check(R, sig_hash, type);
261 return node_map.at(R).at(sig_hash).at(type);
281 for (
const auto& [arg, name_map] : arg_w_map)
283 for (
const auto& [name, w]: name_map)
301 for (
const auto& [name, w]: name_map)
335 std::fill(data_type_weights.begin(), data_type_weights.end(), 1.0f);
342 data_type_weights.begin(),
344 return std::reduce(tw.second.begin(), tw.second.end()); }
348 data_type_weights.end()))
354 auto match = *
r.select_randomly(
357 data_type_weights.begin(),
358 data_type_weights.end()
362 vector<float> match_weights(match.second.size());
365 std::fill(match_weights.begin(), match_weights.end(), 1.0f);
372 match_weights.begin(),
373 [](
const auto& w){ return w; });
376 match_weights.end()))
380 return *
r.select_randomly(match.second.begin(), match.second.end(),
381 match_weights.begin(), match_weights.end());
401 std::fill(match_weights.begin(), match_weights.end(), 1.0f);
411 match_weights.begin(),
412 [](
const auto& w){ return w; }
416 match_weights.end()))
422 match_weights.begin(),
423 match_weights.end());
443 auto arg_match = *
r.select_randomly(ret_match.begin(),
448 vector<float> name_w =
get_weights(ret, arg_match.first);
453 return (*
r.select_randomly(arg_match.second.begin(),
454 arg_match.second.end(),
456 name_w.end())).second;
471 vector<Node> matches;
472 vector<float> weights;
473 for (
const auto& [arg_hash, node_type_map]: ret_match)
475 if (node_type_map.find(type) != node_type_map.end()
478 matches.push_back(node_type_map.at(type));
483 if ( (weights.size()==0)
488 return (*
r.select_randomly(matches.begin(),
501 bool terminal_compatible=
true,
502 int max_args=0)
const
512 vector<Node> matches;
513 vector<float> weights;
515 for (
const auto& [args_type, name_map]: args_map) {
516 for (
const auto& [name, node]: name_map) {
517 auto node_arg_types = node.get_arg_types();
521 auto within_size_limit = !(max_args) || (node.get_arg_count() <= max_args);
524 if (
in(node_arg_types, arg)
530 if (terminal_compatible) {
531 bool compatible =
true;
532 for (
const auto& arg_type: node_arg_types) {
533 if (arg_type != arg) {
544 matches.push_back(node);
550 if ( (weights.size()==0)
555 return (*
r.select_randomly(matches.begin(), matches.end(),
556 weights.begin(), weights.end()));
571 if ( (match_weights.size()==0)
573 match_weights.end())) )
576 return (*
r.select_randomly(matches.begin(),
578 match_weights.begin(),
595 string output =
"=== Search space ===\n";
596 output += fmt::format(
"terminal_map: {}\n", this->terminal_map);
597 output += fmt::format(
"terminal_weights: {}\n", this->terminal_weights);
599 for (
const auto& [ret_type, v] : this->node_map) {
600 for (
const auto& [args_type, v2] : v) {
601 for (
const auto& [node_type, node] : v2) {
602 output += fmt::format(
"node_map[{}][{}][{}] = {}, weight = {}\n",
607 this->node_map_weights.at(ret_type).at(args_type).at(node_type)
616 tree<Node>&
PTC2(tree<Node>& Tree, tree<Node>::iterator root,
int max_d,
int max_size)
const;
618 template<NodeType NT,
typename S>
621 const auto& unique_data_types,
628 for (
auto arg: S::get_arg_types()){
629 if (!
in(unique_data_types,arg) ){
633 ArgsName[S::hash()] = fmt::format(
"{}", S::get_arg_types());
634 return Node(
NT, S{}, weighted);
637 template<NodeType NT,
typename S>
639 const unordered_map<string,float>& user_ops,
640 const vector<DataType>& unique_data_types
643 bool use_all = user_ops.size() == 0;
646 bool weighted =
false;
653 auto n = n_maybe.value();
654 node_map[n.ret_type][n.args_type()][n.node_type] = n;
656 float w = use_all? 1.0 : user_ops.at(name);
661 template <
NodeType NT,
typename Sigs, std::size_t...
Is>
662 constexpr void AddNodes(
const unordered_map<string, float> &user_ops,
663 const vector<DataType> &unique_data_types,
664 std::index_sequence<Is...>)
669 template<NodeType NT>
670 void MakeNodes(
const unordered_map<string,float>& user_ops,
671 const vector<DataType>& unique_data_types
676 bool use_all = user_ops.size() == 0;
680 if (!use_all & user_ops.find(name) == user_ops.end())
684 constexpr auto size = std::tuple_size<signatures>::value;
688 std::make_index_sequence<size>()
692 template<std::size_t...
Is>
694 const vector<DataType>& unique_data_types,
695 std::index_sequence<Is...>
698 auto nt = [](
auto i) {
return static_cast<NodeType>(1UL << i); };
707 int loc =
r.rnd_int(0, Q.size()-1);
708 std::swap(Q[loc], Q[Q.size()-1]);
722 max_size =
r.rnd_int(1, params.
max_size);
729 auto Tree = tree<Node>();
733 tree<Node>::iterator spot;
741 node_logit.
fixed=
true;
743 auto spot_logit = Tree.insert(Tree.begin(), node_logit);
749 node_offset.
fixed=
true;
751 auto spot_offset = Tree.append_child(spot_logit);
753 spot = Tree.replace(spot_offset, node_offset);
765 node_softmax.
fixed=
true;
767 spot = Tree.insert(Tree.begin(), node_softmax);
773 std::optional<Node> opt=std::nullopt;
775 if (max_size>1 && max_d>1)
783 spot = Tree.insert(Tree.begin(), root);
787 PTC2(Tree, spot, max_d-1, max_size);
800 return P(*
this, Tree);
810 template <
typename FormatContext>
812 string output =
SS.repr();
813 return formatter<string_view>::format(output, ctx);
holds variable type data.
#define HANDLE_ERROR_THROW(err)
namespace containing Data structures used in Brush
namespace containing various utility functions
bool in(const V &v, const T &i)
check if element is in vector.
< nsga2 selection operator for getting the front
Program< PT::Representer > RepresenterProgram
Program< PT::BinaryClassifier > ClassifierProgram
auto Is(NodeType nt) -> bool
std::unordered_map< std::size_t, std::string > ArgsName
vector< Node > generate_terminals(const Dataset &d, const bool weights_init)
generate terminals from the dataset features and random constants.
T RandomDequeue(std::vector< T > &Q)
queue for make program
tree< Node >::pre_order_iterator Iter
tree< Node >::pre_order_iterator TreeIter
std::map< NodeType, std::string > NodeTypeName
Program< PT::Regressor > RegressorProgram
Program< PT::MulticlassClassifier > MulticlassClassifierProgram
static constexpr bool is_in_v
class holding the data for a node in a tree.
bool fixed
whether node is modifiable
NodeType node_type
the node type
DataType ret_type
return data type
void set_is_weighted(bool is_weighted)
std::size_t args_type() const
void set_prob_change(float w)
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...
void print() const
prints the search space map.
Node get(NodeType type, DataType R, S sig) const
get a typed node.
static constexpr std::optional< Node > CreateNode(const auto &unique_data_types, bool use_all, bool weighted)
Node get(const string &name)
Node get(NodeType type, DataType R, size_t sig_hash) const
get a typed node
constexpr void AddNode(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types)
Map< Node > node_map
Maps return types to argument types to node types.
unordered_map< DataType, vector< Node > > terminal_map
Maps return types to terminals.
unordered_map< DataType, unordered_map< ArgsHash, unordered_map< NodeType, T > > > Map
constexpr void AddNodes(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types, std::index_sequence< Is... >)
std::optional< Node > sample_terminal(DataType R, bool force_return=false) const
Get a random terminal with return type R
void init(const Dataset &d, const unordered_map< string, float > &user_ops={}, bool weights_init=true)
Called by the constructor to initialize the search space.
RegressorProgram make_regressor(int max_d=0, int max_size=0, const Parameters ¶ms=Parameters())
Makes a random regressor program. Convenience wrapper for make_program.
bool check(DataType R, size_t sig_hash, NodeType type) const
check if a typed Node is in the search space
bool check(DataType R) const
check if a return type is in the node map
unordered_map< DataType, vector< float > > terminal_weights
A map of weights corresponding to elements in terminal_map, used to weight probabilities of each term...
std::optional< Node > sample_op(NodeType type, DataType R)
Get a specific node type that matches a return value.
vector< float > get_weights() const
get weights of the return types
vector< string > op_names
A vector storing the available operator names (used by bandits).
vector< float > get_weights(DataType ret) const
get weights of the argument types matching return type ret.
bool check(DataType R, size_t sig_hash) const
check if a function signature is in the search space
void GenerateNodeMap(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types, std::index_sequence< Is... >)
tree< Node > & PTC2(tree< Node > &Tree, tree< Node >::iterator root, int max_d, int max_size) const
std::optional< Node > sample_op(DataType ret) const
get an operator matching return type ret.
SearchSpace(const Dataset &d, const unordered_map< string, float > &user_ops={}, bool weights_init=true)
Construct a search space.
std::string repr() const
returns a string with a json representation of the search space map
std::optional< Node > get_node_like(Node node) const
get a node with a signature matching node
vector< float > get_weights(DataType ret, ArgsHash sig_hash) const
get the weights of nodes matching a signature.
RepresenterProgram make_representer(int max_d=0, int max_size=0, const Parameters ¶ms=Parameters())
Makes a random representer program. Convenience wrapper for make_program.
std::optional< Node > sample_op_with_arg(DataType ret, DataType arg, bool terminal_compatible=true, int max_args=0) const
get operator with at least one argument matching arg
Map< float > node_map_weights
A map of weights corresponding to elements in node_map, used to weight probabilities of each node bei...
MulticlassClassifierProgram make_multiclass_classifier(int max_d=0, int max_size=0, const Parameters ¶ms=Parameters())
Makes a random multiclass classifier program. Convenience wrapper for make_program.
std::optional< tree< Node > > sample_subtree(Node root, int max_d, int max_size) const
create a subtree with maximum size and depth restrictions and root of type root_type
PT make_program(const Parameters ¶ms, int max_d=0, int max_size=0)
Makes a random program.
vector< DataType > terminal_types
A vector storing the available return types of terminals.
bool has_solution_space(Iter start, Iter end) const
Takes iterators to weight vectors and checks if they have a non-empty solution space....
void MakeNodes(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types)
ClassifierProgram make_classifier(int max_d=0, int max_size=0, const Parameters ¶ms=Parameters())
Makes a random classifier program. Convenience wrapper for make_program.
std::optional< Node > sample_terminal(bool force_return=false) const
Get a random terminal.