Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bind_engines.h
Go to the documentation of this file.
1#include "module.h"
2#include "../engine.h"
3#include "../engine.cpp"
4
5// TODO: figure out why do I need to include the whole thing (otherwise it gives me symbol errors)
6#include "../bandit/bandit.h"
10#include "../bandit/dummy.h"
11#include "../bandit/dummy.cpp"
12#include "../bandit/thompson.h"
14
15#include "../ind/individual.h"
16#include "../ind/individual.cpp"
17#include "../vary/variation.h"
18#include "../vary/variation.cpp"
19
20#include "../eval/evaluation.h"
22
23#include "../pop/population.cpp"
24#include "../pop/population.h"
25
30#include "../selection/nsga2.h"
34
35#include "../pop/archive.cpp"
36#include "../pop/archive.h"
37
42
43namespace nl = nlohmann;
44namespace br = Brush;
45
46using stream_redirect = py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>;
47
48template<typename T>
49void bind_engine(py::module& m, string name)
50{
51 using RetType = std::conditional_t<
52 std::is_same_v<T,Reg>, ArrayXf,
53 std::conditional_t<std::is_same_v<T,Cls>, ArrayXb,
54 std::conditional_t<std::is_same_v<T,MCls>, ArrayXi, ArrayXXf>>>;
55
56 py::class_<T> engine(m, name.data() );
57 engine.def(py::init<>())
58 .def(py::init([](br::Parameters& p, br::SearchSpace& s){
59 T e(p, s); return e; })
60 )
61 .def_property("params", &T::get_params, &T::set_params)
62 .def_property("search_space", &T::get_search_space, &T::set_search_space)
63 .def_property_readonly("is_fitted", &T::get_is_fitted)
64 .def_property_readonly("best_ind", &T::get_best_ind)
65 .def("fit",
66 static_cast<T &(T::*)(Dataset &d)>(&T::fit),
67 py::call_guard<py::gil_scoped_release>(),
68 "fit from Dataset object")
69 .def("fit",
70 static_cast<T &(T::*)(const Ref<const ArrayXXf> &X, const Ref<const ArrayXf> &y)>(&T::fit),
71 py::call_guard<py::gil_scoped_release>(),
72 "fit from X,y data")
73 .def("predict",
74 static_cast<RetType (T::*)(const Dataset &d)>(&T::predict),
75 "predict from Dataset object")
76 .def("predict",
77 static_cast<RetType (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict),
78 "predict from X data")
79 .def("predict_archive",
80 static_cast<RetType (T::*)(int id, const Dataset &d)>(&T::predict_archive),
81 "predict from individual in archive")
82 .def("predict_archive",
83 static_cast<RetType (T::*)(int id, const Ref<const ArrayXXf> &X)>(&T::predict_archive),
84 "predict from individual in archive")
85 .def("get_archive", &T::get_archive, py::arg("front") = false)
86 .def("get_population", &T::get_population)
87 .def("set_population", &T::set_population)
88 .def("lock_nodes",
89 &T::lock_nodes,
90 py::arg("end_depth") = 0,
91 py::arg("skip_leaves") = true,
93 )
94 .def("unlock_nodes",
95 &T::unlock_nodes,
96 py::arg("start_depth") = 0,
98 )
99 .def(py::pickle(
100 [](const T &p) { // __getstate__
101 /* Return a tuple that fully encodes the state of the object */
102 // return py::make_tuple(p.value(), p.extra());
103 nl::json j = p;
104 return j;
105 },
106 [](nl::json j) { // __setstate__
107 T p = j;
108 // TODO: do I need to get the data and ss reference, then call init for this new instance?
109 return p;
110 })
111 )
112 ;
113
114 // specialization for subclasses
115 if constexpr (std::is_same_v<T,Cls>)
116 {
117 engine.def("predict_proba",
118 static_cast<ArrayXf (T::*)(const Dataset &d)>(&T::predict_proba),
119 "predict from Dataset object")
120 .def("predict_proba",
121 static_cast<ArrayXf (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict_proba),
122 "predict from X data")
123 .def("predict_proba_archive",
124 static_cast<ArrayXf (T::*)(int id, const Dataset &d)>(&T::predict_proba_archive),
125 "predict from individual in archive")
126 .def("predict_proba_archive",
127 static_cast<ArrayXf (T::*)(int id, const Ref<const ArrayXXf> &X)>(&T::predict_proba_archive),
128 "predict from individual in archive")
129
130 ;
131 }
132}
Brush::RepresenterEngine Rep
Brush::MulticlassClassifierEngine MCls
void bind_engine(py::module &m, string name)
Brush::ClassifierEngine Cls
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
Brush::RegressorEngine Reg
holds variable type data.
Definition data.h:51
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition functions.h:25
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
Engine< PT::MulticlassClassifier > MulticlassClassifierEngine
Definition types.h:99
Engine< PT::Regressor > RegressorEngine
Definition types.h:97
Engine< PT::BinaryClassifier > ClassifierEngine
Definition types.h:98
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Engine< PT::Representer > RepresenterEngine
Definition types.h:100
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...