Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bind_engines.h
Go to the documentation of this file.
1#include "module.h"
2#include "../engine.h"
3#include "../engine.cpp"
4
5// TODO: figure out why do I need to include the whole thing (otherwise it gives me symbol errors)
6#include "../bandit/bandit.h"
8#include "../bandit/dummy.h"
10
11#include "../ind/individual.h"
12#include "../ind/individual.cpp"
13#include "../vary/variation.h"
14#include "../vary/variation.cpp"
15
16#include "../eval/evaluation.h"
18
19#include "../pop/population.cpp"
20#include "../pop/population.h"
21
26#include "../selection/nsga2.h"
30
31#include "../pop/archive.cpp"
32#include "../pop/archive.h"
33
38
39namespace nl = nlohmann;
40namespace br = Brush;
41
42using stream_redirect = py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>;
43
44template<typename T>
45void bind_engine(py::module& m, string name)
46{
47 using RetType = std::conditional_t<
48 std::is_same_v<T,Reg>, ArrayXf,
49 std::conditional_t<std::is_same_v<T,Cls>, ArrayXb,
50 std::conditional_t<std::is_same_v<T,MCls>, ArrayXi, ArrayXXf>>>;
51
52 py::class_<T> engine(m, name.data() );
53 engine.def(py::init<>())
54 .def(py::init([](br::Parameters& p, br::SearchSpace& s){
55 T e(p, s); return e; })
56 )
57 .def_property("params", &T::get_params, &T::set_params)
58 .def_property("search_space", &T::get_search_space, &T::set_search_space)
59 .def_property_readonly("is_fitted", &T::get_is_fitted)
60 .def_property_readonly("best_ind", &T::get_best_ind)
61 .def("fit",
62 static_cast<T &(T::*)(Dataset &d)>(&T::fit),
63 py::call_guard<py::gil_scoped_release>(),
64 "fit from Dataset object")
65 .def("fit",
66 static_cast<T &(T::*)(const Ref<const ArrayXXf> &X, const Ref<const ArrayXf> &y)>(&T::fit),
67 py::call_guard<py::gil_scoped_release>(),
68 "fit from X,y data")
69 .def("predict",
70 static_cast<RetType (T::*)(const Dataset &d)>(&T::predict),
71 "predict from Dataset object")
72 .def("predict",
73 static_cast<RetType (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict),
74 "predict from X data")
75 .def("get_archive", &T::get_archive)
76 .def("get_population", &T::get_population)
77 .def("set_population", &T::set_population)
78 .def("get_archive_as_json", &T::get_archive_as_json)
79 .def("get_population_as_json", &T::get_population_as_json)
80 .def("set_population_from_json", &T::set_population_from_json)
81 .def("lock_nodes",
82 &T::lock_nodes,
83 py::arg("end_depth") = 0,
84 py::arg("keep_leaves_unlocked") = true,
86 )
87 .def(py::pickle(
88 [](const T &p) { // __getstate__
89 /* Return a tuple that fully encodes the state of the object */
90 // return py::make_tuple(p.value(), p.extra());
91 nl::json j = p;
92 return j;
93 },
94 [](nl::json j) { // __setstate__
95 T p = j;
96 // TODO: do I need to get the data and ss reference, then call init for this new instance?
97 return p;
98 })
99 )
100 ;
101
102 // specialization for subclasses
103 if constexpr (std::is_same_v<T,Cls>)
104 {
105 engine.def("predict_proba",
106 static_cast<ArrayXf (T::*)(const Dataset &d)>(&T::predict_proba),
107 "predict from Dataset object")
108 .def("predict_proba",
109 static_cast<ArrayXf (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict_proba),
110 "predict from X data")
111 ;
112 }
113}
Brush::RepresenterEngine Rep
Brush::MulticlassClassifierEngine MCls
void bind_engine(py::module &m, string name)
Brush::ClassifierEngine Cls
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
Brush::RegressorEngine Reg
holds variable type data.
Definition data.h:51
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition functions.h:25
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
Engine< PT::MulticlassClassifier > MulticlassClassifierEngine
Definition types.h:99
Engine< PT::Regressor > RegressorEngine
Definition types.h:97
Engine< PT::BinaryClassifier > ClassifierEngine
Definition types.h:98
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Engine< PT::Representer > RepresenterEngine
Definition types.h:100
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...