Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
inexact.cpp
Go to the documentation of this file.
1#include "inexact.h"
2
3
4// simplification maps are based on trainingdata
5// should ignore fixed nodes ---> does not change subtrees if they contain fixed nodes
6// should I implement json serialization?
7
8namespace Brush { namespace Simpl{
9
10// Inexact_simplifier* Inexact_simplifier::instance = NULL;
11
15
16// Inexact_simplifier* Inexact_simplifier::initSimplifier()
17// {
18// // creates the static random generator by calling the constructor
19// if (!instance)
20// {
21// instance = new Inexact_simplifier();
22// }
23
24// return instance;
25// }
26
27
28void Inexact_simplifier::initUniformPlanes(int hashSize, int inputDim, int numPlanes)
29{
30 // TODO: inputDim cutoff at 100 datapoints?
31 // int numPlanes = 1? The bigger the number of planes, the more accurate the hash, but the slower the search
32
33 uniformPlanes.clear();
34
35 // create random planes
36 for (int i=0; i<numPlanes; ++i)
37 {
38 MatrixXf plane = MatrixXf::Random(hashSize, inputDim);
39 // plane /= plane.norm();
40 uniformPlanes.push_back(plane);
41 }
42}
43
44vector<string> Inexact_simplifier::hash(const ArrayXf& inputPoint)
45{
46 vector<string> hashes;
47 for (size_t planeIdx = 0; planeIdx < uniformPlanes.size(); ++planeIdx)
48 {
49 // TODO: handle nan predictions here and on the index functions
50 // cout << "Processing plane " << planeIdx << std::endl;
51 const auto& plane = uniformPlanes[planeIdx];
52 ArrayXf projections = (plane * inputPoint.matrix());
53 // cout << "Projections: " << projections.transpose() << std::endl;
54
55 ArrayXb comparison = (projections.array() > 0);
56 // cout << "Comparisons: " << comparison.transpose() << std::endl;
57
58 string hashString = ""; // TODO: size_t instead of string
59 // hashString.reserve(hashSize);
60
61 for (bool v : comparison){
62 // cout << v << ", ";
63 hashString += v ? "1" : "0";
64 }
65
66 // cout << std::endl << "Generated hash string: " << hashString << std::endl;
67 hashes.push_back(hashString);
68 }
69 // cout << "Returning hashes" << std::endl;
70 return hashes;
71}
72
74{
75 // we cast to float because hash and query are based on matrix multiplications,
76 // but we will store the hash only on the corresponding storage instance
77 // cout << "Predicting node value" << std::endl;
78 ArrayXf v_float; // = (*spot.node).template predict<ArrayXf>(d);
79 if (spot.node->data.ret_type==DataType::ArrayB) { // TODO: make this function templated?
80 auto temp = (*spot.node).predict<ArrayXb>(d);
81 v_float = temp.template cast<float>();
82 }
83 else if (spot.node->data.ret_type==DataType::ArrayI) {
84 auto temp = (*spot.node).predict<ArrayXi>(d);
85 v_float = temp.template cast<float>();
86 } else { // otherwise we store it as floats
87 v_float = (*spot.node).template predict<ArrayXf>(d);
88 }
89
90 // cout << "Hashing node value" << std::endl;
91 auto hashes = hash(v_float);
92 for (size_t i = 0; i < hashes.size(); ++i)
93 {
94 // cout << "Processing hash " << i << ": " << hashes[i] << std::endl;
95 if (spot.node->data.ret_type==DataType::ArrayB) {
96 // cout << "Appending to storageBool" << std::endl;
97 storageBool.append(hashes[i], v_float);
98 } else if (spot.node->data.ret_type==DataType::ArrayI) {
99 // cout << "Appending to storageInt" << std::endl;
100 storageInt.append(hashes[i], v_float);
101 } else { // otherwise we store it as floats
102 // cout << "Appending to storageFloat" << std::endl;
103 storageFloat.append(hashes[i], v_float); // TODO: should throw an error
104 }
105 }
106}
107
108// will return the hash and the distance to the queryPoint
109optional<pair<size_t, string>> Inexact_simplifier::query(TreeIter& spot, const Dataset &d)
110{
111 float threshold = 1e-8; // TODO: calculate threshold based on variance of dataset
112
113 // TODO: this block below filling v_float is repeated in the function above. Maybe I should implement it in a separate function?
114 ArrayXf v_float; // = (*spot.node).template predict<ArrayXf>(d);
115 if (spot.node->data.ret_type==DataType::ArrayB) { // TODO: make this function templated?
116 auto temp = (*spot.node).predict<ArrayXb>(d);
117 v_float = temp.template cast<float>();
118 }
119 else if (spot.node->data.ret_type==DataType::ArrayI) {
120 auto temp = (*spot.node).predict<ArrayXi>(d);
121 v_float = temp.template cast<float>();
122 } else { // otherwise we store it as floats
123 v_float = (*spot.node).template predict<ArrayXf>(d);
124 }
125
126 vector<pair<size_t, string>> candidates;
127 vector<float> distances;
128
129 HashStorage *storage;
130 if (spot.node->data.ret_type==DataType::ArrayB) {
131 storage = (&storageBool);
132 // cout << "Using storageBool" << std::endl;
133 } else if (spot.node->data.ret_type==DataType::ArrayI) {
134 storage = (&storageInt);
135 // cout << "Using storageInt" << std::endl;
136 } else { // otherwise we store it as floats
137 storage = (&storageFloat);
138 // cout << "Using storageFloat" << std::endl;
139 }
140 // TODO: should throw an error if no storage matches
141
142 vector<string> hashes = hash(v_float);
143 // cout << "Hashes: ";
144 for (const auto& h : hashes) {
145 // cout << h << " ";
146 }
147 // cout << std::endl;
148
149 for (size_t i = 0; i < hashes.size(); ++i){
150 auto newCandidates = storage->getList(hashes[i]);
151 // cout << "Candidates for hash " << hashes[i] << ": " << newCandidates.size() << std::endl;
152
153 for (const auto& cand : newCandidates) {
154 float d = (v_float - cand).array().pow(2).mean();
155 if (std::isnan(d) || std::isinf(d))
156 d = MAX_FLT;
157
158 // cout << "Distance: " << d << std::endl;
159
160 if (d<threshold){
161 candidates.push_back(make_pair(i, hashes[i]));
162 distances.push_back(d);
163 // cout << "Candidate added with distance: " << d << std::endl;
164 }
165 }
166 }
167
168 if (distances.size() > 0){
169 auto min_idx = std::distance(std::begin(distances),
170 std::min_element(std::begin(distances), std::end(distances)));
171 // cout << "Minimum distance index: " << min_idx << std::endl;
172 return candidates[min_idx];
173 } else {
174 // cout << "No candidates found within threshold" << std::endl;
175 }
176
177 return std::nullopt;
178}
179
180// void Inexact_simplifier::destroy()
181// {
182// if (instance)
183// delete instance;
184
185// instance = NULL;
186// }
187
189
190} // Simply
191} // Brush
holds variable type data.
Definition data.h:51
vector< ArrayXf > getList(const string &key)
Definition inexact.h:35
void initUniformPlanes(int hashSize, int inputDim, int numPlanes)
Definition inexact.cpp:28
vector< MatrixXf > uniformPlanes
Definition inexact.h:144
optional< pair< size_t, string > > query(TreeIter &spot, const Dataset &d)
Definition inexact.cpp:109
void index(TreeIter &spot, const Dataset &d)
Definition inexact.cpp:73
vector< string > hash(const ArrayXf &inputPoint)
Definition inexact.cpp:44
static float MAX_FLT
Definition init.h:61
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
tree< Node >::pre_order_iterator TreeIter