Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
inexact.h
Go to the documentation of this file.
1// simplification maps are based on trainingdata
2// should ignore fixed nodes ---> does not change subtrees if they contain fixed nodes (also this should be applied to the constants simplification)
3// number of tables and max_samples_to_use are parameters, set default values enough to make it work off-the-shelf
4// should I implement json serialization?
5// maybe this function should be templated. should handle constant predictions
6
7
8#ifndef INEXACT_H
9#define INEXACT_H
10
11#include "../init.h"
12#include "../types.h"
13#include "../program/program.h"
15#include "../util/utils.h"
16
17using namespace std;
18using Brush::Node;
19using Brush::DataType;
20
21namespace Brush { namespace Simpl{
22
24public:
27
28 void append(const string& key, const ArrayXf& value) {
29 if (storage.find(key) == storage.end())
30 storage[key] = vector<ArrayXf>{value};
31 else
32 storage[key].push_back(value);
33 }
34
35 vector<ArrayXf> getList(const string& key) {
36 auto it = storage.find(key);
37 if (it != storage.end())
38 return it->second;
39
40 static vector<ArrayXf> empty_list;
41 return empty_list;
42 }
43
44 void clear() {
45 storage.clear();
46 }
47
48 vector<string> keys() {
49 vector<string> result;
50 for (const auto& pair : storage) {
51 result.push_back(pair.first);
52 }
53 return result;
54 }
55
56private:
57 map<string, vector<ArrayXf>> storage;
58};
59
61{
62 public:
63 // static Inexact_simplifier* initSimplifier();
64 void initUniformPlanes(int hashSize, int inputDim, int numPlanes);
65
66 // static void destroy();
67
68 // TODO: mode templated stuff to cpp. right now they are in the header
69 // because of the templating, that does not work with testing.
70 // This is happening in inexact, constants, variation.
71 template<Brush::ProgramType PT>
73 const SearchSpace &ss, const Dataset &d)
74 {
75 Program<PT> simplified_program(program);
76
77 // iterate over the tree, trying to replace each node with a constant, and keeping the change if the pred does not change.
78 TreeIter spot = simplified_program.Tree.begin();
79 while(spot != simplified_program.Tree.end())
80 {
81 // we dont index or simplify fixed stuff.
82 // non-wheightable nodes are not simplified. TODO: revisit this and see if they should (then implement it)
83 // This is avoiding using booleans.
84 if (spot.node->data.get_prob_change() > 0
85 // && IsWeighable(spot.node->data.ret_type) && IsWeighable(spot.node->data.node_type)
86 ) {
87 // indexing only small subtrees or non-constant-terminal nodes
88 if (simplified_program.size_at(spot) < 10
89 || Isnt<NodeType::Constant, NodeType::MeanLabel>(spot.node->data.node_type)) {
90 index(spot, d);
91 }
92
94 // res will return the closest within the threshold, so we dont have to check distance here
95 auto res = query(spot, d); // optional<pair<size_t, string>>
96
97 if (res){
98 auto key = res.value(); // table index and hash
99 const tree<Node> branch(spot);
100
101 if (equivalentExpression.find(key) == equivalentExpression.end()) {
102 equivalentExpression[key] = branch;
103 } else if (spot.node->get_size(false) < equivalentExpression[key].begin().node->get_size(false)){
104 equivalentExpression[key] = branch;
105 } else if (spot.node->get_size(false) > equivalentExpression[key].begin().node->get_size(false)){
106 const tree<Node> simplified_branch(equivalentExpression[key]);
107 simplified_program.Tree.erase_children(spot);
108 spot = simplified_program.Tree.move_ontop(spot, simplified_branch.begin());
109 }
110 }
111 }
112 }
113 ++spot;
114 }
115 program.Tree = simplified_program.Tree;
116
117 return simplified_program;
118 }
121
122 // TODO: make index private and initUniformPlanes to add terminals from search space
123 void index(TreeIter& spot, const Dataset &d);
124 private:
125
126 vector<string> hash(const ArrayXf& inputPoint); // one string for each plane
127
128
129 // will return the hash and the distance to the queryPoint
130 optional<pair<size_t, string>> query(TreeIter& spot, const Dataset &d);
131
132 // one storage instance for each datatype/rettype.
133 // the storage will be used to calculate the hash and query the
134 // collection of hashes, returning the closest ones,
135 // and the list will contain equivalent expressions, ordered by size
136 // (or linear complexity). So we dont store pairs in the storage
137 // TODO: improve how I handle different return types (should I use a map?)
141
142 map<pair<size_t, string>, tree<Node>> equivalentExpression;
143
144 vector<MatrixXf> uniformPlanes;
145
146 // private static attribute used by every instance of the class
147 // static Inexact_simplifier* instance;
148};
149
150// static Inexact_simplifier &inexact_simplifier = *Inexact_simplifier::initSimplifier();
151
152} // Simply
153} // Brush
154
155#endif
holds variable type data.
Definition data.h:51
vector< ArrayXf > getList(const string &key)
Definition inexact.h:35
vector< string > keys()
Definition inexact.h:48
void append(const string &key, const ArrayXf &value)
Definition inexact.h:28
map< string, vector< ArrayXf > > storage
Definition inexact.h:57
Program< PT > simplify_tree(Program< PT > &program, const SearchSpace &ss, const Dataset &d)
Definition inexact.h:72
void initUniformPlanes(int hashSize, int inputDim, int numPlanes)
Definition inexact.cpp:28
vector< MatrixXf > uniformPlanes
Definition inexact.h:144
optional< pair< size_t, string > > query(TreeIter &spot, const Dataset &d)
Definition inexact.cpp:109
void index(TreeIter &spot, const Dataset &d)
Definition inexact.cpp:73
vector< string > hash(const ArrayXf &inputPoint)
Definition inexact.cpp:44
map< pair< size_t, string >, tree< Node > > equivalentExpression
Definition inexact.h:142
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
auto Isnt(DataType dt) -> bool
Definition node.h:43
DataType
data types.
Definition types.h:143
tree< Node >::pre_order_iterator TreeIter
STL namespace.
class holding the data for a node in a tree.
Definition node.h:84
An individual program, a.k.a. model.
Definition program.h:50
tree< Node > Tree
fitness
Definition program.h:73
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
Definition program.h:121
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...