Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
engine.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4*/
5
6#ifndef Engine_H
7#define Engine_H
8
9#include "util/rnd.h"
10#include "init.h"
11#include "params.h"
12#include "eval/evaluation.h"
13#include "vary/variation.h"
14#include "pop/population.h"
15#include "pop/archive.h"
16#include "selection/selection.h"
17
18#include "taskflow/taskflow.hpp"
19#include <taskflow/algorithm/for_each.hpp>
20
21namespace Brush
22{
23
24using namespace Pop;
25using namespace Sel;
26using namespace Eval;
27using namespace Var;
28using namespace nlohmann;
29
30template <ProgramType T>
43class Engine{
44public:
46 {
47 this->params = Parameters();
48 this->ss = SearchSpace();
49 };
50
52 {
53 this->params = p;
54 this->ss = s;
55 // TODO: make variation to have a default constructor
56 // this->variator(Variation<T>(params, ss)) ;
57 };
58
60
61 // outputs a progress bar, filled according to @param percentage.
62 void print_progress(float percentage);
64 void print_stats(std::ofstream& log, float fraction);
65 void log_stats(std::ofstream& log);
66
67 // all hyperparameters are controlled by the parameter class. please refer to that to change something
68 inline Parameters& get_params(){return params;}
69 inline void set_params(Parameters& p){params=p;}
70
71 inline SearchSpace& get_search_space() { return ss; }
72 inline void set_search_space(SearchSpace& space) { ss = space; }
73
74 inline bool get_is_fitted(){return is_fitted;}
75
78
80
82 run(data);
83 return *this;
84 };
85 Engine<T> &fit(const Ref<const ArrayXXf>& X, const Ref<const ArrayXf>& y)
86 {
87 // Using constructor 2 to create the dataset
88 Dataset d(X,y,params.feature_names,{},params.feature_types,
89 params.classification,params.validation_size,
90 params.batch_size, params.shuffle_split);
91 return fit(d);
92 };
93
94 auto predict(const Dataset& data) { return this->best_ind.predict(data); };
95 auto predict(const Ref<const ArrayXXf>& X)
96 {
97 Dataset d(X);
98 return predict(d);
99 };
100
101 template <ProgramType P = T>
102 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
103 auto predict_proba(const Dataset &d) { return this->best_ind.predict_proba(d); };
104 template <ProgramType P = T>
105 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
106 auto predict_proba(const Ref<const ArrayXXf>& X)
107 {
108 Dataset d(X);
109 return predict_proba(d);
110 };
111
113 int get_archive_size(){ return this->archive.individuals.size(); };
114
116 vector<json> get_archive(bool front);
117 vector<json> get_population();
118
119 void set_population(vector<json> pop_vector);
120
121 // locking and unlocking parts of the solutions
122 void lock_nodes(int end_depth=0, bool skip_leaves=true);
123 void unlock_nodes(int start_depth=0);
124
126 auto predict_archive(int id, const Dataset& data);
127 auto predict_archive(int id, const Ref<const ArrayXXf>& X);
128
129 template <ProgramType P = T>
130 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
131 auto predict_proba_archive(int id, const Dataset& data);
132 template <ProgramType P = T>
133 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
134 auto predict_proba_archive(int id, const Ref<const ArrayXXf>& X);
135
136 // TODO: predict/predict_proba/archive with longitudinal data
137
139 void run(Dataset &d);
140
141 // TODO: should params and ss be private? (that would require better json handling)
144
148
149 bool is_fitted = false;
150private:
154
156
158
159 void init();
160
162 inline void set_is_fitted(bool f){is_fitted=f;}
163};
164
165// TODO: should I serialize data and search space as well?
166// Only stuff to make new predictions should be serialized
167NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine<PT::Regressor>, params, best_ind, archive, pop, ss, is_fitted);
168NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine<PT::BinaryClassifier>, params, best_ind, archive, pop, ss, is_fitted);
169NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine<PT::MulticlassClassifier>, params, best_ind, archive, pop, ss, is_fitted);
170NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine<PT::Representer>, params, best_ind, archive, pop, ss, is_fitted);
171
172} // Brush
173#endif
holds variable type data.
Definition data.h:51
The Engine class represents the core engine of the brush library.
Definition engine.h:43
auto predict(const Dataset &data)
Definition engine.h:94
void calculate_stats()
Definition engine.cpp:52
auto predict_proba(const Dataset &d)
Definition engine.h:103
SearchSpace & get_search_space()
Definition engine.h:71
Individual< T > best_ind
Definition engine.h:145
void print_progress(float percentage)
Definition engine.cpp:36
vector< json > get_population()
Definition engine.cpp:197
void run(Dataset &d)
train the model
Definition engine.cpp:403
Engine(Parameters &p, SearchSpace &s)
Definition engine.h:51
void lock_nodes(int end_depth=0, bool skip_leaves=true)
Definition engine.cpp:326
Engine< T > & fit(Dataset &data)
Definition engine.h:81
auto predict(const Ref< const ArrayXXf > &X)
Definition engine.h:95
void set_search_space(SearchSpace &space)
Definition engine.h:72
void set_population(vector< json > pop_vector)
Definition engine.cpp:217
Parameters & get_params()
Definition engine.h:68
auto predict_archive(int id, const Ref< const ArrayXXf > &X)
Definition engine.cpp:277
Individual< T > & get_best_ind()
Definition engine.h:79
void set_params(Parameters &p)
Definition engine.h:69
auto predict_proba(const Ref< const ArrayXXf > &X)
Definition engine.h:106
void set_is_fitted(bool f)
set flag indicating whether fit has been called
Definition engine.h:162
int get_archive_size()
return archive size
Definition engine.h:113
void unlock_nodes(int start_depth=0)
Definition engine.cpp:341
Evaluation< T > evaluator
Definition engine.h:152
bool get_is_fitted()
Definition engine.h:74
Engine< T > & fit(const Ref< const ArrayXXf > &X, const Ref< const ArrayXf > &y)
Definition engine.h:85
void init()
initialize Feat object for fitting.
Definition engine.cpp:16
auto predict_proba_archive(int id, const Dataset &data)
Definition engine.cpp:286
auto predict_archive(int id, const Dataset &data)
predict on unseen data from the archive
Definition engine.cpp:246
bool update_best()
updates best score by searching in the population for the individual that best fits the given data
Definition engine.cpp:357
void log_stats(std::ofstream &log)
Definition engine.cpp:119
vector< json > get_archive(bool front)
return archive/population as string
Definition engine.cpp:182
void print_stats(std::ofstream &log, float fraction)
Definition engine.cpp:149
auto predict_proba_archive(int id, const Ref< const ArrayXXf > &X)
Definition engine.cpp:319
Class for evaluating the fitness of individuals in a population.
Definition evaluation.h:27
class for timing things.
Definition utils.h:270
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine< PT::Regressor >, params, best_ind, archive, pop, ss, is_fitted)
The Archive struct represents a collection of individual programs.
Definition archive.h:26
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...
interfaces with selection operators.
Definition selection.h:25