Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bind_programs.h
Go to the documentation of this file.
1#include "module.h"
3
8
9namespace nl = nlohmann;
10namespace br = Brush;
11
12using stream_redirect = py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>;
13
14template<typename T>
15void bind_program(py::module& m, string name)
16{
17 using RetType = std::conditional_t<
18 std::is_same_v<T,Reg>, ArrayXf,
19 std::conditional_t<std::is_same_v<T,Cls>, ArrayXb,
20 std::conditional_t<std::is_same_v<T,MCls>, ArrayXi, ArrayXXf>>>;
21
22 py::class_<T> prog(m, name.data() );
23 prog.def(py::init<>())
24 .def(py::init(
25 [](const json& j){ T p = j; return p; })
26 )
27 .def("fit",
28 static_cast<T &(T::*)(const Dataset &d)>(&T::fit),
29 "fit from Dataset object")
30 .def("fit",
31 static_cast<T &(T::*)(const Ref<const ArrayXXf> &X, const Ref<const ArrayXf> &y)>(&T::fit),
32 "fit from X,y data")
33 .def("predict",
34 static_cast<RetType (T::*)(const Dataset &d)>(&T::predict),
35 "predict from Dataset object")
36 .def("predict",
37 static_cast<RetType (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict),
38 "predict from X data")
39 .def("get_model",
40 &T::get_model,
41 py::arg("type") = "compact",
42 py::arg("pretty") = false,
44 )
45 .def("get_dot_model", &T::get_dot_model, py::arg("extras")="")
46 .def("get_weights", &T::get_weights)
47 .def("size", &T::size, py::arg("include_weight")=true)
48 .def("complexity", &T::complexity)
49 .def("depth", &T::depth)
50 // .def("cross", &T::cross, py::return_value_policy::automatic,
51 // "Performs one attempt to stochastically swap subtrees between two programs and generate a child")
52 // .def("mutate", &T::mutate, py::return_value_policy::automatic,
53 // "Performs one attempt to stochastically mutate the program and generate a child")
54 .def("set_search_space", &T::set_search_space)
55 //.def("copy", &T::copy<>, py::return_value_policy::copy)
56 .def("copy", [](const T& self){ T clone(self); return clone; })
57 .def(py::pickle(
58 [](const T &p) { // __getstate__
59 /* Return a tuple that fully encodes the state of the object */
60 // return py::make_tuple(p.value(), p.extra());
61 nl::json j = p;
62 return j;
63 },
64 [](nl::json j) { // __setstate__
65 T p = j;
66
67 return p;
68 }
69 )
70 )
71 ;
72 if constexpr (std::is_same_v<T,Cls>)
73 {
74 prog.def("predict_proba",
75 static_cast<ArrayXf (T::*)(const Dataset &d)>(&T::predict_proba),
76 "predict from Dataset object")
77 .def("predict_proba",
78 static_cast<ArrayXf (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict_proba),
79 "predict from X data")
80 ;
81 }
82
83}
void bind_engine(py::module &m, string name)
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
void bind_program(py::module &m, string name)
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
holds variable type data.
Definition data.h:51
The Engine class represents the core engine of the brush library.
Definition engine.h:43
< nsga2 selection operator for getting the front
Definition data.cpp:12
Program< PT::Representer > RepresenterProgram
Definition types.h:81
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
Program< PT::BinaryClassifier > ClassifierProgram
Definition types.h:79
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Program< PT::Regressor > RegressorProgram
Definition types.h:78
Program< PT::MulticlassClassifier > MulticlassClassifierProgram
Definition types.h:80