Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bandit.cpp
Go to the documentation of this file.
1#include "bandit.h"
2#include <typeinfo> // FOR DEBUGGING PURPOSES. TODO: remove it later
3
4namespace Brush {
5namespace MAB {
6
7template <typename T>
9 set_type("dynamic_thompson");
10 set_arms({});
11 set_probs({});
12 set_bandit();
13}
14
15template <typename T>
16Bandit<T>::Bandit(string type, vector<T> arms) : type(type) {
17 this->set_arms(arms);
18
19 map<T, float> arms_probs;
20 float prob = 1.0 / arms.size();
21 for (const auto& arm : arms) {
22 arms_probs[arm] = prob;
23 }
24 this->set_probs(arms_probs);
25 this->set_bandit();
26}
27
28template <typename T>
29Bandit<T>::Bandit(string type, map<T, float> arms_probs) : type(type) {
30 this->set_probs(arms_probs);
31
32 vector<T> arms_names;
33 for (const auto& pair : arms_probs) {
34 arms_names.push_back(pair.first);
35 }
36 this->set_arms(arms_names);
37 this->set_bandit();
38}
39
40template <typename T>
42 // TODO: a flag that is set to true when this function is called. make all
43 // other methods to raise an error if bandit was not set
44 if (type == "thompson") {
45 pbandit = make_unique<ThompsonSamplingBandit<T>>(probabilities);
46 } else if (type == "dynamic_thompson") {
47 pbandit = make_unique<ThompsonSamplingBandit<T>>(probabilities, true);
48 } else if (type == "linear_thompson") {
49 pbandit = make_unique<LinearThompsonSamplingBandit<T>>(probabilities);
50 } else if (type == "dummy") {
51 pbandit = make_unique<DummyBandit<T>>(probabilities);
52 } else {
53 HANDLE_ERROR_THROW("Undefined Selection Operator " + this->type + "\n");
54 }
55}
56
57template <typename T>
59 return type;
60}
61
62template <typename T>
64 this->type = type;
65}
66
67template <typename T>
69 return arms;
70}
71
72template <typename T>
73void Bandit<T>::set_arms(vector<T> arms) {
74 this->arms = arms;
75}
76
77template <typename T>
78map<T, float> Bandit<T>::get_probs() {
80}
81
82template <typename T>
83void Bandit<T>::set_probs(map<T, float> arms_probs) {
84 probabilities = arms_probs;
85}
86
87template <typename T>
88map<T, float> Bandit<T>::sample_probs(bool update) {
89 return this->pbandit->sample_probs(update);
90}
91
92template <typename T>
93T Bandit<T>::choose(const VectorXf& context) {
94 return this->pbandit->choose(context);
95}
96
97template <typename T>
98void Bandit<T>::update(T arm, float reward, VectorXf& context) {
99 this->pbandit->update(arm, reward, context);
101
102template <typename T> template <ProgramType PT>
103VectorXf Bandit<T>::get_context(const Program<PT>& program, Iter spot,
104 const SearchSpace &ss, const Dataset &d) {
105 // TODO: for better performance, get_context should calculate the context only if the
106 // pbandit is of a contextual type. otherwise, return empty stuff
108 // cout << "Inside get_context" << endl;
109 VectorXf context;
110 // -------------------------------------------------------------------------
111 // SECOND APPROACH: prediction vector of the spot node
112 // -------------------------------------------------------------------------
113 if constexpr (PT==ProgramType::Regressor)
114 {
115 // cout << "RegressorProgram detected\n" << endl;
116
117 // use the code below to work with the whole tree prediction -----------
118 ArrayXf out = (*program.Tree.begin().node).template predict<ArrayXf>(d);
119 context = out;
120
121 // predicting the spot node --------------------------------------------
122 // context = (*spot.node).template predict<ArrayXf>(d);
123 }
125 {
126 // cout << "ClassifierProgram detected\n" << endl;
128 // use the code below to work with the whole tree prediction -----------
129 ArrayXf out = (*program.Tree.begin().node).template predict<ArrayXf>(d);
130 context = ArrayXf(out.template cast<float>());
131
132 // predicting the spot node --------------------------------------------
133 // ArrayXf logit = (*spot.node).template predict<ArrayXf>(d);
134 // ArrayXb pred = (logit > 0.5);
135 // context = ArrayXf(pred.template cast<float>());
136 }
137 else if constexpr (PT==ProgramType::MulticlassClassifier)
138 {
139 // cout << "MulticlassClassifierProgram detected\n" << endl;
140
141 // use the code below to work with the whole tree prediction -----------
142 ArrayXXf out = (*program.Tree.begin().node).template predict<ArrayXXf>(d);
143 auto argmax = Function<NodeType::ArgMax>{};
144 context = ArrayXf(argmax(out).template cast<float>());
145
146 // predicting the spot node --------------------------------------------
147 }
148 else if constexpr (PT==ProgramType::Representer)
149 {
150 cout << "MulticlassClassifierProgram detected, not implemented\n" << endl;
151 }
152 else
153 {
154 HANDLE_ERROR_THROW("No predict available for the class.");
155 }
156
157 // -------------------------------------------------------------------------
158 // FIRST APPROACH: label encoding of nodes above/below/on the spot
159 // -------------------------------------------------------------------------
160 // context is 3 times the number of nodes in the search space.
161 // it represents a label encoding of the Tree structure, where
162 // the first third represents number of nodes above the spot,
163 // the second represents the spot, and the third represents
164 // the number of nodes below the spot.
165 // The vector below works as a reference of the nodes.
166
167 // cout << "Tree: " << std::endl;
168 // for (auto it = Tree.begin(); it != Tree.end(); ++it) {
169 // for (int i = 0; i < Tree.depth(it); ++i) {
170 // std::cout << " ";
171 // }
172 // std::cout << (*it).name << std::endl;
173 // }
174
175 // cout << "Spot name: " << (*spot).name << std::endl;
176
177 // size_t tot_operators = ss.op_names.size(); //NodeTypes::Count;
178 // size_t tot_features = 0;
179
180 // for (const auto& pair : ss.terminal_map)
181 // tot_features += pair.second.size();
182
183 // size_t tot_symbols = tot_operators + tot_features;
184
185 // VectorXf context( 3 * tot_symbols );
186 // context.setZero();
187
188 // for (auto it = Tree.begin(); it != Tree.end(); ++it) {
189 // if (Tree.is_valid(it)) {
190 // cout << "Check succeeded for node: " << (*it).name << std::endl;
191 // cout << "Depth of spot: " << Tree.depth(spot) << std::endl;
192 // cout << "Depth of it: " << Tree.depth(it) << std::endl;
193 // cout << "It is the spot, searching for it " << std::endl;
194
195 // // deciding if it is above or below the spot
196 // size_t pos_shift = 0; // above
197 // if (it == spot) { // spot
198 // pos_shift = 1;
199 // }
200 // else if (Tree.is_in_subTree(it, spot)) // below
201 // pos_shift = 2;
202
203 // cout << "Position shift: " << pos_shift << std::endl;
204 // if (Is<NodeType::Terminal, NodeType::Constant, NodeType::MeanLabel>((*it).node_type)){
205 // size_t feature_index = 0;
206
207 // // iterating using terminal_types since it is ordered
208 // for (const auto& terminal : ss.terminal_map.at((*it).ret_type)) {
209 // if (terminal.name == (*it).name) {
210 // // Just one hot encode --------------------------------------
211 // context((tot_operators + feature_index) + pos_shift*tot_symbols) += 1.0;
212
213 // // encode with weights --------------------------------------
214 // // int Tree_complexity = operator_complexities.at((*it).node_type);
215 // // if ((*it).get_is_weighted()
216 // // && Isnt<NodeType::Constant, NodeType::MeanLabel>((*it).node_type) )
217 // // {
218 // // if ((Is<NodeType::OffsetSum>((*it).node_type) && (*it).W != 0.0)
219 // // || ((*it).W != 1.0))
220 // // Tree_complexity = operator_complexities.at(NodeType::Mul) +
221 // // operator_complexities.at(NodeType::Constant) +
222 // // Tree_complexity;
223 // // }
224 // // context((tot_operators + feature_index) + pos_shift*tot_symbols) += static_cast<float>(Tree_complexity);
225
226 // // use recursive evaluation to get the complexity of the subTree
227 // // linear complexity to avoid exponential increase of values
228 // // int complexity = it.node->get_linear_complexity();
229 // // context((tot_operators + feature_index) + pos_shift*tot_symbols) += static_cast<float>(complexity);
230
231 // cout << "Below spot, terminal: " << terminal.name << " at feature index " << feature_index << std::endl;
232 // break;
233 // }
234 // ++feature_index;
235 // }
236 // } else {
237 // auto it_op = std::find(ss.op_names.begin(), ss.op_names.end(), (*it).name);
238 // if (it_op != ss.op_names.end()) {
239 // size_t op_index = std::distance(ss.op_names.begin(), it_op);
240 // context(pos_shift * tot_symbols + op_index) += 1.0;
241 // cout << "Below spot, operator: " << (*it).name << " of index " << pos_shift*tot_symbols + op_index << std::endl;
242 // }
243 // else {
244 // HANDLE_ERROR_THROW("Undefined operator " + (*it).name + "\n");
245 // }
246 // }
247 // }
248 // }
249
250 // cout << "Context part 1: " << context.head(tot_symbols).transpose() << std::endl;
251 // cout << "Context part 2: " << context.segment(tot_symbols, tot_symbols).transpose() << std::endl;
252 // cout << "Context part 3: " << context.tail(tot_symbols).transpose() << std::endl;
253 // -------------------------------------------------------------------------
254
255 return context;
256}
257
258} // MAB
259} // Brush
holds variable type data.
Definition data.h:51
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
ProgramType PT
Definition program.h:40
tree< Node >::pre_order_iterator Iter
Definition bandit.h:35
VectorXf get_context(const Program< PT > &program, Iter spot, const SearchSpace &ss, const Dataset &d)
Definition bandit.cpp:103
T choose(const VectorXf &context)
Selects an arm using the tree and fitness as context.
Definition bandit.cpp:93
std::shared_ptr< BanditOperator< T > > pbandit
A shared pointer to the bandit operator (policy).
Definition bandit.h:41
void set_arms(vector< T > arms)
Sets the arms of the bandit.
Definition bandit.cpp:73
string get_type()
Gets the type of the bandit.
Definition bandit.cpp:58
void set_type(string type)
Sets the type of the bandit.
Definition bandit.cpp:63
vector< T > get_arms()
Gets the arms of the bandit.
Definition bandit.cpp:68
map< T, float > get_probs()
Gets the probabilities associated with each arm.
Definition bandit.cpp:78
std::string type
Definition bandit.h:44
vector< T > arms
Definition bandit.h:45
void update(T arm, float reward, VectorXf &context)
Updates the bandit's state based on the chosen arm and the received reward.
Definition bandit.cpp:98
std::map< T, float > probabilities
Definition bandit.h:47
void set_probs(map< T, float > arms_probs)
Sets the probabilities associated with each arm.
Definition bandit.cpp:83
void set_bandit()
Sets the bandit operator (policy).
Definition bandit.cpp:41
map< T, float > sample_probs(bool update=false)
Samples the probabilities associated with each arm using the policy.
Definition bandit.cpp:88
An individual program, a.k.a. model.
Definition program.h:50
tree< Node > Tree
fitness
Definition program.h:73
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...