Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
constants.h
Go to the documentation of this file.
1#ifndef CONSTANTS_H
2#define CONSTANTS_H
3
4#include "../init.h"
5#include "../types.h"
8#include "../util/utils.h"
9
10using namespace std;
11using Brush::Node;
12using Brush::DataType;
13
14namespace Brush { namespace Simpl{
16 {
17 public:
18 template <ProgramType P>
20 Program<P>& program, const SearchSpace &ss, const Dataset &d)
21 {
22 using RetType =
23 typename std::conditional_t<P == PT::Regressor, ArrayXf,
24 std::conditional_t<P == PT::Representer, ArrayXXf, ArrayXf
25 >>;
26
27 // create a copy of the tree
28 Program<P> simplified_program(program);
29
30 // iterate over the tree, trying to replace each node with a constant, and keeping the change if the pred does not change.
31 TreeIter spot = simplified_program.Tree.begin();
32 while(spot != simplified_program.Tree.end())
33 {
34 Node n = spot.node->data;
35
36 // This is avoiding using booleans.
37 // non-wheightable nodes are not simplified. TODO: revisit this and see if they should (then implement it)
39 && n.get_prob_change()>0
41 )
42 {
43 // TODO: check if holds alternative and use this information, instead of making it templated. Also, return void.
44 // get new_pred with predictions after simplification
45 VectorXf branch_pred;
47 {
48 RetType pred = (*spot.node).predict<RetType>(d);
49 branch_pred = pred.template cast<float>();
50 }
51 else if constexpr (P==ProgramType::MulticlassClassifier)
52 {
53 ArrayXXf out = (*spot.node).template predict<ArrayXXf>(d);
54 auto argmax = Function<NodeType::ArgMax>{};
55 branch_pred = ArrayXf(argmax(out).template cast<float>());
56 }
57 else
58 {
59 HANDLE_ERROR_THROW("No predict available for the class.");
60 }
61
62 if (variance(branch_pred) < 1e-5) // TODO: calculate threshold based on data
63 {
64 // get constant equivalent to its argtype (all data types should have
65 // a constant defined in the search space for its given type). It will be
66 // the last node of the terminal map for the given type
67 Node cte = ss.terminal_map.at(n.ret_type).at(
68 ss.terminal_map.at(n.ret_type).size()-1);
69
70 cte.W = branch_pred.mean();
71 simplified_program.Tree.erase_children(spot);
72 spot = simplified_program.Tree.replace(spot, cte);
73 }
74 }
75 ++spot;
76 }
77 program.Tree = simplified_program.Tree;
78 return simplified_program;
79 }
80
83 private:
84
85 };
86} // Simply
87} // Brush
88
89#endif
holds variable type data.
Definition data.h:51
Program< P > simplify_tree(Program< P > &program, const SearchSpace &ss, const Dataset &d)
Definition constants.h:19
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
float variance(const ArrayXf &v)
calculate variance
Definition utils.cpp:317
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
auto Isnt(DataType dt) -> bool
Definition node.h:43
auto IsWeighable() noexcept -> bool
Definition node.h:46
DataType
data types.
Definition types.h:143
tree< Node >::pre_order_iterator TreeIter
STL namespace.
class holding the data for a node in a tree.
Definition node.h:84
float get_prob_change() const
Definition node.h:254
NodeType node_type
the node type
Definition node.h:89
DataType ret_type
return data type
Definition node.h:92
float W
the weights of the node. also used for splitting thresholds.
Definition node.h:107
An individual program, a.k.a. model.
Definition program.h:50
tree< Node > Tree
fitness
Definition program.h:73
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...
unordered_map< DataType, vector< Node > > terminal_map
Maps return types to terminals.