Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
thompson.cpp
Go to the documentation of this file.
1#include "thompson.h"
2
3namespace Brush {
4namespace MAB {
5
6ThompsonSamplingBandit::ThompsonSamplingBandit(vector<string> arms, bool dynamic)
7 : BanditOperator(arms)
8 , dynamic_update(dynamic)
9{
10 for (const auto& arm : arms) {
11 alphas[arm] = 2;
12 betas[arm] = 2;
13 }
14}
15
16ThompsonSamplingBandit::ThompsonSamplingBandit(map<string, float> arms_probs, bool dynamic)
17 : BanditOperator(arms_probs)
18 , dynamic_update(dynamic)
19{
20 for (const auto& pair : arms_probs) {
21 alphas[pair.first] = 2;
22 betas[pair.first] = 2;
23 }
24};
25
26
27std::map<string, float> ThompsonSamplingBandit::sample_probs(bool update) {
28 // gets sampling probabilities using the bandit
29
30 if (update) {
31 // 1. use a beta distribution based on alphas and betas to sample probabilities
32 // 2. normalize probabilities so the sum is 1
33
34 for (const auto& pair : this->probabilities) {
35 string arm = pair.first;
36
37 float prob = r.rnd_alpha_beta(alphas[arm], betas[arm]);
38
39 // avoiding deadlocks when sampling from search space
40 this->probabilities[arm] = std::max(std::min(prob, 1.0f), 0.001f);
41 }
42
43 // assert that the sum is not zero
44 float totalProb = 0.0f;
45 for (const auto& pair : this->probabilities) {
46 totalProb += pair.second;
47 }
48 assert(totalProb != 0.0f && "Sum of probabilities is zero!");
49 }
50
51 return this->probabilities;
52}
53
55 std::map<string, float> probs = this->sample_probs(true);
56
57 return r.random_choice(probs);
58}
59
60void ThompsonSamplingBandit::update(string arm, float reward) {
61 // reward must be either 0 or 1
62
63 alphas[arm] += reward;
64 betas[arm] += 1.0f-reward;
65
66 if (dynamic_update && alphas[arm] + betas[arm] >= C)
67 {
68 alphas[arm] *= C/(C+1.0f) ;
69 betas[arm] *= C/(C+1.0f) ;
70 }
71}
72
73} // MAB
74} // Brush
std::map< string, float > probabilities
BanditOperator(vector< string > arms)
Constructs a BanditOperator object with a vector of arms.
string choose()
Chooses an arm based on the given tree and fitness. Should call sample_probs internally.
Definition thompson.cpp:54
std::map< string, float > sample_probs(bool update)
Samples the probabilities of the arms.
Definition thompson.cpp:27
void update(string arm, float reward)
Updates the reward for a specific arm.
Definition thompson.cpp:60
std::map< string, float > alphas
Definition thompson.h:25
ThompsonSamplingBandit(vector< string > arms, bool dynamic=false)
Definition thompson.cpp:6
std::map< string, float > betas
Definition thompson.h:26
static Rnd & r
Definition rnd.h:176
< nsga2 selection operator for getting the front
Definition bandit.cpp:4