Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
linear_thompson.cpp
Go to the documentation of this file.
1#include "linear_thompson.h"
2
3namespace Brush {
4namespace MAB {
5
6template <typename T>
8 : BanditOperator<T>(arms)
9{
10 n_arms = arms.size();
11
12 // Initialize Eigen matrices
13 arm_index_to_key.resize(0);
14 B.resize(0);
15 B_inv.resize(0);
16 B_inv_sqrt.resize(0);
17
18 m2_r = MatrixXf::Zero(n_arms, 1);
19 mean = MatrixXf::Zero(n_arms, 1);
20
21 for (int i = 0; i < n_arms; ++i) { // one for each arm
22 arm_index_to_key[i] = arms[i];
23 }
24}
25
26template <typename T>
28 : BanditOperator<T>(arms_probs)
29{
30 n_arms = arms_probs.size();
31
32 B.resize(0);
33 B_inv.resize(0);
34 B_inv_sqrt.resize(0);
35
36 m2_r = MatrixXf::Zero(n_arms, 1);
37 mean = MatrixXf::Zero(n_arms, 1);
38
39 int index = 0;
40 for (const auto& pair : arms_probs) { // making sure we have the same order
41 arm_index_to_key[index++] = pair.first;
42 }
43};
44
45
46template <typename T>
48 // cout << "sampling probs started" << endl;
49 if (update && B.size()>0) // must be called after at least one choose
50 {
51 int context_size = B.at(0).rows();
52
53 MatrixXf w(n_arms, context_size);
54 MatrixXf r = MatrixXf::Random(n_arms, context_size); // TODO: use random generator here
55 for (int i = 0; i < n_arms; ++i) {
56 w.row(i) = B_inv_sqrt[i] * r.row(i); // mat mul
57 }
58
59 w = mean + w;
60
61 VectorXf u(n_arms);
62
63 VectorXf last_context = ArrayXf::Random(context_size);
64
65 u = w * last_context; // mat mul
66
67 float total_prob = 0.0f;
68 for (int i = 0; i < n_arms; ++i) {
69 float prob = std::exp(u(i)) / std::exp(u.maxCoeff());
70 this->probabilities[arm_index_to_key[i]] = prob;
71 total_prob += prob;
72 }
73
74 assert(total_prob > 0 && "Total probability must be greater than zero");
75
76 // Normalize probabilities to ensure they sum to 1
77 for (auto& [k, v] : this->probabilities) {
78 this->probabilities[k] = std::min(this->probabilities[k], 1.0f); // / total_prob
79 }
80 }
81
82 return this->probabilities;
83 // cout << "sampling probs finished" << endl;
84}
85
86template <typename T>
87T LinearThompsonSamplingBandit<T>::choose(const VectorXf& context) {
88 int context_size = context.size();
89
90 if (B.size()==0){
91 // cout << "INITIALIZING BANDIT " << endl;
92
93 for (int i = 0; i < n_arms; ++i) { // one for each arm
94 B.push_back( MatrixXf::Identity(context_size, context_size) );
95 B_inv.push_back( MatrixXf::Identity(context_size, context_size) );
96 B_inv_sqrt.push_back( MatrixXf::Identity(context_size, context_size) );
97 }
98
99 m2_r = MatrixXf::Zero(n_arms, context_size);
100 mean = MatrixXf::Zero(n_arms, context_size);
101 }
102
103 MatrixXf w(n_arms, context_size);
104 MatrixXf r = MatrixXf::Random(n_arms, context_size); // TODO: use random generator here
105 for (int i = 0; i < n_arms; ++i) {
106 w.row(i) = B_inv_sqrt[i] * r.row(i); // mat mul
107 }
108
109 w = mean + w;
110
111 // cout << "w: " << w << endl;
112 VectorXf u(n_arms);
113 u = w * context; // mat mul
114 // cout << "u: " << u << endl;
115
116 Eigen::Index max_index;
117 float max_value = u.maxCoeff(&max_index);
118 // cout << "max_index: " << max_index << ", max_value: " << max_value << endl;
119
120 // cout << "choose finished" << endl;
121 return arm_index_to_key[max_index];
122}
123
124template <typename T>
125void LinearThompsonSamplingBandit<T>::update(T arm, float reward, VectorXf& context) {
126 int context_size = context.size();
127
128 if (B.size()==0){
129 // cout << "INITIALIZING BANDIT " << endl;
130 for (int i = 0; i < n_arms; ++i) { // one for each arm
131 B.push_back( MatrixXf::Identity(context_size, context_size) );
132 B_inv.push_back( MatrixXf::Identity(context_size, context_size) );
133 B_inv_sqrt.push_back( MatrixXf::Identity(context_size, context_size) );
134 }
135
136 m2_r = MatrixXf::Zero(n_arms, context_size);
137 mean = MatrixXf::Zero(n_arms, context_size);
138 }
139
140 // TODO: have a more efficient way of doing this
141 // Find the arm index using our mapping
142 auto it = std::find_if(arm_index_to_key.begin(), arm_index_to_key.end(),
143 [&arm](const auto& pair) { return pair.second == arm; });
144
145 if (it == arm_index_to_key.end()) {
146 throw std::invalid_argument("Arm not found in the arm_index_to_key map");
147 }
148
149 int arm_index = it->first;
150
151 // cout << "Arm index: " << arm_index << endl;
152 // cout << "Context: " << context.size() << endl;
153 // cout << "B[arm_index] before update: " << B[arm_index].size() << endl;
154 // cout << "m2_r.row(arm_index) before update: " << m2_r.row(arm_index).size() << endl;
155
156 B[arm_index] += context * context.transpose();
157 // cout << "B[arm_index] after update: " << B[arm_index].size() << endl;
158
159 m2_r.row(arm_index) += (context * reward).transpose();
160 // cout << "m2_r.row(arm_index) after update: " << m2_r.row(arm_index).size() << endl;
161
162 B_inv[arm_index] = B[arm_index].inverse();
163 // cout << "B_inv[arm_index]: " << B_inv[arm_index].size() << endl;
164
165 B_inv_sqrt[arm_index] = B_inv[arm_index].ldlt().matrixL();
166 // cout << "B_inv_sqrt[arm_index]: " << B_inv_sqrt[arm_index].size() << endl;
167
168 mean.row(arm_index) = B_inv[arm_index] * m2_r.row(arm_index).transpose(); // mat mul
169 // cout << "mean.row(arm_index): " << mean.row(arm_index).size() << endl;
170
171 // cout << "update finished" << endl;
172}
173
174} // MAB
175} // Brush
BanditOperator(vector< T > arms)
Constructs a BanditOperator object with a vector of arms.
std::map< T, float > probabilities
void update(T arm, float reward, VectorXf &context)
Updates the reward for a specific arm.
std::map< T, float > sample_probs(bool update)
Samples the probabilities of the arms.
T choose(const VectorXf &context)
Chooses an arm based on the given tree and fitness. Should call sample_probs internally.
static Rnd & r
Definition rnd.h:174
< nsga2 selection operator for getting the front
Definition bandit.cpp:4