Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
constants.cpp
Go to the documentation of this file.
1
2
#include "
constants.h
"
3
4
namespace
Brush
{
namespace
Simpl
{
5
6
// Constants_simplifier* Constants_simplifier::instance = NULL;
7
8
Constants_simplifier::Constants_simplifier
()
9
{
10
}
11
12
// Constants_simplifier* Constants_simplifier::initSimplifier()
13
// {
14
// // creates the static random generator by calling the constructor
15
// if (!instance)
16
// {
17
// instance = new Constants_simplifier();
18
// }
19
20
// return instance;
21
// }
22
23
// template <ProgramType PT>
24
// Program<PT> Constants_simplifier::simplify_tree(
25
// Program<PT>& program, const SearchSpace &ss, const Dataset &d)
26
// {
27
// // create a copy of the tree
28
// Program<PT> simplified_program(program);
29
30
// // iterate over the tree, trying to replace each node with a constant, and keeping the change if the pred does not change.
31
// TreeIter spot = simplified_program.Tree.begin();
32
// while(spot != simplified_program.Tree.end())
33
// {
34
// Node n = spot.node->data;
35
36
// if (Isnt<NodeType::Terminal, NodeType::Constant, NodeType::MeanLabel>(n.node_type)
37
// && n.get_prob_change()>0)
38
// {
39
// // get new_pred with predictions after simplification
40
// VectorXf branch_pred;
41
// if constexpr (PT==ProgramType::Regressor || PT==ProgramType::BinaryClassifier)
42
// {
43
// branch_pred = (*spot.node).template predict<ArrayXf>(d);
44
// }
45
// else if constexpr (PT==ProgramType::MulticlassClassifier)
46
// {
47
// ArrayXXf out = (*spot.node).template predict<ArrayXXf>(d);
48
// auto argmax = Function<NodeType::ArgMax>{};
49
// branch_pred = ArrayXf(argmax(out).template cast<float>());
50
// }
51
// else
52
// {
53
// HANDLE_ERROR_THROW("No predict available for the class.");
54
// }
55
56
// if (variance(branch_pred) < 1e-4) // TODO: calculate threshold based on data
57
// {
58
// // get constant equivalent to its argtype (all data types should have
59
// // a constant defined in the search space for its given type). It will be
60
// // the last node of the terminal map for the given type
61
// Node cte = ss.terminal_map.at(n.ret_type).at(
62
// ss.terminal_map.at(n.ret_type).size()-1);
63
64
// cte.W = branch_pred.mean();
65
// simplified_program.Tree.erase_children(spot);
66
// spot = simplified_program.Tree.replace(spot, cte);
67
// }
68
// }
69
// ++spot;
70
// }
71
// program.Tree = simplified_program.Tree;
72
73
// return simplified_program;
74
// }
75
76
// void Constants_simplifier::destroy()
77
// {
78
// if (instance)
79
// delete instance;
80
81
// instance = NULL;
82
// }
83
84
Constants_simplifier::~Constants_simplifier
() {}
85
} }
Brush::Simpl::Constants_simplifier::Constants_simplifier
Constants_simplifier()
Definition
constants.cpp:8
Brush::Simpl::Constants_simplifier::~Constants_simplifier
~Constants_simplifier()
Definition
constants.cpp:84
constants.h
Brush::Simpl
Definition
constants.cpp:4
Brush
< nsga2 selection operator for getting the front
Definition
bandit.cpp:4
src
simplification
constants.cpp
Generated by
1.13.2