Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
inexact.h
Go to the documentation of this file.
1#ifndef INEXACT_H
2#define INEXACT_H
3
4#include "../init.h"
5#include "../types.h"
8#include "../util/utils.h"
9
10using namespace std;
11using Brush::Node;
12using Brush::DataType;
13
14namespace Brush { namespace Simpl{
15
17public:
18 HashStorage(int numPlanes=10) {
19 storage.clear();
20 storage.reserve(numPlanes);
21 for (int i = 0; i < numPlanes; ++i)
22 storage.push_back(map<size_t, vector<tree<Node>>>());
23 };
25
26 void append(const int& storage_n, const size_t& key, const tree<Brush::Node> Tree) {
27 // we initialize the list of equivalent vectors if it does not exist
28
29 if (storage[storage_n].find(key) == storage[storage_n].end())
30 storage[storage_n][key] = vector<tree<Node>>();
31
32 // if it is smaller we add to the front, otherwise we add to the back.
33 // This way we know the first element is the smallest one, and we don't care
34 // about the order of the rest of the elements.
35 auto& storage_it = storage[storage_n][key];
36
37 // calculating incoming tree's attributes to compare
38 size_t new_size = Tree.begin().node->get_size();
39
40 auto it = storage_it.begin();
41 for (; it != storage_it.end(); ++it) {
42 size_t curr_size = it->begin().node->get_size();
43
44 if (curr_size > new_size) {
45 // found insertion point, we dont need to look beyond this point
46 break;
47 } else if (curr_size == new_size) {
48 // Compare structure + contents
49 auto it1 = it->begin();
50 auto it2 = Tree.begin();
51
52 auto end1 = it->end();
53 auto end2 = Tree.end();
54
55 bool trees_equal = true;
56 for (; it1 != end1 && it2 != end2; ++it1, ++it2) {
57 if (it1.node->data.get_node_hash(false)
58 != it2.node->data.get_node_hash(false) ){
59 trees_equal = false;
60 break;
61 }
62 }
63
64 if (trees_equal && it1 == end1 && it2 == end2) {
65 // both finished at same time and the look was not interrupted earlier.
66 // it means we already have the same exact tree (but maybe with a different coeff).
67 // lets pretend we inserted and just return
68 return;
69 }
70
71 // else keep scanning; insertion will be after last equal-size element
72 }
73 }
74
75 // Insert Tree in order by size, smallest first
76 storage_it.insert(it, Tree);
77 }
78
79 vector<tree<Node>> getList(const int& storage_n, const size_t& key) {
80 auto it = storage[storage_n].find(key);
81 if (it != storage[storage_n].end())
82 return it->second;
83
84 return {};
85 }
86
87 void clear() {
88 int numPlanes = storage.size();
89 storage.clear();
90 storage.resize(numPlanes);
91 }
92
93 vector<size_t> keys(const int& storage_n) {
94 vector<size_t> result;
95 for (const auto& pair : storage[storage_n])
96 result.push_back(pair.first);
97
98 return result;
99 }
100
101 void print(const string& prefix, std::ofstream& log) const {
102 for (size_t plane_idx = 0; plane_idx < storage.size(); ++plane_idx) {
103 for (const auto& kv : storage[plane_idx]) {
104 size_t key = kv.first;
105 const auto& trees = kv.second;
106
107 for (const auto& t : trees) {
108 log << prefix
109 << plane_idx << ","
110 << key << ","
111 << t.begin().node->get_model()
112 << "\n";
113 }
114 }
115 }
116 }
117
118private:
119 // one storage instance for each plane
120 vector<map<size_t, vector<tree<Node>>>> storage;
121};
122
124{
125 public:
126 // static Inexact_simplifier* initSimplifier();
127 void init(int hashSize, const Dataset &data, int numPlanes);
128
129 // static void destroy();
130
131 // iterates through the tree, indexing it's nodes
132 template<ProgramType P>
134 const SearchSpace &ss, const Dataset &d)
135 {
136 // iterate over the tree, trying to replace each node with a constant, and keeping the change if the pred does not change.
137 TreeIter spot = program.Tree.begin_post();
138 while(spot != program.Tree.end_post())
139 {
140 // we dont index or simplify fixed stuff.
141 // non-wheightable nodes are not simplified. TODO: revisit this and see if they should (then implement it)
142 // This is avoiding using booleans.
143 if (spot.node->data.get_prob_change() > 0
144 // && IsWeighable(spot.node->data.ret_type) && IsWeighable(spot.node->data.node_type)
145 ) {
146 // indexing only small subtrees. We don't index constants (the constant simplifier will take
147 // care of them), but we index terminals, as they are weighted and may be added to different
148 // hash collections
149 if (program.size_at(spot, true) <= 30
151 {
152 index<P>(spot, d);
153 // terminals are indexed on initialization
154 }
155 }
156 ++spot;
157 }
158 }
159
160 template<ProgramType P>
162 const SearchSpace &ss, const Dataset &d)
163 {
164 // using RetType =
165 // typename // std::conditional_t<P == PT::Regressor, ArrayXf,
166 // // std::conditional_t<P == PT::Representer, ArrayXXf, ArrayXf
167 // >>;
168
169 analyze_tree(program, ss, d);
170
171 Program<P> simplified_program(program);
172
173 // prediction at the root already performs template cast and always returns a float
174 auto original_predictions = simplified_program.predict(d);
175
176 // iterate over the tree, trying to replace each node with a constant, and keeping the change if the pred does not change.
177 // notice it is a post order iterator.
178 TreeIter spot = simplified_program.Tree.begin_post();
179 while(spot != simplified_program.Tree.end_post())
180 {
181 // we dont index or simplify fixed stuff.
182 // non-wheightable nodes are not simplified. TODO: revisit this and see if they should (then implement it)
183 // This is avoiding using booleans.
184 if (spot.node->data.get_prob_change() > 0
185 // && IsWeighable(spot.node->data.ret_type) && IsWeighable(spot.node->data.node_type)
186 ) {
187 // TODO: use IsLeaf here instead of checking for each possible nodetype. also search throughout the code and replace it
189
190 // res will return the closest within the threshold, so we dont have to check distance here
191
192 auto res = query<P>(spot, d); // optional<pair<int, size_T>>
193
194 if (res){
195 // for each res we replace the subtree and pick the one with the smallest error.
196 // we know they will be smaller because query only returns smaller trees. We also include
197 // the current node in the list of candidates so the model does not get worse.
198
199 float threshold = 1e-5;
200
201 float best_distance = threshold;
202 tree<Node> best_branch;
203 for (const auto& cand : res.value()) {
204
205 const tree<Node> original_branch(spot);
206 const tree<Node> simplified_branch(cand);
207
208 // auto original_predictions = simplified_program.predict(d);
209 // auto spot_pred = spot.node->template predict<spot.node->data.ret_type>(d);
210 // using RetType = decltype(spot_pred);
211
212 simplified_program.Tree.erase_children(spot);
213
214 spot = simplified_program.Tree.move_ontop(spot, simplified_branch.begin());
215
216 auto new_predictions = simplified_program.predict(d);
217
218 float diff = (original_predictions.template cast<float>() - new_predictions.template cast<float>()).square().mean();
219
220 if (diff < best_distance) {
221 best_distance = diff;
222 best_branch = cand;
223 }
224
225 // rollback
226 simplified_program.Tree.erase_children(spot);
227 spot = simplified_program.Tree.move_ontop(spot, original_branch.begin());
228 }
229 if (best_distance < threshold) {
230
231 // cout << "replacing " << spot.node->get_model();
232 simplified_program.Tree.erase_children(spot);
233
234 const tree<Node> best_branch_copy(best_branch);
235 // cout << " with " << best_branch_copy.begin().node->get_model() << endl;
236
237 spot = simplified_program.Tree.move_ontop(spot, best_branch_copy.begin());
238
239 // learning the simplifications made here
240 analyze_tree(simplified_program, ss, d);
241 }
242 }
243 }
244 }
245 ++spot;
246 }
247 program.Tree = simplified_program.Tree;
248
249 return simplified_program;
250 }
251
254
255 template<ProgramType P>
256 void index(TreeIter& spot, const Dataset &d)
257 {
258 const tree<Node> tree_copy(spot);
259
260 // cout << "indexing ...";
261 // cout << tree_copy.begin().node->get_model(true)
262 // << " with datatype " << static_cast<int>(tree_copy.begin().node->data.ret_type) << endl;
263
264 auto hashes = hash<P>(spot, d);
265
266 for (size_t i = 0; i < hashes.size(); ++i)
267 {
268 // hash() will clip the prediction to the inputDim, but here we store the full
269 // predictions so we can calculate the distance to the query point later in query()
270 equivalentExpressions[spot.node->data.ret_type].append(i, hashes[i], tree_copy);
271 }
272 }
273
274 // wrapper to print all equivalentExpressions
275 inline void log_simplification_table(std::ofstream& log) {
276 // print header
277 log << "DataType,Plane,Key,Tree\n";
278
279 for (const auto& kv : equivalentExpressions) {
280 DataType dt = kv.first;
281 const HashStorage& hs = kv.second;
282
283 // prefix is the DataType name + a comma
284 std::string prefix = dt_to_string(dt) + ",";
285 hs.print(prefix, log);
286 }
287 }
288
289 int inputDim = 1000; // default value
290 private:
291 template<ProgramType P>
292 vector<size_t> hash(TreeIter& spot, const Dataset &d)
293 {
294 // returns one hash for each plane
295
296 using RetType =
297 typename std::conditional_t<P == PT::Regressor, ArrayXf,
298 std::conditional_t<P == PT::Representer, ArrayXXf, ArrayXf
299 >>;
300
301 // we cast to float because hash and query are based on matrix multiplications,
302 // but we will store the hash only on the corresponding storage instance
303 ArrayXf floatClippedInput;
304
305 if constexpr (P == PT::Representer) {
306 ArrayXXf inputPoint = (*spot.node).template predict<ArrayXXf>(d);
307 floatClippedInput = Eigen::Map<ArrayXf>(inputPoint.data(), inputPoint.size()).head(inputDim).template cast<float>();
308 } else {
309 if (spot.node->data.ret_type == DataType::ArrayB) {
310 ArrayXb inputPointB = (*spot.node).template predict<ArrayXb>(d);
311 floatClippedInput = inputPointB.template cast<float>();
312 }
313 else if (spot.node->data.ret_type == DataType::ArrayI) {
314 ArrayXi inputPointI = (*spot.node).template predict<ArrayXi>(d);
315 floatClippedInput = inputPointI.template cast<float>();
316 }
317 else {
318 floatClippedInput = (*spot.node).template predict<ArrayXf>(d);
319 }
320 }
321
322 // assert(floatClippedInput.size() >= inputDim &&
323 // "data must have at least inputDim elements");
324
325 // floatClippedInput = floatClippedInput.head(inputDim);
326
327 assert(floatClippedInput.size() == inputDim &&
328 "You need to pass a dataset with inputDim samples to the simplification.");
329
330 // Equalize floatClippedInput
331 float floatClippedInput_mean = floatClippedInput.mean();
332 floatClippedInput = floatClippedInput - floatClippedInput_mean;
333
334 vector<size_t> hashes;
335 for (size_t planeIdx = 0; planeIdx < uniformPlanes.size(); ++planeIdx)
336 {
337 // TODO: handle nan predictions?
338
339 const auto& plane = uniformPlanes[planeIdx];
340 Eigen::ArrayXf projection = plane * floatClippedInput.matrix();
341 Eigen::Array<bool, Eigen::Dynamic, 1> comparison = (projection > 0);
342
343 size_t input_hash = 0;
344 for (int i = 0; i < comparison.size(); ++i) {
345 input_hash <<= 1;
346 input_hash |= comparison(i) ? 1 : 0;
347 }
348
349 hashes.push_back(input_hash);
350 }
351
352 return hashes;
353 }
354
355 template<ProgramType P>
356 optional<vector<tree<Node>>> query(TreeIter& spot, const Dataset &d)
357 {
358 // will return the hash and the distance to the queryPoint.
359
360 int spot_size = spot.node->get_size();
361
362 // first argument is the index of the plane, second is the hash
363 vector<tree<Node>> matches = {};
364
365 vector<size_t> hashes = hash<P>(spot, d);
366
367 for (int i = 0; i < hashes.size(); ++i){
368 // cout << "querying hashes index " << i
369 // << " with datatype " << static_cast<int>(spot.node->data.ret_type) << endl;
370
371 vector<tree<Node>> newCandidates = equivalentExpressions[spot.node->data.ret_type].getList(i, hashes[i]);
372
373 if (newCandidates.size() == 0)
374 continue;
375
376 int count = 0;
377 for (const auto& cand : newCandidates) {
378 if (cand.begin().node->get_size() < spot_size) {
379 matches.push_back(cand);
380 if (++count >= 25) break; // returning only top 10
381 } else {
382 // Since candidates are ordered by size, we can break early
383 break;
384 }
385 }
386 }
387
388 if (matches.size() > 0)
389 return matches;
390 return std::nullopt;
391 }
392
393 inline string dt_to_string(DataType dt) {
394 switch (dt) {
395 case DataType::ArrayB: return "ArrayB";
396 case DataType::ArrayI: return "ArrayI";
397 case DataType::ArrayF: return "ArrayF";
398 case DataType::MatrixB: return "MatrixB";
399 case DataType::MatrixI: return "MatrixI";
400 case DataType::MatrixF: return "MatrixF";
401 case DataType::TimeSeriesB: return "TimeSeriesB";
402 case DataType::TimeSeriesI: return "TimeSeriesI";
403 case DataType::TimeSeriesF: return "TimeSeriesF";
404 case DataType::ArrayBJet: return "ArrayBJet";
405 case DataType::ArrayIJet: return "ArrayIJet";
406 case DataType::ArrayFJet: return "ArrayFJet";
407 case DataType::MatrixBJet: return "MatrixBJet";
408 case DataType::MatrixIJet: return "MatrixIJet";
409 case DataType::MatrixFJet: return "MatrixFJet";
410 case DataType::TimeSeriesBJet: return "TimeSeriesBJet";
411 case DataType::TimeSeriesIJet: return "TimeSeriesIJet";
412 case DataType::TimeSeriesFJet: return "TimeSeriesFJet";
413 }
414 return "Unknown";
415 }
416
417 // one storage instance for each datatype/rettype.
418 // the storage will be used to calculate the hash and query the
419 // collection of hashes, returning the closest ones,
420 // and the list will contain equivalent expressions, ordered by size
421 // (or linear complexity). So we dont store pairs in the storage
422 std::unordered_map<DataType, HashStorage> equivalentExpressions;
423
424 vector<MatrixXf> uniformPlanes;
425};
426
427} // Simply
428} // Brush
429
430#endif
holds variable type data.
Definition data.h:51
vector< map< size_t, vector< tree< Node > > > > storage
Definition inexact.h:120
vector< tree< Node > > getList(const int &storage_n, const size_t &key)
Definition inexact.h:79
void append(const int &storage_n, const size_t &key, const tree< Brush::Node > Tree)
Definition inexact.h:26
void print(const string &prefix, std::ofstream &log) const
Definition inexact.h:101
vector< size_t > keys(const int &storage_n)
Definition inexact.h:93
HashStorage(int numPlanes=10)
Definition inexact.h:18
string dt_to_string(DataType dt)
Definition inexact.h:393
optional< vector< tree< Node > > > query(TreeIter &spot, const Dataset &d)
Definition inexact.h:356
vector< size_t > hash(TreeIter &spot, const Dataset &d)
Definition inexact.h:292
vector< MatrixXf > uniformPlanes
Definition inexact.h:424
void index(TreeIter &spot, const Dataset &d)
Definition inexact.h:256
Program< P > simplify_tree(Program< P > &program, const SearchSpace &ss, const Dataset &d)
Definition inexact.h:161
void analyze_tree(Program< P > &program, const SearchSpace &ss, const Dataset &d)
Definition inexact.h:133
void init(int hashSize, const Dataset &data, int numPlanes)
Definition inexact.cpp:17
std::unordered_map< DataType, HashStorage > equivalentExpressions
Definition inexact.h:422
void log_simplification_table(std::ofstream &log)
Definition inexact.h:275
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
auto Isnt(DataType dt) -> bool
Definition node.h:43
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
DataType
data types.
Definition types.h:143
tree< Node >::pre_order_iterator TreeIter
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
STL namespace.
class holding the data for a node in a tree.
Definition node.h:84
An individual program, a.k.a. model.
Definition program.h:50
TreeType predict(const Dataset &d)
the standard predict function. Returns the output of the Tree directly.
Definition program.h:183
tree< Node > Tree
fitness
Definition program.h:73
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
Definition program.h:121
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...