Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
search_space.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4*/
5#ifndef SEARCHSPACE_H
6#define SEARCHSPACE_H
7//internal includes
8#include "../init.h"
9#include "../program/node.h"
10#include "../program/nodetype.h"
12// #include "program/program.h"
13#include "../util/error.h"
14#include "../util/utils.h"
15#include "../util/rnd.h"
16#include "../params.h"
17#include <utility>
18#include <optional>
19#include <iostream>
20
21/* Defines the search space of Brush.
22 * The search spaces consists of nodes and their accompanying probability
23 * distribution.
24 * Nodes can be accessed by type, signature, or a combination.
25 * You may also sample the search space by return type.
26 * Sampling is done in proportion to the weight associated with
27 * each node. By default, sampling is done uniform randomly.
28*/
29using namespace Brush::Data;
30using namespace Brush::Util;
31using Brush::Node;
32using Brush::DataType;
33using std::type_index;
34
35
36namespace Brush
37{
39// node generation routines
40/* template<typename T> */
41/* tuple<set<Node>,set<type_index>> generate_nodes(vector<string>& op_names); */
42/* tuple<set<Node>,set<type_index>> generate_split_nodes(vector<string>& op_names); */
43
44// forward declarations
45using TreeIter = tree<Node>::pre_order_iterator;
46// template<typename T> struct Program;
47// enum class ProgramType: uint32_t;
48// template<typename T> struct ProgramTypeEnum;
49
50vector<Node> generate_terminals(const Dataset& d, const bool weights_init);
51
53
54extern std::unordered_map<std::size_t, std::string> ArgsName;
55
84{
85 using ArgsHash = std::size_t;
86
87 template<typename T>
88 using Map = unordered_map<DataType, // return type
89 unordered_map<ArgsHash, // hash of arg types
90 unordered_map<NodeType, // node type
91 T>>>; // the data!
92
101
104
113 unordered_map<DataType, vector<Node>> terminal_map;
114
116 unordered_map<DataType, vector<float>> terminal_weights;
117
119 vector<DataType> terminal_types;
120
121 // serialization
122#ifndef DOXYGEN_SKIP
123
125 node_map,
130 )
131
132#endif
133
148 template<typename PT>
149 PT make_program(const Parameters& params, int max_d=0, int max_size=0);
150
155 RegressorProgram make_regressor(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
156
161 ClassifierProgram make_classifier(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
162
167 MulticlassClassifierProgram make_multiclass_classifier(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
168
173 RepresenterProgram make_representer(int max_d = 0, int max_size = 0, const Parameters& params=Parameters());
174
175 SearchSpace() = default;
176
181 SearchSpace(const Dataset& d, const unordered_map<string,float>& user_ops = {}, bool weights_init = true){
183 }
184
189 void init(const Dataset& d, const unordered_map<string,float>& user_ops = {}, bool weights_init = true);
190
194 bool check(DataType R) const {
195 if (node_map.find(R) == node_map.end()){
196 auto msg = fmt::format("{} not in node_map\n",R);
198 }
199 return true;
200 }
201
206 bool check(DataType R, size_t sig_hash) const
207 {
208 if (check(R)){
209 if (node_map.at(R).find(sig_hash) == node_map.at(R).end()){
210 auto msg = fmt::format("{} not in node_map.at({})\n", sig_hash, R);
212 }
213 }
214 return true;
215 }
216
222 bool check(DataType R, size_t sig_hash, NodeType type) const
223 {
224 if (check(R,sig_hash)){
225 if (node_map.at(R).at(sig_hash).find(type) == node_map.at(R).at(sig_hash).end()){
226
227 auto msg = fmt::format("{} not in node_map[{}][{}]\n",type, sig_hash, R);
229 }}
230 return true;
231 }
232
240 template<typename Iter>
242 return !std::all_of(start, end, [](const auto& w) { return w<=0.0; });
243 }
244
245 template<typename F> Node get(const string& name);
246
252 Node get(NodeType type, DataType R, size_t sig_hash)
253 {
254 check(R, sig_hash, type);
255 return node_map.at(R).at(sig_hash).at(type);
256 };
257
264 template<typename S>
265 Node get(NodeType type, DataType R, S sig){ return get(type, R, sig.hash()); };
266
269 vector<float> get_weights() const
270 {
271 vector<float> v;
272 for (auto& [ret, arg_w_map]: node_map_weights)
273 {
274 v.push_back(0);
275 for (const auto& [arg, name_map] : arg_w_map)
276 {
277 for (const auto& [name, w]: name_map)
278 {
279 v.back() += w;
280 }
281 }
282 }
283 return v;
284 };
285
289 vector<float> get_weights(DataType ret) const
290 {
291 vector<float> v;
292 for (const auto& [arg, name_map] : node_map_weights.at(ret))
293 {
294 v.push_back(0);
295 for (const auto& [name, w]: name_map)
296 {
297 v.back() += w;
298 }
299
300 }
301 return v;
302 };
303
308 vector<float> get_weights(DataType ret, ArgsHash sig_hash) const
309 {
310 vector<float> v;
311 for (const auto& [name, w]: node_map_weights.at(ret).at(sig_hash))
312 v.push_back(w);
313
314 return v;
315 };
316
319 std::optional<Node> sample_terminal(bool force_return=false) const
320 {
321 //TODO: match terminal args_type (probably '{}' or something?)
322 // make a separate terminal_map
323
324 // We'll make terminal types to have its weights proportional to the
325 // DataTypes Weights they hold
326 vector<float> data_type_weights(terminal_weights.size());
327 if (force_return)
328 {
329 std::fill(data_type_weights.begin(), data_type_weights.end(), 1.0f);
330 }
331 else
332 {
333 std::transform(
334 terminal_weights.begin(),
335 terminal_weights.end(),
336 data_type_weights.begin(),
337 [](const auto& tw){
338 return std::reduce(tw.second.begin(), tw.second.end()); }
339 );
340
342 data_type_weights.end()))
343 return std::nullopt;
344 }
345
346 // If we got this far, then it is garanteed that we'll return something
347 // The match take into account datatypes with non-zero weights
348 auto match = *r.select_randomly(
349 terminal_map.begin(),
350 terminal_map.end(),
351 data_type_weights.begin(),
353 );
354
355 // theres always a constant of each data type
356 vector<float> match_weights(match.second.size());
357 if (force_return)
358 {
359 std::fill(match_weights.begin(), match_weights.end(), 1.0f);
360 }
361 else
362 {
363 std::transform(
364 terminal_weights.at(match.first).begin(),
365 terminal_weights.at(match.first).end(),
366 match_weights.begin(),
367 [](const auto& w){ return w; });
368
369 if (!has_solution_space(match_weights.begin(),
370 match_weights.end()))
371 return std::nullopt;
372 }
373
374 return *r.select_randomly(match.second.begin(), match.second.end(),
375 match_weights.begin(), match_weights.end());
376 };
377
380 std::optional<Node> sample_terminal(DataType R, bool force_return=false) const
381 {
382 // should I keep doing this check?
383 // if (terminal_map.find(R) == terminal_map.end()){
384 // auto msg = fmt::format("{} not in terminal_map\n",R);
385 // HANDLE_ERROR_THROW(msg);
386 // }
387
388 // If there's at least one constant for every data type, its always possible to force sample_terminal to return something
389
390 // TODO: try to combine with above function
391 vector<float> match_weights(terminal_weights.at(R).size());
392 if (force_return)
393 {
394 std::fill(match_weights.begin(), match_weights.end(), 1.0f);
395 }
396 else
397 {
398 std::transform(
399 terminal_weights.at(R).begin(),
400 terminal_weights.at(R).end(),
401 match_weights.begin(),
402 [](const auto& w){ return w; }
403 );
404
405 if ( (terminal_map.find(R) == terminal_map.end())
406 || (!has_solution_space(match_weights.begin(),
407 match_weights.end())) )
408 return std::nullopt;
409 }
410
411 return *r.select_randomly(terminal_map.at(R).begin(),
412 terminal_map.at(R).end(),
413 match_weights.begin(),
414 match_weights.end());
415 };
416
420 std::optional<Node> sample_op(DataType ret) const
421 {
422 // check(ret);
423 if (node_map.find(ret) == node_map.end())
424 return std::nullopt;
425
426 //TODO: match terminal args_type (probably '{}' or something?)
427 auto ret_match = node_map.at(ret);
428
429 vector<float> args_w = get_weights(ret);
430
431 if (!has_solution_space(args_w.begin(), args_w.end()))
432 return std::nullopt;
433
434 auto arg_match = *r.select_randomly(ret_match.begin(),
435 ret_match.end(),
436 args_w.begin(),
437 args_w.end());
438
439 vector<float> name_w = get_weights(ret, arg_match.first);
440
441 if (!has_solution_space(name_w.begin(), name_w.end()))
442 return std::nullopt;
443
444 return (*r.select_randomly(arg_match.second.begin(),
445 arg_match.second.end(),
446 name_w.begin(),
447 name_w.end())).second;
448 };
449
454 std::optional<Node> sample_op(NodeType type, DataType R)
455 {
456 // check(R);
457 if (node_map.find(R) == node_map.end())
458 return std::nullopt;
459
460 auto ret_match = node_map.at(R);
461
462 vector<Node> matches;
463 vector<float> weights;
464 for (const auto& kv: ret_match)
465 {
466 auto arg_hash = kv.first;
467 auto node_type_map = kv.second;
468 if (node_type_map.find(type) != node_type_map.end())
469 {
470 matches.push_back(node_type_map.at(type));
471 weights.push_back(node_map_weights.at(R).at(arg_hash).at(type));
472 }
473 }
474
475 if ( (weights.size()==0)
476 || (!has_solution_space(weights.begin(),
477 weights.end())) )
478 return std::nullopt;
479
480 return (*r.select_randomly(matches.begin(),
481 matches.end(),
482 weights.begin(),
483 weights.end()));
484 };
485
493 bool terminal_compatible=true,
494 int max_args=0) const
495 {
496 // thoughts (TODO):
497 // this could be templated by return type and arg. although the lookup in the map should be
498 // fairly fast.
499 //TODO: these needs to be overhauled
500 // fmt::print("sample_op_with_arg");
501 check(ret);
502
503 auto args_map = node_map.at(ret);
504 vector<Node> matches;
505 vector<float> weights;
506
507 for (const auto& [args_type, name_map]: args_map) {
508 for (const auto& [name, node]: name_map) {
509 auto node_arg_types = node.get_arg_types();
510
511 // has no size limit (max_arg_count==0) or the number of
512 // arguments woudn't exceed the maximum number of arguments
513 auto within_size_limit = !(max_args) || (node.get_arg_count() <= max_args);
514
516 // if checking terminal compatibility, make sure there's
517 // a compatible terminal for the node's other arguments
519 bool compatible = true;
520 for (const auto& arg_type: node_arg_types) {
521 if (arg_type != arg) {
522 if ( ! in(terminal_types, arg_type) ) {
523 compatible = false;
524 break;
525 }
526 }
527 }
528 if (! compatible)
529 continue;
530 }
531 // if we made it this far, include the node as a match!
532 matches.push_back(node);
533 weights.push_back(node_map_weights.at(ret).at(args_type).at(name));
534 }
535 }
536 }
537
538 if ( (weights.size()==0)
539 || (!has_solution_space(weights.begin(),
540 weights.end())) )
541 return std::nullopt;
542
543 return (*r.select_randomly(matches.begin(), matches.end(),
544 weights.begin(), weights.end()));
545 };
546
550 std::optional<Node> get_node_like(Node node) const
551 {
553 return sample_terminal(node.ret_type);
554 }
555
556 auto matches = node_map.at(node.ret_type).at(node.args_type());
557 auto match_weights = get_weights(node.ret_type, node.args_type());
558
559 if ( (match_weights.size()==0)
560 || (!has_solution_space(match_weights.begin(),
561 match_weights.end())) )
562 return std::nullopt;
563
564 return (*r.select_randomly(matches.begin(),
565 matches.end(),
566 match_weights.begin(),
567 match_weights.end())
568 ).second;
569 };
570
576 std::optional<tree<Node>> sample_subtree(Node root, int max_d, int max_size) const;
577
579 void print() const;
580
581 private:
582 tree<Node>& PTC2(tree<Node>& Tree, tree<Node>::iterator root, int max_d, int max_size) const;
583
584 template<NodeType NT, typename S>
586 static constexpr std::optional<Node> CreateNode(
587 const auto& unique_data_types,
588 bool use_all,
589 bool weighted
590 )
591 {
592 // prune the operators out that don't have argument types that
593 // overlap with feature data types
594 for (auto arg: S::get_arg_types()){
595 if (! in(unique_data_types,arg) ){
596 return {};
597 }
598 }
599 ArgsName[S::hash()] = fmt::format("{}", S::get_arg_types());
600 return Node(NT, S{}, weighted);
601 }
602
603 template<NodeType NT, typename S>
604 constexpr void AddNode(
605 const unordered_map<string,float>& user_ops,
606 const vector<DataType>& unique_data_types
607 )
608 {
609 bool use_all = user_ops.size() == 0;
610 auto name = NodeTypeName[NT];
611
612 bool weighted = false;
613 if (Is<NodeType::OffsetSum>(NT)) // this has to have weights on by default
614 weighted = true;
615
616 auto n_maybe = CreateNode<NT,S>(unique_data_types, use_all, weighted);
617
618 if (n_maybe){
619 auto n = n_maybe.value();
620 node_map[n.ret_type][n.args_type()][n.node_type] = n;
621 // sampling probability map
622 float w = use_all? 1.0 : user_ops.at(name);
623 node_map_weights[n.ret_type][n.args_type()][n.node_type] = w;
624 }
625 }
626
627 template <NodeType NT, typename Sigs, std::size_t... Is>
628 constexpr void AddNodes(const unordered_map<string, float> &user_ops,
629 const vector<DataType> &unique_data_types,
630 std::index_sequence<Is...>)
631 {
632 (AddNode<NT, std::tuple_element_t<Is, Sigs>>(user_ops, unique_data_types), ...);
633 }
634
635 template<NodeType NT>
636 void MakeNodes(const unordered_map<string,float>& user_ops,
637 const vector<DataType>& unique_data_types
638 )
639 {
641 return;
642 bool use_all = user_ops.size() == 0;
643 auto name = NodeTypeName.at(NT);
644
645 // skip operators not defined by user
646 if (!use_all & user_ops.find(name) == user_ops.end())
647 return;
648
650 constexpr auto size = std::tuple_size<signatures>::value;
652 user_ops,
653 unique_data_types,
654 std::make_index_sequence<size>()
655 );
656 }
657
658 template<std::size_t... Is>
659 void GenerateNodeMap(const unordered_map<string,float>& user_ops,
660 const vector<DataType>& unique_data_types,
661 std::index_sequence<Is...>
662 )
663 {
664 auto nt = [](auto i) { return static_cast<NodeType>(1UL << i); };
665 (MakeNodes<nt(Is)>(user_ops, unique_data_types), ...);
666 }
667}; // SearchSpace
668
670template<typename T>
671T RandomDequeue(std::vector<T>& Q)
672{
673 int loc = r.rnd_int(0, Q.size()-1);
674 std::swap(Q[loc], Q[Q.size()-1]);
675 T val = Q.back();
676 Q.pop_back();
677 return val;
678};
679
680template<typename P>
681P SearchSpace::make_program(const Parameters& params, int max_d, int max_size)
682{
683 // this is what makes `make_program` create uniformly distributed
684 // individuals to feed initial population
685 if (max_d < 1)
686 max_d = r.rnd_int(1, params.max_depth);
687 if (max_size < 1)
688 max_size = r.rnd_int(1, params.max_size);
689
691 ProgramType program_type = P::program_type;
692 // ProgramType program_type = ProgramTypeEnum<PT>::value;
693
694 // Tree is pre-filled with some fixed nodes depending on program type
695 auto Tree = tree<Node>();
696
697 // building the tree for each program case. Then, we give the spot to PTC2,
698 // and it will fill the rest of the tree
699 tree<Node>::iterator spot;
700
701 // building the root node for each program case
702 if (P::program_type == ProgramType::BinaryClassifier)
703 {
705 node_logit.set_prob_change(0.0);
706 node_logit.fixed=true;
707 auto spot_logit = Tree.insert(Tree.begin(), node_logit);
708
709 if (true) { // Logistic(Add(Constant, <>)).
711 node_offset.set_prob_change(0.0);
712 node_offset.fixed=true;
713
714 auto spot_offset = Tree.append_child(spot_logit);
715
716 spot = Tree.replace(spot_offset, node_offset);
717 }
718 else { // If false, then model will be Logistic(<>)
720 }
721 }
722 else if (P::program_type == ProgramType::MulticlassClassifier)
723 {
725 node_softmax.set_prob_change(0.0);
726 node_softmax.fixed=true;
727
728 spot = Tree.insert(Tree.begin(), node_softmax);
729 }
730 else // regression or representer --- sampling any candidate op or terminal
731 {
732 Node root;
733
734 std::optional<Node> opt=std::nullopt;
735
736 if (max_size>1 && max_d>1)
738
739 if (!opt) // if failed, then we dont have any operator to use as root...
741
742 root = opt.value();
743
744 spot = Tree.insert(Tree.begin(), root);
745 }
746
747 // max_d-1 because we always pick the root before calling ptc2
748 PTC2(Tree, spot, max_d-1, max_size); // change inplace
749
750 return P(*this, Tree);
751};
752
753extern SearchSpace SS;
754
755} // Brush
756
757// format overload
758template <> struct fmt::formatter<Brush::SearchSpace>: formatter<string_view> {
759 template <typename FormatContext>
761 string output = "Search Space\n===\n";
762 output += fmt::format("terminal_map: {}\n", SS.terminal_map);
763 output += fmt::format("terminal_weights: {}\n", SS.terminal_weights);
764 for (const auto& [ret_type, v] : SS.node_map) {
765 for (const auto& [args_type, v2] : v) {
766 for (const auto& [node_type, node] : v2) {
767 output += fmt::format("node_map[{}][{}][{}] = {}, weight = {}\n",
768 ret_type,
769 ArgsName[args_type],
770 node_type,
771 node,
772 SS.node_map_weights.at(ret_type).at(args_type).at(node_type)
773 );
774 }
775 }
776 }
777 output += "===";
778 return formatter<string_view>::format(output, ctx);
779 }
780};
781#endif
void bind_engine(py::module &m, string name)
holds variable type data.
Definition data.h:51
Iter select_randomly(Iter start, Iter end)
Definition rnd.h:61
int rnd_int(int lowerLimit, int upperLimit)
Definition rnd.cpp:68
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
namespace containing Data structures used in Brush
Definition data.cpp:49
namespace containing various utility functions
Definition error.cpp:11
bool in(const V &v, const T &i)
check if element is in vector.
Definition utils.h:192
static Rnd & r
Definition rnd.h:174
< nsga2 selection operator for getting the front
Definition data.cpp:12
NodeType
Definition nodetype.h:31
DataType
data types.
Definition types.h:143
auto Is(NodeType nt) -> bool
Definition node.h:260
std::unordered_map< std::size_t, std::string > ArgsName
vector< Node > generate_terminals(const Dataset &d, const bool weights_init)
generate terminals from the dataset features and random constants.
T RandomDequeue(std::vector< T > &Q)
queue for make program
tree< Node >::pre_order_iterator Iter
Definition program.h:37
tree< Node >::pre_order_iterator TreeIter
NodeType NT
Definition node.cpp:129
std::map< NodeType, std::string > NodeTypeName
Definition nodetype.cpp:81
SearchSpace SS
ProgramType
Definition types.h:70
class holding the data for a node in a tree.
Definition node.h:84
unsigned int max_depth
Definition params.h:35
unsigned int max_size
Definition params.h:36
An individual program, a.k.a. model.
Definition program.h:50
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...
unordered_map< DataType, unordered_map< ArgsHash, unordered_map< NodeType, T > > > Map
void print() const
prints the search space map.
static constexpr std::optional< Node > CreateNode(const auto &unique_data_types, bool use_all, bool weighted)
Node get(const string &name)
constexpr void AddNode(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types)
Map< Node > node_map
Maps return types to argument types to node types.
unordered_map< DataType, vector< Node > > terminal_map
Maps return types to terminals.
constexpr void AddNodes(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types, std::index_sequence< Is... >)
std::optional< Node > sample_terminal(DataType R, bool force_return=false) const
Get a random terminal with return type R
void init(const Dataset &d, const unordered_map< string, float > &user_ops={}, bool weights_init=true)
Called by the constructor to initialize the search space.
RegressorProgram make_regressor(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random regressor program. Convenience wrapper for make_program.
bool check(DataType R, size_t sig_hash, NodeType type) const
check if a typed Node is in the search space
Node get(NodeType type, DataType R, S sig)
get a typed node.
bool check(DataType R) const
check if a return type is in the node map
unordered_map< DataType, vector< float > > terminal_weights
A map of weights corresponding to elements in terminal_map, used to weight probabilities of each term...
std::optional< Node > sample_op(NodeType type, DataType R)
Get a specific node type that matches a return value.
vector< float > get_weights() const
get weights of the return types
Node get(NodeType type, DataType R, size_t sig_hash)
get a typed node
vector< float > get_weights(DataType ret) const
get weights of the argument types matching return type ret.
bool check(DataType R, size_t sig_hash) const
check if a function signature is in the search space
void GenerateNodeMap(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types, std::index_sequence< Is... >)
tree< Node > & PTC2(tree< Node > &Tree, tree< Node >::iterator root, int max_d, int max_size) const
std::optional< Node > sample_op(DataType ret) const
get an operator matching return type ret.
SearchSpace(const Dataset &d, const unordered_map< string, float > &user_ops={}, bool weights_init=true)
Construct a search space.
std::optional< Node > get_node_like(Node node) const
get a node with a signature matching node
vector< float > get_weights(DataType ret, ArgsHash sig_hash) const
get the weights of nodes matching a signature.
RepresenterProgram make_representer(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random representer program. Convenience wrapper for make_program.
std::optional< Node > sample_op_with_arg(DataType ret, DataType arg, bool terminal_compatible=true, int max_args=0) const
get operator with at least one argument matching arg
Map< float > node_map_weights
A map of weights corresponding to elements in node_map, used to weight probabilities of each node bei...
MulticlassClassifierProgram make_multiclass_classifier(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random multiclass classifier program. Convenience wrapper for make_program.
std::optional< tree< Node > > sample_subtree(Node root, int max_d, int max_size) const
create a subtree with maximum size and depth restrictions and root of type root_type
std::size_t ArgsHash
PT make_program(const Parameters &params, int max_d=0, int max_size=0)
Makes a random program.
vector< DataType > terminal_types
A vector storing the available return types of terminals.
bool has_solution_space(Iter start, Iter end) const
Takes iterators to weight vectors and checks if they have a non-empty solution space....
void MakeNodes(const unordered_map< string, float > &user_ops, const vector< DataType > &unique_data_types)
SearchSpace()=default
ClassifierProgram make_classifier(int max_d=0, int max_size=0, const Parameters &params=Parameters())
Makes a random classifier program. Convenience wrapper for make_program.
std::optional< Node > sample_terminal(bool force_return=false) const
Get a random terminal.
auto format(const Brush::SearchSpace &SS, FormatContext &ctx) const