Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bind_engines.h
Go to the documentation of this file.
1#include "module.h"
2#include "../engine.h"
3#include "../engine.cpp"
4
5// TODO: figure out why do I need to include the whole thing (otherwise it gives me symbol errors)
10#include "../selection/nsga2.h"
14
15#include "../eval/evaluation.h"
17
18#include "../pop/population.cpp"
19#include "../pop/population.h"
20
21#include "../pop/archive.cpp"
22#include "../pop/archive.h"
23
28
29namespace nl = nlohmann;
30namespace br = Brush;
31
32using stream_redirect = py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>;
33
34template<typename T>
35void bind_engine(py::module& m, string name)
36{
37 using RetType = std::conditional_t<
38 std::is_same_v<T,Reg>, ArrayXf,
39 std::conditional_t<std::is_same_v<T,Cls>, ArrayXb,
40 std::conditional_t<std::is_same_v<T,MCls>, ArrayXi, ArrayXXf>>>;
41
42 py::class_<T> engine(m, name.data() );
43 engine.def(py::init<>())
44 .def(py::init([](br::Parameters& p){ T e(p);
45 return e; })
46 )
47 .def_property("params", &T::get_params, &T::set_params)
48 .def_property_readonly("is_fitted", &T::get_is_fitted)
49 .def_property_readonly("best_ind", &T::get_best_ind)
50 // .def("run", &T::run, py::call_guard<py::gil_scoped_release>(), "run from brush dataset")
51 .def("fit",
52 static_cast<T &(T::*)(Dataset &d)>(&T::fit),
53 py::call_guard<py::gil_scoped_release>(),
54 "fit from Dataset object")
55 .def("fit",
56 static_cast<T &(T::*)(const Ref<const ArrayXXf> &X, const Ref<const ArrayXf> &y)>(&T::fit),
57 py::call_guard<py::gil_scoped_release>(),
58 "fit from X,y data")
59 .def("predict",
60 static_cast<RetType (T::*)(const Dataset &d)>(&T::predict),
61 "predict from Dataset object")
62 .def("predict",
63 static_cast<RetType (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict),
64 "predict from X data")
65 .def("predict_archive",
66 static_cast<RetType (T::*)(int id, const Dataset &d)>(&T::predict_archive),
67 "predict from individual in archive")
68 .def("predict_archive",
69 static_cast<RetType (T::*)(int id, const Ref<const ArrayXXf> &X)>(&T::predict_archive),
70 "predict from individual in archive")
71 .def("get_archive", &T::get_archive, py::arg("front") = false)
72 .def(py::pickle(
73 [](const T &p) { // __getstate__
74 /* Return a tuple that fully encodes the state of the object */
75 // return py::make_tuple(p.value(), p.extra());
76 nl::json j = p;
77 return j;
78 },
79 [](nl::json j) { // __setstate__
80 T p = j;
81 return p;
82 })
83 )
84 ;
85
86 // specialization for subclasses
87 if constexpr (std::is_same_v<T,Cls>)
88 {
89 engine.def("predict_proba",
90 static_cast<ArrayXf (T::*)(const Dataset &d)>(&T::predict_proba),
91 "predict from Dataset object")
92 .def("predict_proba",
93 static_cast<ArrayXf (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict_proba),
94 "predict from X data")
95 .def("predict_proba_archive",
96 static_cast<ArrayXf (T::*)(int id, const Dataset &d)>(&T::predict_proba_archive),
97 "predict from individual in archive")
98 .def("predict_proba_archive",
99 static_cast<ArrayXf (T::*)(int id, const Ref<const ArrayXXf> &X)>(&T::predict_proba_archive),
100 "predict from individual in archive")
101
102 ;
103 }
104}
void bind_engine(py::module &m, string name)
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
holds variable type data.
Definition data.h:51
The Engine class represents the core engine of the brush library.
Definition engine.h:43
< nsga2 selection operator for getting the front
Definition data.cpp:12
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
Engine< PT::MulticlassClassifier > MulticlassClassifierEngine
Definition types.h:99
Engine< PT::Regressor > RegressorEngine
Definition types.h:97
Engine< PT::BinaryClassifier > ClassifierEngine
Definition types.h:98
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Engine< PT::Representer > RepresenterEngine
Definition types.h:100