Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
bind_dataset.cpp
Go to the documentation of this file.
1#include "module.h"
2#include "../data/data.h"
3#include "../types.h"
4#include "../data/io.h"
5namespace py = pybind11;
6namespace br = Brush;
7namespace nl = nlohmann;
8
9void bind_dataset(py::module & m)
10{
11 py::class_<br::Data::Dataset>(m, "Dataset")
12 // construct from X, feature names (and optional validation and batch sizes) with constructor 3.
13 .def(py::init([](const Ref<const ArrayXXf>& X,
14 const vector<string>& feature_names=vector<string>(),
15 const bool c=false,
16 const float validation_size=0.0,
17 const float batch_size=1.0){
18 return br::Data::Dataset(
19 X, feature_names, c, validation_size, batch_size);
20 }),
21 py::arg("X"),
22 py::arg("feature_names") = vector<string>(),
23 py::arg("c") = false,
24 py::arg("validation_size") = 0.0,
25 py::arg("batch_size") = 1.0
26 )
27 // construct from X, y, feature names (and optional validation and batch sizes) with constructor 2.
28 .def(py::init([](const Ref<const ArrayXXf>& X,
29 const Ref<const ArrayXf>& y,
30 const vector<string>& feature_names=vector<string>(),
31 const bool c=false,
32 const float validation_size=0.0,
33 const float batch_size=1.0){
34 return br::Data::Dataset(
35 X, y, feature_names, {}, c, validation_size, batch_size);
36 }),
37 py::arg("X"),
38 py::arg("y"),
39 py::arg("feature_names") = vector<string>(),
40 py::arg("c") = false,
41 py::arg("validation_size") = 0.0,
42 py::arg("batch_size") = 1.0
43 )
44 // construct from X, feature names, but copying the feature types from a
45 // reference dataset with constructor 4. Useful for predicting (specially
46 // because the user can provide a single element matrix, or an array with
47 // no feature names).
48 .def(py::init([](const Ref<const ArrayXXf>& X,
49 const br::Data::Dataset& ref_dataset,
50 const vector<string>& feature_names,
51 const bool c=false){
52 return br::Data::Dataset(X, ref_dataset, feature_names, c);
53 }),
54 py::arg("X"),
55 py::arg("ref_dataset"),
56 py::arg("feature_names"),
57 py::arg("c") = false
58 )
59
60 .def_readwrite("y", &br::Data::Dataset::y)
61 // .def_readwrite("features", &br::Data::Dataset::features)
62 .def("get_n_samples", &br::Data::Dataset::get_n_samples)
63 .def("get_n_features", &br::Data::Dataset::get_n_features)
64 .def("print", &br::Data::Dataset::print)
65 .def("get_batch", &br::Data::Dataset::get_batch)
66 .def("get_training_data", &br::Data::Dataset::get_training_data)
67 .def("get_validation_data", &br::Data::Dataset::get_validation_data)
68 .def("get_batch_size", &br::Data::Dataset::get_batch_size)
69 .def("set_batch_size", &br::Data::Dataset::set_batch_size)
70 .def("split", &br::Data::Dataset::split)
71 .def("get_X", &br::Data::Dataset::get_X)
72 ;
73
74 m.def("read_csv", &br::Data::read_csv, py::arg("path"), py::arg("target"), py::arg("sep")=',');
75}
void bind_dataset(py::module &m)
< nsga2 selection operator for getting the front
Definition data.cpp:12