Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
inexact.h
Go to the documentation of this file.
1#ifndef INEXACT_H
2#define INEXACT_H
3
4#include "../init.h"
5#include "../types.h"
8#include "../util/utils.h"
9
10using namespace std;
11using Brush::Node;
12using Brush::DataType;
13
14namespace Brush { namespace Simpl{
15
17public:
18 HashStorage(int numPlanes=10) {
19 storage.clear();
20 storage.reserve(numPlanes);
21 for (int i = 0; i < numPlanes; ++i)
22 storage.push_back(map<size_t, vector<tree<Node>>>());
23 };
25
26 void append(const int& storage_n, const size_t& key, const tree<Brush::Node> Tree) {
27 // we initialize the list of equivalent vectors if it does not exist
28
29 if (storage[storage_n].find(key) == storage[storage_n].end())
30 storage[storage_n][key] = vector<tree<Node>>();
31
32 // if it is smaller we add to the front, otherwise we add to the back.
33 // This way we know the first element is the smallest one, and we don't care
34 // about the order of the rest of the elements.
35 auto& storage_it = storage[storage_n][key];
36
37 // calculating incoming tree's attributes to compare
38 size_t new_size = Tree.begin().node->get_size();
39
40 auto it = storage_it.begin();
41 for (; it != storage_it.end(); ++it) {
42 size_t curr_size = it->begin().node->get_size();
43
44 if (curr_size > new_size) {
45 // found insertion point, we dont need to look beyond this point
46 break;
47 } else if (curr_size == new_size) {
48 // Compare structure + contents
49 auto it1 = it->begin();
50 auto it2 = Tree.begin();
51
52 auto end1 = it->end();
53 auto end2 = Tree.end();
54
55 bool trees_equal = true;
56 for (; it1 != end1 && it2 != end2; ++it1, ++it2) {
57 if (it1.node->data.get_node_hash(false)
58 != it2.node->data.get_node_hash(false) ){
59 trees_equal = false;
60 break;
61 }
62 }
63
64 if (trees_equal && it1 == end1 && it2 == end2) {
65 // both finished at same time and the look was not interrupted earlier.
66 // it means we already have the same exact tree (but maybe with a different coeff).
67 // lets pretend we inserted and just return
68 return;
69 }
70
71 // else keep scanning; insertion will be after last equal-size element
72 }
73 }
74
75 // Insert Tree in order by size, smallest first
76 storage_it.insert(it, Tree);
77 }
78
79 vector<tree<Node>> getList(const int& storage_n, const size_t& key) {
80 auto it = storage[storage_n].find(key);
81 if (it != storage[storage_n].end())
82 return it->second;
83
84 return {};
85 }
86
87 void clear() {
88 int numPlanes = storage.size();
89 storage.clear();
90 storage.resize(numPlanes);
91 }
92
93 vector<size_t> keys(const int& storage_n) {
94 vector<size_t> result;
95 for (const auto& pair : storage[storage_n])
96 result.push_back(pair.first);
97
98 return result;
99 }
100
101 void print(const string& prefix, std::ofstream& log) const {
102 for (size_t plane_idx = 0; plane_idx < storage.size(); ++plane_idx) {
103 for (const auto& kv : storage[plane_idx]) {
104 size_t key = kv.first;
105 const auto& trees = kv.second;
106
107 for (const auto& t : trees) {
108 log << prefix
109 << plane_idx << ","
110 << key << ","
111 << t.begin().node->get_model()
112 << "\n";
113 }
114 }
115 }
116 }
117
118private:
119 // one storage instance for each plane
120 vector<map<size_t, vector<tree<Node>>>> storage;
121};
122
124{
125 public:
126 // static Inexact_simplifier* initSimplifier();
127 void init(int hashSize, const Dataset &data, int numPlanes);
128
129 // static void destroy();
130
131 // iterates through the tree, indexing it's nodes
132 template<ProgramType P>
134 const SearchSpace &ss, const Dataset &d)
135 {
136 // iterate over the tree, trying to replace each node with a constant, and keeping the change if the pred does not change.
137 TreeIter spot = program.Tree.begin_post();
138 while(spot != program.Tree.end_post())
139 {
140 // we dont index or simplify fixed stuff.
141 // non-wheightable nodes are not simplified. TODO: revisit this and see if they should (then implement it)
142 // This is avoiding using booleans.
143 if (spot.node->data.get_prob_change() > 0
144 // && IsWeighable(spot.node->data.ret_type) && IsWeighable(spot.node->data.node_type)
145 ) {
146 // indexing only small subtrees. We don't index constants (the constant simplifier will take
147 // care of them), but we index terminals, as they are weighted and may be added to different
148 // hash collections
149 if (program.size_at(spot, true) <= 30
151 {
152 index<P>(spot, d);
153 // terminals are already indexed on initialization
154 }
155 }
156 ++spot;
157 }
158 }
159
160 template<ProgramType P>
162 const SearchSpace &ss, const Dataset &d)
163 {
164 // using RetType =
165 // typename // std::conditional_t<P == PT::Regressor, ArrayXf,
166 // // std::conditional_t<P == PT::Representer, ArrayXXf, ArrayXf
167 // >>;
168
169 analyze_tree(program, ss, d);
170
171 Program<P> simplified_program(program);
172
173 // prediction at the root already performs template cast and always returns a float
174 auto original_predictions = simplified_program.predict(d);
175
176 // iterate over the tree, trying to replace each node with a constant, and keeping the change if the pred does not change.
177 // notice it is a post order iterator.
178 TreeIter spot = simplified_program.Tree.begin_post();
179 while(spot != simplified_program.Tree.end_post())
180 {
181 // we dont index or simplify fixed stuff.
182 // non-wheightable nodes are not simplified. TODO: revisit this and see if they should (then implement it)
183 // This is avoiding using booleans.
184 // we do not simplify branches with fixed weights, because the simplification ignores the weight (it uses normalized predictions)
185 if (spot.node->data.get_prob_change() > 0 && !spot.node->data.weight_is_fixed
186 // && IsWeighable(spot.node->data.ret_type) && IsWeighable(spot.node->data.node_type)
187 ) {
188 // TODO: use IsLeaf here instead of checking for each possible nodetype. also search throughout the code and replace it
190
191 // res will return the closest within the threshold, so we dont have to check distance here
192
193 auto res = query<P>(spot, d); // optional<pair<int, size_T>>
194
195 if (res){
196 // for each res we replace the subtree and pick the one with the smallest error.
197 // we know they will be smaller because query only returns smaller trees. We also include
198 // the current node in the list of candidates so the model does not get worse.
199
200 float threshold = 1e-5;
201
202 float best_distance = threshold;
203 tree<Node> best_branch;
204 for (const auto& cand : res.value()) {
205
206 const tree<Node> original_branch(spot);
207 const tree<Node> simplified_branch(cand);
208
209 // auto original_predictions = simplified_program.predict(d);
210 // auto spot_pred = spot.node->template predict<spot.node->data.ret_type>(d);
211 // using RetType = decltype(spot_pred);
212
213 simplified_program.Tree.erase_children(spot);
214
215 spot = simplified_program.Tree.move_ontop(spot, simplified_branch.begin());
216
217 auto new_predictions = simplified_program.predict(d);
218
219 float diff = (original_predictions.template cast<float>() - new_predictions.template cast<float>()).square().mean();
220
221 if (diff < best_distance) {
222 best_distance = diff;
223 best_branch = cand;
224 }
225
226 // rollback
227 simplified_program.Tree.erase_children(spot);
228 spot = simplified_program.Tree.move_ontop(spot, original_branch.begin());
229 }
230 if (best_distance < threshold) {
231
232 // cout << "replacing " << spot.node->get_model();
233 simplified_program.Tree.erase_children(spot);
234
235 const tree<Node> best_branch_copy(best_branch);
236 // cout << " with " << best_branch_copy.begin().node->get_model() << endl;
237
238 spot = simplified_program.Tree.move_ontop(spot, best_branch_copy.begin());
239
240 // learning the simplifications made here
241 analyze_tree(simplified_program, ss, d);
242 }
243 }
244 }
245 }
246 ++spot;
247 }
248 program.Tree = simplified_program.Tree;
249
250 return simplified_program;
251 }
252
255
256 template<ProgramType P>
257 void index(TreeIter& spot, const Dataset &d)
258 {
259 const tree<Node> tree_copy(spot);
260
261 // cout << "indexing ...";
262 // cout << tree_copy.begin().node->get_model(true)
263 // << " with datatype " << static_cast<int>(tree_copy.begin().node->data.ret_type) << endl;
264
265 auto hashes = hash<P>(spot, d);
266
267 for (size_t i = 0; i < hashes.size(); ++i)
268 {
269 // hash() will clip the prediction to the inputDim, but here we store the full
270 // predictions so we can calculate the distance to the query point later in query()
271 equivalentExpressions[spot.node->data.ret_type].append(i, hashes[i], tree_copy);
272 }
273 }
274
275 // wrapper to print all equivalentExpressions
276 inline void log_simplification_table(std::ofstream& log) {
277 // print header
278 log << "DataType,Plane,Key,Tree\n";
279
280 for (const auto& kv : equivalentExpressions) {
281 DataType dt = kv.first;
282 const HashStorage& hs = kv.second;
283
284 // prefix is the DataType name + a comma
285 std::string prefix = dt_to_string(dt) + ",";
286 hs.print(prefix, log);
287 }
288 }
289
290 int inputDim = 1000; // default value
291 private:
292 template<ProgramType P>
293 vector<size_t> hash(TreeIter& spot, const Dataset &d)
294 {
295 // returns one hash for each plane
296
297 using RetType =
298 typename std::conditional_t<P == PT::Regressor, ArrayXf,
299 std::conditional_t<P == PT::Representer, ArrayXXf, ArrayXf
300 >>;
301
302 // we cast to float because hash and query are based on matrix multiplications,
303 // but we will store the hash only on the corresponding storage instance
304 ArrayXf floatClippedInput;
305
306 if constexpr (P == PT::Representer) {
307 ArrayXXf inputPoint = (*spot.node).template predict<ArrayXXf>(d);
308 floatClippedInput = Eigen::Map<ArrayXf>(inputPoint.data(), inputPoint.size()).head(inputDim).template cast<float>();
309 } else {
310 if (spot.node->data.ret_type == DataType::ArrayB) {
311 ArrayXb inputPointB = (*spot.node).template predict<ArrayXb>(d);
312 floatClippedInput = inputPointB.template cast<float>();
313 }
314 else if (spot.node->data.ret_type == DataType::ArrayI) {
315 ArrayXi inputPointI = (*spot.node).template predict<ArrayXi>(d);
316 floatClippedInput = inputPointI.template cast<float>();
317 }
318 else {
319 floatClippedInput = (*spot.node).template predict<ArrayXf>(d);
320 }
321 }
322
323 // assert(floatClippedInput.size() >= inputDim &&
324 // "data must have at least inputDim elements");
325
326 // floatClippedInput = floatClippedInput.head(inputDim);
327
328 assert(floatClippedInput.size() == inputDim &&
329 "You need to pass a dataset with inputDim samples to the simplification.");
330
331 // Equalize floatClippedInput
332 float floatClippedInput_mean = floatClippedInput.mean();
333
334 // Check for NaN/Inf in mean - if present, skip simplification for this node.
335 // Otherwise, we are at changes of having wrong simplifications and terrible
336 // replacements due to numeric error.
337 if (std::isnan(floatClippedInput_mean) || std::isinf(floatClippedInput_mean)) {
338 return {}; // Return empty hashes to skip this node
339 }
340
341 floatClippedInput = floatClippedInput - floatClippedInput_mean;
342
343 // No need to check for NaN/Inf in the normalized predictions ---
344 // the mean is already a valid numeric value.
345
346 vector<size_t> hashes;
347 for (size_t planeIdx = 0; planeIdx < uniformPlanes.size(); ++planeIdx)
348 {
349 const auto& plane = uniformPlanes[planeIdx];
350 Eigen::ArrayXf projection = plane * floatClippedInput.matrix();
351 Eigen::Array<bool, Eigen::Dynamic, 1> comparison = (projection > 0);
352
353 size_t input_hash = 0;
354 for (int i = 0; i < comparison.size(); ++i) {
355 input_hash <<= 1;
356 input_hash |= comparison(i) ? 1 : 0;
357 }
358
359 hashes.push_back(input_hash);
360 }
361
362 return hashes;
363 }
364
365 template<ProgramType P>
366 optional<vector<tree<Node>>> query(TreeIter& spot, const Dataset &d)
367 {
368 // will return the hash and the distance to the queryPoint.
369
370 int spot_size = spot.node->get_size();
371
372 // first argument is the index of the plane, second is the hash
373 vector<tree<Node>> matches = {};
374
375 vector<size_t> hashes = hash<P>(spot, d);
376
377 for (int i = 0; i < hashes.size(); ++i){
378 // cout << "querying hashes index " << i
379 // << " with datatype " << static_cast<int>(spot.node->data.ret_type) << endl;
380
381 vector<tree<Node>> newCandidates = equivalentExpressions[spot.node->data.ret_type].getList(i, hashes[i]);
382
383 if (newCandidates.size() == 0)
384 continue;
385
386 int count = 0;
387 for (const auto& cand : newCandidates) {
388 if (cand.begin().node->get_size() < spot_size) {
389 matches.push_back(cand);
390 if (++count >= 25) break; // returning only top 10
391 } else {
392 // Since candidates are ordered by size, we can break early
393 break;
394 }
395 }
396 }
397
398 if (matches.size() > 0)
399 return matches;
400 return std::nullopt;
401 }
402
403 inline string dt_to_string(DataType dt) {
404 switch (dt) {
405 case DataType::ArrayB: return "ArrayB";
406 case DataType::ArrayI: return "ArrayI";
407 case DataType::ArrayF: return "ArrayF";
408 case DataType::MatrixB: return "MatrixB";
409 case DataType::MatrixI: return "MatrixI";
410 case DataType::MatrixF: return "MatrixF";
411 case DataType::TimeSeriesB: return "TimeSeriesB";
412 case DataType::TimeSeriesI: return "TimeSeriesI";
413 case DataType::TimeSeriesF: return "TimeSeriesF";
414 case DataType::ArrayBJet: return "ArrayBJet";
415 case DataType::ArrayIJet: return "ArrayIJet";
416 case DataType::ArrayFJet: return "ArrayFJet";
417 case DataType::MatrixBJet: return "MatrixBJet";
418 case DataType::MatrixIJet: return "MatrixIJet";
419 case DataType::MatrixFJet: return "MatrixFJet";
420 case DataType::TimeSeriesBJet: return "TimeSeriesBJet";
421 case DataType::TimeSeriesIJet: return "TimeSeriesIJet";
422 case DataType::TimeSeriesFJet: return "TimeSeriesFJet";
423 }
424 return "Unknown";
425 }
426
427 // one storage instance for each datatype/rettype.
428 // the storage will be used to calculate the hash and query the
429 // collection of hashes, returning the closest ones,
430 // and the list will contain equivalent expressions, ordered by size
431 // (or linear complexity). So we dont store pairs in the storage
432 std::unordered_map<DataType, HashStorage> equivalentExpressions;
433
434 vector<MatrixXf> uniformPlanes;
435};
436
437} // Simply
438} // Brush
439
440#endif
holds variable type data.
Definition data.h:51
vector< map< size_t, vector< tree< Node > > > > storage
Definition inexact.h:120
vector< tree< Node > > getList(const int &storage_n, const size_t &key)
Definition inexact.h:79
void append(const int &storage_n, const size_t &key, const tree< Brush::Node > Tree)
Definition inexact.h:26
void print(const string &prefix, std::ofstream &log) const
Definition inexact.h:101
vector< size_t > keys(const int &storage_n)
Definition inexact.h:93
HashStorage(int numPlanes=10)
Definition inexact.h:18
string dt_to_string(DataType dt)
Definition inexact.h:403
optional< vector< tree< Node > > > query(TreeIter &spot, const Dataset &d)
Definition inexact.h:366
vector< size_t > hash(TreeIter &spot, const Dataset &d)
Definition inexact.h:293
vector< MatrixXf > uniformPlanes
Definition inexact.h:434
void index(TreeIter &spot, const Dataset &d)
Definition inexact.h:257
Program< P > simplify_tree(Program< P > &program, const SearchSpace &ss, const Dataset &d)
Definition inexact.h:161
void analyze_tree(Program< P > &program, const SearchSpace &ss, const Dataset &d)
Definition inexact.h:133
void init(int hashSize, const Dataset &data, int numPlanes)
Definition inexact.cpp:19
std::unordered_map< DataType, HashStorage > equivalentExpressions
Definition inexact.h:432
void log_simplification_table(std::ofstream &log)
Definition inexact.h:276
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
auto Isnt(DataType dt) -> bool
Definition node.h:48
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
DataType
data types.
Definition types.h:143
tree< Node >::pre_order_iterator TreeIter
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
STL namespace.
class holding the data for a node in a tree.
Definition node.h:89
An individual program, a.k.a. model.
Definition program.h:50
TreeType predict(const Dataset &d)
the standard predict function. Returns the output of the Tree directly.
Definition program.h:215
tree< Node > Tree
fitness
Definition program.h:73
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
Definition program.h:122
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...