Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bandit.cpp
Go to the documentation of this file.
1#include "bandit.h"
2#include <typeinfo> // FOR DEBUGGING PURPOSES. TODO: remove it later
3
4namespace Brush {
5namespace MAB {
6
8 set_type("dynamic_thompson");
9 set_arms({});
10 set_probs({});
11 set_bandit();
12}
13
14Bandit::Bandit(string type, vector<string> arms) : type(type) {
15 this->set_arms(arms);
16
17 map<string, float> arms_probs;
18 float prob = 1.0 / arms.size();
19 for (const auto& arm : arms) {
20 arms_probs[arm] = prob;
21 }
22 this->set_probs(arms_probs);
23 this->set_bandit();
24}
25
26Bandit::Bandit(string type, map<string, float> arms_probs) : type(type) {
27 this->set_probs(arms_probs);
28
29 vector<string> arms_names;
30 for (const auto& pair : arms_probs) {
31 arms_names.push_back(pair.first);
32 }
33 this->set_arms(arms_names);
34 this->set_bandit();
35}
36
38 // TODO: a flag that is set to true when this function is called. make all
39 // other methods to raise an error if bandit was not set
40 if (type == "thompson") {
41 pbandit = make_unique<ThompsonSamplingBandit>(probabilities);
42 } else if (type == "dynamic_thompson") {
43 pbandit = make_unique<ThompsonSamplingBandit>(probabilities, true);
44 } else if (type == "dummy") {
45 pbandit = make_unique<DummyBandit>(probabilities);
46 } else {
47 HANDLE_ERROR_THROW("Undefined Selection Operator " + this->type + "\n");
48 }
49}
50
52 return type;
53}
54
55void Bandit::set_type(string type) {
56 this->type = type;
57}
58
59vector<string> Bandit::get_arms() {
60 return arms;
61}
62
63void Bandit::set_arms(vector<string> arms) {
64 this->arms = arms;
65}
66
67map<string, float> Bandit::get_probs() {
68 return probabilities;
69}
70
71void Bandit::set_probs(map<string, float> arms_probs) {
72 probabilities = arms_probs;
73}
74
75map<string, float> Bandit::sample_probs(bool update) {
76 map<string, float> new_probs = this->pbandit->sample_probs(update);
77
78 // making all probabilities strictly positive
79 float eps = 1e-6;
80
81 for (auto& pair : new_probs) {
82 if (pair.second <= 0.0f) {
83 pair.second = eps;
84 }
85 }
86
87 return new_probs;
88}
89
91 return this->pbandit->choose();
92}
93
94void Bandit::update(string arm, float reward) {
95 this->pbandit->update(arm, reward);
96}
97
98} // MAB
99} // Brush
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
void update(string arm, float reward)
Updates the bandit's state based on the chosen arm and the received reward.
Definition bandit.cpp:94
string get_type()
Gets the type of the bandit.
Definition bandit.cpp:51
vector< string > arms
Definition bandit.h:43
vector< string > get_arms()
Gets the arms of the bandit.
Definition bandit.cpp:59
void set_type(string type)
Sets the type of the bandit.
Definition bandit.cpp:55
void set_arms(vector< string > arms)
Sets the arms of the bandit.
Definition bandit.cpp:63
std::map< string, float > probabilities
Definition bandit.h:45
std::string type
Definition bandit.h:42
void set_probs(map< string, float > arms_probs)
Sets the probabilities associated with each arm.
Definition bandit.cpp:71
string choose()
Selects an arm.
Definition bandit.cpp:90
map< string, float > sample_probs(bool update=false)
Samples the probabilities associated with each arm using the policy.
Definition bandit.cpp:75
void set_bandit()
Sets the bandit operator (policy).
Definition bandit.cpp:37
std::shared_ptr< BanditOperator > pbandit
A shared pointer to the bandit operator (policy).
Definition bandit.h:39
map< string, float > get_probs()
Gets the probabilities associated with each arm.
Definition bandit.cpp:67