Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
program.h
Go to the documentation of this file.
1/* Brush
2
3copyright 2020 William La Cava
4license: GNU/GPL v3
5*/
6#ifndef PROGRAM_H
7#define PROGRAM_H
8//external includes
9
10//
11
12#include <string>
13#include "assert.h"
14
15// internal includes
16
17// #include "data/data.h"
18#include "../init.h"
19#include "tree_node.h"
20#include "node.h"
22#include "../params.h"
23#include "../util/utils.h"
24#include "functions.h"
25// #include "../variation.h"
26// #include "weight_optimizer.h"
27
28
29using std::cout;
30using std::string;
33
34namespace Brush {
35
36
37typedef tree<Node>::pre_order_iterator Iter;
38typedef tree<Node>::post_order_iterator PostIter;
39
41
42// for unsupervised learning, classification and regression.
43
49template<PT PType> struct Program
50{
52 static constexpr PT program_type = PType;
53
55 using RetType = typename std::conditional_t<PType == PT::Regressor, ArrayXf,
56 std::conditional_t<PType == PT::BinaryClassifier, ArrayXb,
57 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXi,
58 std::conditional_t<PType == PT::Representer, ArrayXXf, ArrayXf
59 >>>>;
60
62 using TreeType = std::conditional_t<PType == PT::BinaryClassifier, ArrayXf,
63 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXXf,
64 RetType>>;
65
68
70 // Fitness fitness;
71
75 std::optional<std::reference_wrapper<SearchSpace>> SSref;
76
77 Program() = default;
78 Program(const std::reference_wrapper<SearchSpace> s, const tree<Node> t)
79 : Tree(t)
80 {
81 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
82 }
83
84 Program<PType> copy() { return Program<PType>(*this); }
85
86 inline void set_search_space(const std::reference_wrapper<SearchSpace> s)
87 {
88 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
89 }
90
93 int complexity() const{
94 auto head = Tree.begin();
95
96 return head.node->get_complexity();
97 }
98
102 int size(bool include_weight=true) const{
103 auto head = Tree.begin();
104
105 return head.node->get_size(include_weight);
106 }
107
113 int size_at(Iter& top, bool include_weight=true) const{
114
115 return top.node->get_size(include_weight);
116 }
117
120 int depth() const{
121 //tree.hh count the number of edges. We need to ensure that a single-node
122 //tree has depth>0
123 return 1+Tree.max_depth();
124 }
125
130 int depth_at(Iter& top) const{
131 return 1+Tree.max_depth(top);
132 }
133
139 return 1+Tree.depth(top);
140 }
141
143 {
144 TreeType out = Tree.begin().node->fit<TreeType>(d);
145 this->is_fitted_ = true;
147 // this->valid = true;
148 return *this;
149 };
150
151 template <typename R, typename W>
152 R predict_with_weights(const Dataset &d, const W** weights)
153 {
154 if (!is_fitted_)
155 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
156
157 return Tree.begin().node->predict<R>(d, weights);
158 };
159
160 auto predict_with_weights(const Dataset &d, const ArrayXf& weights)
161 {
162 float const * wptr = weights.data();
163 return this->predict_with_weights<RetType>(d, &wptr);
164 };
165
174 template <typename R = RetType>
176 {
177 if (!is_fitted_)
178 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
179
180 return Tree.begin().node->predict<TreeType>(d);
181 };
182
187 template <typename R = RetType>
189 {
190 if (!is_fitted_)
191 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
192 return (Tree.begin().node->predict<TreeType>(d) > 0.5);
193 };
194
199 template <typename R = RetType>
201 {
202 if (!is_fitted_)
203 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
204 TreeType out = Tree.begin().node->predict<TreeType>(d);
206 return argmax(out);
207 };
208
209 // template <typename R = RetType>
210 template <PT P = PType>
211 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
213 {
214 return predict<TreeType>(d);
215 };
216
222 {
223 Dataset d(X,y);
224 return fit(d);
225 };
226
231 {
232 Dataset d(X);
233 return predict(d);
234 };
235
243 template <PT P = PType>
244 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
246 {
247 Dataset d(X);
248 return predict_proba(d);
249 };
250
257
259 int get_n_weights() const
260 {
261 int count=0;
262 // check tree nodes for weights
263 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
264 {
265 const auto& node = i.node->data;
266 if (node.get_is_weighted())
267 ++count;
268 }
269 return count;
270 }
271
277 ArrayXf get_weights()
278 {
279 ArrayXf weights(get_n_weights());
280 int i = 0;
281 for (PostIter t = Tree.begin_post(); t != Tree.end_post(); ++t)
282 {
283 const auto& node = t.node->data;
284 if (node.get_is_weighted())
285 {
286 weights(i) = node.W;
287 ++i;
288 }
289 }
290 return weights;
291 }
292
299 void set_weights(const ArrayXf& weights)
300 {
301 // take the weights set them in the tree.
302 // return the weights of the tree as an array
303 if (weights.size() != get_n_weights())
304 HANDLE_ERROR_THROW("Tried to set_weights of incorrect size");
305 int j = 0;
306 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
307 {
308 auto& node = i.node->data;
309 if (node.get_is_weighted())
310 {
311 node.W = weights(j);
312 ++j;
313 }
314 }
315 }
328 string get_model(string fmt="compact", bool pretty=false) const
329 {
330 auto head = Tree.begin();
331 if (fmt=="tree")
332 return head.node->get_tree_model(pretty);
333 else if (fmt=="dot")
334 return get_dot_model(); ;
335 return head.node->get_model(pretty);
336 }
337
344 string get_dot_model(string extras="") const
345 {
346 // TODO: make the node names their hash or index, and the node label the nodetype name.
347 // ref: https://stackoverflow.com/questions/10579041/graphviz-create-new-node-with-this-same-label#10579155
348 string out = "digraph G {\n";
349 if (! extras.empty())
350 out += fmt::format("{}\n", extras);
351
352 auto get_id = [](const auto& n){
353 if (Is<NodeType::Terminal>(n->data.node_type))
354 return n->data.get_name(false);
355
356 return fmt::format("{}",fmt::ptr(n)).substr(2);
357 };
358 // bool first = true;
359 std::map<string, unsigned int> node_count;
360 int i = 0;
361 for (Iter iter = Tree.begin(); iter!=Tree.end(); iter++)
362 {
363 const auto& parent = iter.node;
364 // const auto& parent_data = iter.node->data;
365
366 string parent_id = get_id(parent);
367 // if (Is<NodeType::Terminal>(parent_data.node_type))
368 // parent_id = parent_data.get_name(false);
369 // else{
370 // parent_id = fmt::format("{}",fmt::ptr(iter.node)).substr(2);
371 // }
372 // // parent_id = parent_id.substr(2);
373
374 // if the first node is weighted, make a dummy output node so that the
375 // first node's weight can be shown
376 if (i==0 && parent->data.get_is_weighted())
377 {
378 out += "y [shape=box];\n";
379 out += fmt::format("y -> \"{}\" [label=\"{:.2f}\"];\n",
380 // parent_data.get_name(false),
381 parent_id,
382 parent->data.W
383 );
384 }
385
386 // add the node
387 bool is_constant = Is<NodeType::Constant, NodeType::MeanLabel>(parent->data.node_type);
388 string node_label = parent->data.get_name(is_constant);
389
390 if (Is<NodeType::SplitBest>(parent->data.node_type)){
391 node_label = fmt::format("{}>{:.2f}?", parent->data.get_feature(), parent->data.W);
392 }
393 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
394 node_label = fmt::format("Add");
395 }
396 out += fmt::format("\"{}\" [label=\"{}\"];\n", parent_id, node_label);
397
398 // add edges to the node's children
399 auto kid = iter.node->first_child;
400 for (int j = 0; j < iter.number_of_children(); ++j)
401 {
402 string edge_label="";
403 string head_label="";
404 string tail_label="";
405 bool use_head_tail_labels = false;
406
407 string kid_id = get_id(kid);
408 // string kid_id = fmt::format("{}",fmt::ptr(kid));
409 // kid_id = kid_id.substr(2);
410
411 if (kid->data.get_is_weighted()
413 edge_label = fmt::format("{:.2f}",kid->data.W);
414 }
415
416 if (Is<NodeType::SplitOn>(parent->data.node_type)){
418 if (j == 0)
419 tail_label = fmt::format(">{:.2f}",parent->data.W);
420 else if (j==1)
421 tail_label = "Y";
422 else
423 tail_label = "N";
424
426 }
427 else if (Is<NodeType::SplitBest>(parent->data.node_type)){
429 if (j == 0){
430 tail_label = "Y";
431 }
432 else
433 tail_label = "N";
434
436 }
437
439 out += fmt::format("\"{}\" -> \"{}\" [headlabel=\"{}\",taillabel=\"{}\"];\n",
440 parent_id,
441 kid_id,
444 );
445 }
446 else{
447 out += fmt::format("\"{}\" -> \"{}\" [label=\"{}\"];\n",
448 parent_id,
449 kid_id,
451 );
452 }
453 kid = kid->next_sibling;
454 }
455
456 // adding the offset as the last child
457 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
458 // drawing the edge
459 out += fmt::format("\"{}\" -> \"{}\" [label=\"\"];\n",
460 parent_id,
461 parent_id+"Offset"
462 );
463
464 // drawing the node
465 out += fmt::format("\"{}\" [label=\"{}\"];\n",
466 parent_id+"Offset",
467 parent->data.W
468 );
469 }
470
471 ++i;
472 }
473 out += "}\n";
474 return out;
475 }
476
479 vector<Node> linearize() const {
480 vector<Node> linear_program;
481 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
482 linear_program.push_back(i.node->data);
483 return linear_program;
484 }
485}; // Program
486} // Brush
487
489// weight optimization
491// #include "../variation.h"
492namespace Brush{
493
494template<ProgramType PType>
496{
497 // Updates the weights within a tree.
498 // make an optimizer
499 auto WO = WeightOptimizer();
500 // get new weights from optimization.
501 WO.update((*this), d);
502};
503
504
506// serialization
507// serialization for program
508template<ProgramType PType>
509void to_json(json &j, const Program<PType> &p)
510{
511 j = json{{"Tree",p.Tree}, {"is_fitted_", p.is_fitted_}};
512}
513
514template<ProgramType PType>
515void from_json(const json &j, Program<PType>& p)
516{
517 j.at("Tree").get_to(p.Tree);
518 j.at("is_fitted_").get_to(p.is_fitted_);
519}
520
521}//namespace Brush
522
523
524
525#endif
void bind_engine(py::module &m, string name)
holds variable type data.
Definition data.h:51
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition data.cpp:12
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:24
tree< Node >::pre_order_iterator Iter
Definition program.h:37
tree< Node >::post_order_iterator PostIter
Definition program.h:38
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
ProgramType
Definition types.h:70
An individual program, a.k.a. model.
Definition program.h:50
Program(const std::reference_wrapper< SearchSpace > s, const tree< Node > t)
Definition program.h:78
TreeType predict_proba(const Dataset &d)
Definition program.h:212
void update_weights(const Dataset &d)
Updates the program's weights using non-linear least squares.
Definition program.h:495
typename std::conditional_t< PType==PT::Regressor, ArrayXf, std::conditional_t< PType==PT::BinaryClassifier, ArrayXb, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXi, std::conditional_t< PType==PT::Representer, ArrayXXf, ArrayXf > > > > RetType
the return type of the tree when calling :func:predict.
Definition program.h:55
std::optional< std::reference_wrapper< SearchSpace > > SSref
reference to search space
Definition program.h:75
RetType predict(const Ref< const ArrayXXf > &X)
Convenience function to call predict directly from X data.
Definition program.h:230
void set_weights(const ArrayXf &weights)
Set the weights in the tree from an array of weights.
Definition program.h:299
bool is_fitted_
whether fit has been called
Definition program.h:67
TreeType predict(const Dataset &d)
the standard predict function. Returns the output of the Tree directly.
Definition program.h:175
static constexpr PT program_type
an enum storing the program type.
Definition program.h:52
R predict_with_weights(const Dataset &d, const W **weights)
Definition program.h:152
std::conditional_t< PType==PT::BinaryClassifier, ArrayXf, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXXf, RetType > > TreeType
the type of output from the tree object
Definition program.h:62
int complexity() const
count the complexity of the program.
Definition program.h:93
tree< Node > Tree
fitness
Definition program.h:73
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
Definition program.h:113
ArrayXf get_weights()
Get the weights of the tree as an array.
Definition program.h:277
Program< PType > & fit(const Dataset &d)
Definition program.h:142
int depth() const
count the tree depth of the program. The depth is not influenced by weighted nodes.
Definition program.h:120
auto predict_with_weights(const Dataset &d, const ArrayXf &weights)
Definition program.h:160
int get_n_weights() const
returns the number of weights in the program.
Definition program.h:259
int depth_at(Iter &top) const
count the depth of a given subtree. The depth is not influenced by weighted nodes....
Definition program.h:130
vector< Node > linearize() const
turns program tree into a linear program.
Definition program.h:479
ArrayXi predict(const Dataset &d)
Specialized predict function for multiclass classification.
Definition program.h:200
void set_search_space(const std::reference_wrapper< SearchSpace > s)
Definition program.h:86
ArrayXb predict(const Dataset &d)
Specialized predict function for binary classification.
Definition program.h:188
string get_dot_model(string extras="") const
Get the model as a dot object.
Definition program.h:344
Program()=default
string get_model(string fmt="compact", bool pretty=false) const
Get the model as a string.
Definition program.h:328
TreeType predict_proba(const Ref< const ArrayXXf > &X)
Predict probabilities from X.
Definition program.h:245
int depth_to_reach(Iter &top) const
count the depth until reaching the given subtree. The depth is not influenced by weighted nodes....
Definition program.h:138
Program< PType > copy()
Definition program.h:84
Program< PType > & fit(const Ref< const ArrayXXf > &X, const Ref< const ArrayXf > &y)
Convenience function to call fit directly from X,y data.
Definition program.h:221
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Definition program.h:102
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...