88 int context_size = context.size();
93 for (
int i = 0; i <
n_arms; ++i) {
94 B.push_back( MatrixXf::Identity(context_size, context_size) );
95 B_inv.push_back( MatrixXf::Identity(context_size, context_size) );
96 B_inv_sqrt.push_back( MatrixXf::Identity(context_size, context_size) );
103 MatrixXf w(
n_arms, context_size);
104 MatrixXf
r = MatrixXf::Random(
n_arms, context_size);
105 for (
int i = 0; i <
n_arms; ++i) {
116 Eigen::Index max_index;
117 float max_value = u.maxCoeff(&max_index);
126 int context_size = context.size();
130 for (
int i = 0; i <
n_arms; ++i) {
131 B.push_back( MatrixXf::Identity(context_size, context_size) );
132 B_inv.push_back( MatrixXf::Identity(context_size, context_size) );
133 B_inv_sqrt.push_back( MatrixXf::Identity(context_size, context_size) );
143 [&arm](
const auto& pair) { return pair.second == arm; });
146 throw std::invalid_argument(
"Arm not found in the arm_index_to_key map");
149 int arm_index = it->first;
156 B[arm_index] += context * context.transpose();
159 m2_r.row(arm_index) += (context * reward).transpose();
162 B_inv[arm_index] =
B[arm_index].inverse();
168 mean.row(arm_index) =
B_inv[arm_index] *
m2_r.row(arm_index).transpose();