Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bandit.cpp
Go to the documentation of this file.
1#include "bandit.h"
2
3namespace Brush {
4namespace MAB {
5
7 set_type("dynamic_thompson");
8 set_arms({});
9 set_probs({});
10 set_bandit();
11}
12
13Bandit::Bandit(string type, vector<string> arms) : type(type) {
14 this->set_arms(arms);
15
16 map<string, float> arms_probs;
17 float prob = 1.0 / arms.size();
18 for (const auto& arm : arms) {
19 arms_probs[arm] = prob;
20 }
21 this->set_probs(arms_probs);
22 this->set_bandit();
23}
24
25Bandit::Bandit(string type, map<string, float> arms_probs) : type(type) {
26 this->set_probs(arms_probs);
27
28 vector<string> arms_names;
29 for (const auto& pair : arms_probs) {
30 arms_names.push_back(pair.first);
31 }
32 this->set_arms(arms_names);
33 this->set_bandit();
34}
35
37 if (type == "thompson") {
38 pbandit = make_unique<ThompsonSamplingBandit>(probabilities);
39 } else if (type == "dynamic_thompson") {
40 pbandit = make_unique<ThompsonSamplingBandit>(probabilities, true);
41 } else if (type == "dummy") {
42 pbandit = make_unique<DummyBandit>(probabilities);
43 } else {
44 HANDLE_ERROR_THROW("Undefined Selection Operator " + this->type + "\n");
45 }
46
47 bandit_set = true;
48}
49
51 if (!bandit_set || !pbandit) {
52 HANDLE_ERROR_THROW("Bandit operator is not set. Call set_bandit() before use.\n");
53 }
54}
55
57 return type;
58}
59
60void Bandit::set_type(string type) {
61 this->type = type;
62}
63
64vector<string> Bandit::get_arms() {
65 return arms;
66}
67
68void Bandit::set_arms(vector<string> arms) {
69 this->arms = arms;
70}
71
72map<string, float> Bandit::get_probs() {
73 return probabilities;
74}
75
76void Bandit::set_probs(map<string, float> arms_probs) {
77 probabilities = arms_probs;
78}
79
80map<string, float> Bandit::sample_probs(bool update) {
82 map<string, float> new_probs = this->pbandit->sample_probs(update);
83
84 // making all probabilities strictly positive
85 float eps = 1e-6;
86
87 for (auto& pair : new_probs) {
88 if (pair.second <= 0.0f) {
89 pair.second = eps;
90 }
91 }
92
93 return new_probs;
94}
95
98 return this->pbandit->choose();
99}
100
101void Bandit::update(string arm, float reward) {
103 this->pbandit->update(arm, reward);
104}
105
106} // MAB
107} // Brush
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition bandit.cpp:3
void ensure_bandit_set() const
Definition bandit.cpp:50
void update(string arm, float reward)
Updates the bandit's state based on the chosen arm and the received reward.
Definition bandit.cpp:101
string get_type()
Gets the type of the bandit.
Definition bandit.cpp:56
vector< string > arms
Definition bandit.h:43
vector< string > get_arms()
Gets the arms of the bandit.
Definition bandit.cpp:64
void set_type(string type)
Sets the type of the bandit.
Definition bandit.cpp:60
void set_arms(vector< string > arms)
Sets the arms of the bandit.
Definition bandit.cpp:68
std::map< string, float > probabilities
Definition bandit.h:45
std::string type
Definition bandit.h:42
void set_probs(map< string, float > arms_probs)
Sets the probabilities associated with each arm.
Definition bandit.cpp:76
string choose()
Selects an arm.
Definition bandit.cpp:96
map< string, float > sample_probs(bool update=false)
Samples the probabilities associated with each arm using the policy.
Definition bandit.cpp:80
void set_bandit()
Sets the bandit operator (policy).
Definition bandit.cpp:36
std::shared_ptr< BanditOperator > pbandit
A shared pointer to the bandit operator (policy).
Definition bandit.h:39
map< string, float > get_probs()
Gets the probabilities associated with each arm.
Definition bandit.cpp:72