Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bind_engines.h
Go to the documentation of this file.
1#include "module.h"
2#include "../engine.h"
3
4#include "../bandit/bandit.h"
6#include "../bandit/dummy.h"
8
9#include "../ind/individual.h"
10#include "../vary/variation.h"
11
12#include "../eval/evaluation.h"
13
14#include "../pop/population.h"
15
18#include "../selection/nsga2.h"
20
21#include "../pop/archive.h"
22
27
28namespace nl = nlohmann;
29namespace br = Brush;
30
31using stream_redirect = py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>;
32
33template<typename T>
34void bind_engine(py::module& m, string name)
35{
36 using RetType = std::conditional_t<
37 std::is_same_v<T,Reg>, ArrayXf,
38 std::conditional_t<std::is_same_v<T,Cls>, ArrayXb,
39 std::conditional_t<std::is_same_v<T,MCls>, ArrayXi, ArrayXXf>>>;
40
41 py::class_<T> engine(m, name.data() );
42 engine.def(py::init<>())
43 .def(py::init([](br::Parameters& p, br::SearchSpace& s){
44 T e(p, s); return e; })
45 )
46 .def_property("params", &T::get_params, &T::set_params)
47 .def_property("search_space", &T::get_search_space, &T::set_search_space)
48 .def_property_readonly("is_fitted", &T::get_is_fitted)
49 .def_property_readonly("best_ind", &T::get_best_ind)
50 .def("fit",
51 static_cast<T &(T::*)(Dataset &d)>(&T::fit),
52 py::call_guard<py::gil_scoped_release>(),
53 "fit from Dataset object")
54 .def("fit",
55 static_cast<T &(T::*)(const Ref<const ArrayXXf> &X, const Ref<const ArrayXf> &y)>(&T::fit),
56 py::call_guard<py::gil_scoped_release>(),
57 "fit from X,y data")
58 .def("predict",
59 static_cast<RetType (T::*)(const Dataset &d)>(&T::predict),
60 "predict from Dataset object")
61 .def("predict",
62 static_cast<RetType (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict),
63 "predict from X data")
64 .def("get_archive", &T::get_archive)
65 .def("get_population", &T::get_population)
66 .def("set_population", &T::set_population)
67 .def("get_archive_as_json", &T::get_archive_as_json)
68 .def("get_population_as_json", &T::get_population_as_json)
69 .def("set_population_from_json", &T::set_population_from_json)
70 .def("lock_nodes",
71 &T::lock_nodes,
72 py::arg("end_depth") = 0,
73 py::arg("keep_leaves_unlocked") = true,
74 py::arg("keep_current_weights") = false,
76 )
77 .def(py::pickle(
78 [](const T &p) { // __getstate__
79 /* Return a tuple that fully encodes the state of the object */
80 // return py::make_tuple(p.value(), p.extra());
81 nl::json j = p;
82 return j;
83 },
84 [](nl::json j) { // __setstate__
85 T p = j;
86 return p;
87 })
88 )
89 ;
90
91 // specialization for subclasses
92 if constexpr (std::is_same_v<T,Cls>)
93 {
94 engine.def("predict_proba",
95 static_cast<ArrayXf (T::*)(const Dataset &d)>(&T::predict_proba),
96 "predict from Dataset object")
97 .def("predict_proba",
98 static_cast<ArrayXf (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict_proba),
99 "predict from X data")
100 ;
101 }
102}
Brush::RepresenterEngine Rep
Brush::MulticlassClassifierEngine MCls
void bind_engine(py::module &m, string name)
Brush::ClassifierEngine Cls
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
Brush::RegressorEngine Reg
holds variable type data.
Definition data.h:51
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition functions.h:25
< nsga2 selection operator for getting the front
Definition bandit.cpp:3
Engine< PT::MulticlassClassifier > MulticlassClassifierEngine
Definition types.h:99
Engine< PT::Regressor > RegressorEngine
Definition types.h:97
Engine< PT::BinaryClassifier > ClassifierEngine
Definition types.h:98
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Engine< PT::Representer > RepresenterEngine
Definition types.h:100
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...