Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
constants.h
Go to the documentation of this file.
1#ifndef CONSTANTS_H
2#define CONSTANTS_H
3
4#include "../init.h"
5#include "../types.h"
8#include "../util/utils.h"
9
10using namespace std;
11using Brush::Node;
12using Brush::DataType;
13
14namespace Brush { namespace Simpl{
15
17 {
18 public:
19 // static Constants_simplifier* initSimplifier();
20
21 // static void destroy();
22
23 template <ProgramType P>
25 Program<P>& program, const SearchSpace &ss, const Dataset &d)
26 {
27 using RetType =
28 typename std::conditional_t<P == PT::Regressor, ArrayXf,
29 std::conditional_t<P == PT::Representer, ArrayXXf, ArrayXf
30 >>;
31
32 // create a copy of the tree
33 Program<P> simplified_program(program);
34
35 // iterate over the tree, trying to replace each node with a constant, and keeping the change if the pred does not change.
36 TreeIter spot = simplified_program.Tree.begin();
37 while(spot != simplified_program.Tree.end())
38 {
39 Node n = spot.node->data;
40
41 // This is avoiding using booleans.
42 // non-wheightable nodes are not simplified. TODO: revisit this and see if they should (then implement it)
44 && n.get_prob_change()>0
46 )
47 {
48 // TODO: check if holds alternative and use this information, instead of making it templated. Also, return void.
49 // get new_pred with predictions after simplification
50 VectorXf branch_pred;
52 {
53 RetType pred = (*spot.node).predict<RetType>(d);
54 branch_pred = pred.template cast<float>();
55 }
56 else if constexpr (P==ProgramType::MulticlassClassifier)
57 {
58 ArrayXXf out = (*spot.node).template predict<ArrayXXf>(d);
59 auto argmax = Function<NodeType::ArgMax>{};
60 branch_pred = ArrayXf(argmax(out).template cast<float>());
61 }
62 else
63 {
64 HANDLE_ERROR_THROW("No predict available for the class.");
65 }
66
67 if (variance(branch_pred) < 1e-4) // TODO: calculate threshold based on data
68 {
69 // get constant equivalent to its argtype (all data types should have
70 // a constant defined in the search space for its given type). It will be
71 // the last node of the terminal map for the given type
72 Node cte = ss.terminal_map.at(n.ret_type).at(
73 ss.terminal_map.at(n.ret_type).size()-1);
74
75 cte.W = branch_pred.mean();
76 simplified_program.Tree.erase_children(spot);
77 spot = simplified_program.Tree.replace(spot, cte);
78 }
79 }
80 ++spot;
81 }
82 program.Tree = simplified_program.Tree;
83 return simplified_program;
84 }
85
88 private:
89
90 // private static attribute used by every instance of the class
91 // static Constants_simplifier* instance;
92 };
93
94 // TODO: get rid of static reference
95 // static attribute holding an singleton instance of Constants_simplifier.
96 // the instance is created by calling `initRand`, which creates
97 // an instance of the private static attribute `instance`. `r` will contain
98 // one generator for each thread (since it called the constructor)
99 // static Constants_simplifier &constants_simplifier = *Constants_simplifier::initSimplifier();
100
101} // Simply
102} // Brush
103
104#endif
holds variable type data.
Definition data.h:51
Program< P > simplify_tree(Program< P > &program, const SearchSpace &ss, const Dataset &d)
Definition constants.h:24
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
float variance(const ArrayXf &v)
calculate variance
Definition utils.cpp:317
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
auto Isnt(DataType dt) -> bool
Definition node.h:43
auto IsWeighable() noexcept -> bool
Definition node.h:46
DataType
data types.
Definition types.h:143
tree< Node >::pre_order_iterator TreeIter
STL namespace.
class holding the data for a node in a tree.
Definition node.h:84
float get_prob_change() const
Definition node.h:245
NodeType node_type
the node type
Definition node.h:95
DataType ret_type
return data type
Definition node.h:101
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:107
An individual program, a.k.a. model.
Definition program.h:50
tree< Node > Tree
fitness
Definition program.h:73
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...
unordered_map< DataType, vector< Node > > terminal_map
Maps return types to terminals.