Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bind_programs.h
Go to the documentation of this file.
1#include "module.h"
3
8
9namespace nl = nlohmann;
10namespace br = Brush;
11
12using stream_redirect = py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>;
13
14template<typename T>
15void bind_program(py::module& m, string name)
16{
17 using RetType = std::conditional_t<
18 std::is_same_v<T,Reg>, ArrayXf,
19 std::conditional_t<std::is_same_v<T,Cls>, ArrayXb,
20 std::conditional_t<std::is_same_v<T,MCls>, ArrayXi, ArrayXXf>>>;
21
22 py::class_<T> prog(m, name.data() );
23 prog.def(py::init<>())
24 .def(py::init(
25 [](const json& j){ T p = j; return p; })
26 )
27 .def("fit",
28 static_cast<T &(T::*)(const Dataset &d)>(&T::fit),
29 "fit from Dataset object")
30 .def("fit",
31 static_cast<T &(T::*)(const Ref<const ArrayXXf> &X, const Ref<const ArrayXf> &y)>(&T::fit),
32 "fit from X,y data")
33 .def("replace_program",
34 static_cast<T &(T::*)(const T&)>(&T::replace_program),
35 py::arg("new_program"),
36 "Replace the current program with a new program, invalidating fitness")
37 .def("replace_program",
38 static_cast<T &(T::*)(const json&)>(&T::replace_program),
39 py::arg("json_program"),
40 "Replace the current program from a JSON representation, invalidating fitness")
41 .def("predict",
42 static_cast<RetType (T::*)(const Dataset &d)>(&T::predict),
43 "predict from Dataset object")
44 .def("predict",
45 static_cast<RetType (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict),
46 "predict from X data")
47 .def("lock_nodes",
48 &T::lock_nodes,
49 py::arg("end_depth") = 0,
50 py::arg("keep_leaves_unlocked") = true,
51 py::arg("keep_current_weights") = false,
53 )
54 .def("get_model",
55 &T::get_model,
56 py::arg("type") = "compact",
57 py::arg("pretty") = false,
59 )
60 .def("get_dot_model", &T::get_dot_model, py::arg("extras")="")
61 .def("get_weights", &T::get_weights)
62 .def("size", &T::size, py::arg("include_weight")=true)
63 .def("complexity", &T::complexity)
64 .def("linear_complexity", &T::linear_complexity)
65 .def("depth", &T::depth)
66 // .def("cross", &T::cross, py::return_value_policy::automatic,
67 // "Performs one attempt to stochastically swap subtrees between two programs and generate a child")
68 // .def("mutate", &T::mutate, py::return_value_policy::automatic,
69 // "Performs one attempt to stochastically mutate the program and generate a child")
70 .def("set_search_space", &T::set_search_space)
71 //.def("copy", &T::copy<>, py::return_value_policy::copy)
72 .def("copy", [](const T& self){ T clone(self); return clone; })
73 .def(py::pickle(
74 [](const T &p) { // __getstate__
75 /* Return a tuple that fully encodes the state of the object */
76 // return py::make_tuple(p.value(), p.extra());
77 nl::json j = p;
78 return j;
79 },
80 [](nl::json j) { // __setstate__
81 T p = j;
82
83 return p;
84 }
85 )
86 )
87 ;
88 if constexpr (std::is_same_v<T,Cls>)
89 {
90 prog.def("predict_proba",
91 static_cast<ArrayXf (T::*)(const Dataset &d)>(&T::predict_proba),
92 "predict from Dataset object")
93 .def("predict_proba",
94 static_cast<ArrayXf (T::*)(const Ref<const ArrayXXf> &X)>(&T::predict_proba),
95 "predict from X data")
96 ;
97 }
98
99}
Brush::RepresenterEngine Rep
Brush::MulticlassClassifierEngine MCls
Brush::ClassifierEngine Cls
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
Brush::RegressorEngine Reg
void bind_program(py::module &m, string name)
py::call_guard< py::scoped_ostream_redirect, py::scoped_estream_redirect > stream_redirect
holds variable type data.
Definition data.h:51
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition functions.h:25
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
Program< PT::Representer > RepresenterProgram
Definition types.h:81
Program< PT::BinaryClassifier > ClassifierProgram
Definition types.h:79
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
Program< PT::Regressor > RegressorProgram
Definition types.h:78
Program< PT::MulticlassClassifier > MulticlassClassifierProgram
Definition types.h:80