Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bind_programs.h
Go to the documentation of this file.
1#include "module.h"
3
8
9namespace nl = nlohmann;
10namespace br = Brush;
11
12using stream_redirect = py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>;
13
14template<typename T>
15void bind_program(py::module& m, string name)
16{
17 using RetType = std::conditional_t<
18 std::is_same_v<T,Reg>, ArrayXf,
19 std::conditional_t<std::is_same_v<T,Cls>, ArrayXb,
20 std::conditional_t<std::is_same_v<T,MCls>, ArrayXi, ArrayXXf>>>;
21
22 py::class_<T> prog(m, name.data() );
23 prog.def(py::init<>())
24 .def(py::init(
25 [](const json& j){ T p = j; return p; })
26 )
27 .def("fit",
28 static_cast<T &(T::*)(const Dataset &d)>(&T::fit),
29 "fit from Dataset object")
30 .def("fit",
31 static_cast<T &(T::*)(const Ref<const ArrayXXf> &X, const Ref<const ArrayXf> &y)>(&T::fit),
32 "fit from X,y data")
33 .def("predict",
34 static_cast<RetType (T::*)(const Dataset &d)>(&T::predict),
35 "predict from Dataset object")
36 .def("predict",
37 static_cast<RetType (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict),
38 "predict from X data")
39 .def("lock_nodes",
40 &T::lock_nodes,
41 py::arg("end_depth") = 0,
42 py::arg("keep_leaves_unlocked") = true,
44 )
45 .def("get_model",
46 &T::get_model,
47 py::arg("type") = "compact",
48 py::arg("pretty") = false,
50 )
51 .def("get_dot_model", &T::get_dot_model, py::arg("extras")="")
52 .def("get_weights", &T::get_weights)
53 .def("size", &T::size, py::arg("include_weight")=true)
54 .def("complexity", &T::complexity)
55 .def("linear_complexity", &T::linear_complexity)
56 .def("depth", &T::depth)
57 // .def("cross", &T::cross, py::return_value_policy::automatic,
58 // "Performs one attempt to stochastically swap subtrees between two programs and generate a child")
59 // .def("mutate", &T::mutate, py::return_value_policy::automatic,
60 // "Performs one attempt to stochastically mutate the program and generate a child")
61 .def("set_search_space", &T::set_search_space)
62 //.def("copy", &T::copy<>, py::return_value_policy::copy)
63 .def("copy", [](const T& self){ T clone(self); return clone; })
64 .def(py::pickle(
65 [](const T &p) { // __getstate__
66 /* Return a tuple that fully encodes the state of the object */
67 // return py::make_tuple(p.value(), p.extra());
68 nl::json j = p;
69 return j;
70 },
71 [](nl::json j) { // __setstate__
72 T p = j;
73
74 return p;
75 }
76 )
77 )
78 ;
79 if constexpr (std::is_same_v<T,Cls>)
80 {
81 prog.def("predict_proba",
82 static_cast<ArrayXf (T::*)(const Dataset &d)>(&T::predict_proba),
83 "predict from Dataset object")
84 .def("predict_proba",
85 static_cast<ArrayXf (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict_proba),
86 "predict from X data")
87 ;
88 }
89
90}
Brush::RepresenterEngine Rep
Brush::MulticlassClassifierEngine MCls
Brush::ClassifierEngine Cls
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
Brush::RegressorEngine Reg
void bind_program(py::module &m, string name)
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
holds variable type data.
Definition data.h:51
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition functions.h:25
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
Program< PT::Representer > RepresenterProgram
Definition types.h:81
Program< PT::BinaryClassifier > ClassifierProgram
Definition types.h:79
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Program< PT::Regressor > RegressorProgram
Definition types.h:78
Program< PT::MulticlassClassifier > MulticlassClassifierProgram
Definition types.h:80