Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
constants.cpp
Go to the documentation of this file.
1
2#include "constants.h"
3
4namespace Brush { namespace Simpl{
5
6 // Constants_simplifier* Constants_simplifier::instance = NULL;
7
11
12 // Constants_simplifier* Constants_simplifier::initSimplifier()
13 // {
14 // // creates the static random generator by calling the constructor
15 // if (!instance)
16 // {
17 // instance = new Constants_simplifier();
18 // }
19
20 // return instance;
21 // }
22
23 // template <ProgramType PT>
24 // Program<PT> Constants_simplifier::simplify_tree(
25 // Program<PT>& program, const SearchSpace &ss, const Dataset &d)
26 // {
27 // // create a copy of the tree
28 // Program<PT> simplified_program(program);
29
30 // // iterate over the tree, trying to replace each node with a constant, and keeping the change if the pred does not change.
31 // TreeIter spot = simplified_program.Tree.begin();
32 // while(spot != simplified_program.Tree.end())
33 // {
34 // Node n = spot.node->data;
35
36 // if (Isnt<NodeType::Terminal, NodeType::Constant, NodeType::MeanLabel>(n.node_type)
37 // && n.get_prob_change()>0)
38 // {
39 // // get new_pred with predictions after simplification
40 // VectorXf branch_pred;
41 // if constexpr (PT==ProgramType::Regressor || PT==ProgramType::BinaryClassifier)
42 // {
43 // branch_pred = (*spot.node).template predict<ArrayXf>(d);
44 // }
45 // else if constexpr (PT==ProgramType::MulticlassClassifier)
46 // {
47 // ArrayXXf out = (*spot.node).template predict<ArrayXXf>(d);
48 // auto argmax = Function<NodeType::ArgMax>{};
49 // branch_pred = ArrayXf(argmax(out).template cast<float>());
50 // }
51 // else
52 // {
53 // HANDLE_ERROR_THROW("No predict available for the class.");
54 // }
55
56 // if (variance(branch_pred) < 1e-4) // TODO: calculate threshold based on data
57 // {
58 // // get constant equivalent to its argtype (all data types should have
59 // // a constant defined in the search space for its given type). It will be
60 // // the last node of the terminal map for the given type
61 // Node cte = ss.terminal_map.at(n.ret_type).at(
62 // ss.terminal_map.at(n.ret_type).size()-1);
63
64 // cte.W = branch_pred.mean();
65 // simplified_program.Tree.erase_children(spot);
66 // spot = simplified_program.Tree.replace(spot, cte);
67 // }
68 // }
69 // ++spot;
70 // }
71 // program.Tree = simplified_program.Tree;
72
73 // return simplified_program;
74 // }
75
76 // void Constants_simplifier::destroy()
77 // {
78 // if (instance)
79 // delete instance;
80
81 // instance = NULL;
82 // }
83
85} }
< nsga2 selection operator for getting the front
Definition bandit.cpp:4