Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
program.h
Go to the documentation of this file.
1/* Brush
2
3copyright 2020 William La Cava
4license: GNU/GPL v3
5*/
6#ifndef PROGRAM_H
7#define PROGRAM_H
8//external includes
9
10//
11
12#include <string>
13#include "assert.h"
14
15// internal includes
16
17// #include "data/data.h"
18#include "../init.h"
19#include "tree_node.h"
20#include "node.h"
22#include "../params.h"
23#include "../util/utils.h"
24#include "functions.h"
25// #include "../variation.h"
26// #include "weight_optimizer.h"
27
28
29using std::cout;
30using std::string;
33
34namespace Brush {
35
36
37typedef tree<Node>::pre_order_iterator Iter;
38typedef tree<Node>::post_order_iterator PostIter;
39
41
42// for unsupervised learning, classification and regression.
43
49template<PT PType> struct Program
50{
52 static constexpr PT program_type = PType;
53
55 using RetType = typename std::conditional_t<PType == PT::Regressor, ArrayXf,
56 std::conditional_t<PType == PT::BinaryClassifier, ArrayXb,
57 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXi,
58 std::conditional_t<PType == PT::Representer, ArrayXXf, ArrayXf
59 >>>>;
60
62 using TreeType = std::conditional_t<PType == PT::BinaryClassifier, ArrayXf,
63 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXXf,
64 RetType>>;
65
67 bool is_fitted_ = false;
68
70 // Fitness fitness;
71
73 tree<Node> Tree;
74
76 std::optional<std::reference_wrapper<SearchSpace>> SSref;
77
78 Program() = default;
79 Program(const std::reference_wrapper<SearchSpace> s, const tree<Node> t)
80 : Tree(t), is_fitted_(false)
81 {
82 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
83 }
84
85 Program<PType> copy() { return Program<PType>(*this); }
86
87 inline void set_search_space(const std::reference_wrapper<SearchSpace> s)
88 {
89 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
90 }
91
94 int complexity() const{
95 auto head = Tree.begin();
96
97 return head.node->get_complexity();
98 }
99
102 int linear_complexity() const{
103 auto head = Tree.begin();
104
105 return head.node->get_linear_complexity();
106 }
107
111 int size(bool include_weight=true) const{
112 auto head = Tree.begin();
113
114 return head.node->get_size(include_weight);
115 }
116
122 int size_at(Iter& top, bool include_weight=true) const{
123
124 return top.node->get_size(include_weight);
125 }
126
129 int depth() const{
130 //tree.hh count the number of edges. We need to ensure that a single-node
131 //tree has depth>0
132 return 1+Tree.max_depth();
133 }
134
139 int depth_at(Iter& top) const{
140 return 1+Tree.max_depth(top);
141 }
142
147 int depth_to_reach(Iter& top) const{
148 return 1+Tree.depth(top);
149 }
150
152 {
153 TreeType out = Tree.begin().node->fit<TreeType>(d);
154 this->is_fitted_ = true;
156 // this->valid = true;
157 return *this;
158 };
159
167 {
168 this->Tree = new_program.Tree;
169 this->is_fitted_ = false;
170
171 // Update search space reference if the new program has one, otherwise keep current
172 if (new_program.SSref.has_value())
173 this->SSref = new_program.SSref;
174 // If new_program doesn't have a search space, keep this->SSref as is
175
176 return *this;
177 };
178
186 {
187 Program<PType> new_program = j;
188 return replace_program(new_program);
189 };
190
191 template <typename R, typename W>
192 R predict_with_weights(const Dataset &d, const W** weights)
193 {
194 if (!is_fitted_)
195 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
196
197 return Tree.begin().node->predict<R>(d, weights);
198 };
199
200 auto predict_with_weights(const Dataset &d, const ArrayXf& weights)
201 {
202 float const * wptr = weights.data();
203 return this->predict_with_weights<RetType>(d, &wptr);
204 };
205
214 template <typename R = RetType>
215 TreeType predict(const Dataset &d) requires(is_same_v<R, TreeType>)
216 {
217 if (!is_fitted_)
218 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
219
220 return Tree.begin().node->predict<TreeType>(d);
221 };
222
227 template <typename R = RetType>
228 ArrayXb predict(const Dataset &d) requires(is_same_v<R, ArrayXb>)
229 {
230 if (!is_fitted_)
231 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
232
233 return (Tree.begin().node->predict<TreeType>(d) > 0.5);
234 };
235
240 template <typename R = RetType>
241 ArrayXi predict(const Dataset &d) requires(is_same_v<R, ArrayXi>)
242 {
243 if (!is_fitted_)
244 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
245
246 TreeType out = Tree.begin().node->predict<TreeType>(d);
247 auto argmax = Function<NodeType::ArgMax>{};
248 return argmax(out);
249 };
250
251 // template <typename R = RetType>
252 template <PT P = PType>
253 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
255 {
256 return predict<TreeType>(d);
257 };
258
263 Program<PType>& fit(const Ref<const ArrayXXf>& X, const Ref<const ArrayXf>& y)
264 {
265 Dataset d(X,y);
266 return fit(d);
267 };
268
272 RetType predict(const Ref<const ArrayXXf>& X)
273 {
274 Dataset d(X);
275 return predict(d);
276 };
277
285 template <PT P = PType>
286 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
287 TreeType predict_proba(const Ref<const ArrayXXf>& X)
288 {
289 Dataset d(X);
290 return predict_proba(d);
291 };
292
298 void update_weights(const Dataset& d);
299
301 int get_n_weights() const
302 {
303 int count=0;
304 // check tree nodes for weights
305 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
306 {
307 const auto& node = i.node->data;
308
309 // we do not want to include fixed weights, because they should not be changed.
310 // get_n_weights, get_weights, and set_weights, are functions used in
311 // parameter optimization --- we can simply make weights invisible to parameter
312 // optimization, and they will not be changed (so they remain fixed).
313 if (node.weight_is_fixed)
314 continue;
315
316 // some nodes cannot have their weights optimized, others must have.
317 // It is important that this condition also matches the condition in
318 // the methods get_weights and set_weights.
319 if (Is<NodeType::OffsetSum>(node.node_type)
320 || (node.get_is_weighted() && IsWeighable(node.ret_type)) )
321 ++count;
322 }
323 return count;
324 }
325
331 ArrayXf get_weights()
332 {
333 ArrayXf weights(get_n_weights());
334 int i = 0;
335 for (PostIter t = Tree.begin_post(); t != Tree.end_post(); ++t)
336 {
337 const auto& node = t.node->data;
338
339 // skip fixed weights (this also avoid changing offsetSum weight if is locked)
340 if (node.weight_is_fixed)
341 continue;
342
343 if ( Is<NodeType::OffsetSum>(node.node_type)
344 || (node.get_is_weighted() && IsWeighable(node.ret_type)) )
345 {
346 weights(i) = node.W;
347 ++i;
348 }
349 }
350 return weights;
351 }
352
359 void set_weights(const ArrayXf& weights)
360 {
361 // take the weights set them in the tree.
362 // return the weights of the tree as an array
363 if (weights.size() != get_n_weights())
364 HANDLE_ERROR_THROW("Tried to set_weights of incorrect size");
365
366 int j = 0;
367 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
368 {
369 auto& node = i.node->data;
370
371 // skip fixed weights (this also avoid changing offsetSum weight if is locked)
372 if (node.weight_is_fixed)
373 continue;
374
375 if ( Is<NodeType::OffsetSum>(node.node_type)
376 || (node.get_is_weighted() && IsWeighable(node.node_type)) )
377 {
378 node.W = weights(j);
379 ++j;
380 }
381 }
382 }
383
394 void lock_nodes(int end_depth=0, bool keep_leaves_unlocked=true, bool keep_current_weights=false)
395 {
396 // This also support unlocking the ndoes by setting 0, true, false.
397 // OBS for unlocking the node: Do not change prob_change here, because some
398 // nodes are meant to never be replaced (e.g. logistic root).
399 // unlocking and changing probabilities are different things.
400 // every `if` performing a lock should have its counterpart `else` performing
401 // unlocking.
402
403 // iterate over the nodes, locking them if their depth does not exceed end_depth.
404 if (end_depth<0) {
405 return;
406 }
407
408 // we need the iterator to calculate the depth, but
409 // the lambda below iterate using nodes. So we are creating an iterator
410 // and using it to access depth.
411 auto tree_iter = Tree.begin();
412
413 std::for_each(Tree.begin(), Tree.end(),
414 [&](auto& n){
415 auto d = Tree.depth(tree_iter);
416 std::advance(tree_iter, 1);
417
418 // weights (this will work for all nodes)
419 if (n.get_is_weighted() && d<end_depth) // this will lock and unlock weights!
420 n.weight_is_fixed = keep_current_weights;
421 else{
422 n.weight_is_fixed = false;
423 }
424
425 // leaves (terminals, constants, and splitbest variables)
426 if (IsLeaf(n.node_type)) {
427 if (d<end_depth && !keep_leaves_unlocked)
428 n.node_is_fixed = true; // leaves should be locked based on depth
429 else {
430 // either we are outside end_depth or leaves should be unlocked.
431 // this `else` also helps unlocking th enode
432 n.node_is_fixed = false;
433 }
434
435 return; // stop here
436 }
437
438 // non-terminal nodes
439 if (d<end_depth) // This if-else could be a single line, but im trying to make it readable
440 n.node_is_fixed = true;
441 else // this else clausure helps unlocking the node
442 n.node_is_fixed = false;
443
444 // special case - split best nodes.
445 // If we are skipping leaves, then the split feature is unlocked;
446 // Otherwise, then we lock based on depth.
447 if (n.node_type==NodeType::SplitBest)
448 {
449 if (keep_leaves_unlocked)
450 {
451 n.set_keep_split_feature(false);
452 }
453 else // leaves can be locked
454 {
455 // check if we should lock based on depth
456 n.set_keep_split_feature(d+1<end_depth);
457 }
458 }
459 }
460 );
461 }
462
475 string get_model(string fmt="compact", bool pretty=false) const
476 {
477 auto head = Tree.begin();
478 if (fmt=="tree")
479 return head.node->get_tree_model(pretty);
480 else if (fmt=="dot")
481 return get_dot_model(); ;
482 return head.node->get_model(pretty);
483 }
484
491 string get_dot_model(string extras="") const
492 {
493 // TODO: make the node names their hash or index, and the node label the nodetype name.
494 // ref: https://stackoverflow.com/questions/10579041/graphviz-create-new-node-with-this-same-label#10579155
495 string out = "digraph G {\n";
496 if (! extras.empty())
497 out += fmt::format("{}\n", extras);
498
499 auto get_id = [](const auto& n){
500 if (Is<NodeType::Terminal>(n->data.node_type))
501 return n->data.get_name(false);
502
503 return fmt::format("{}",fmt::ptr(n)).substr(2);
504 };
505 // bool first = true;
506 std::map<string, unsigned int> node_count;
507 int i = 0;
508 for (Iter iter = Tree.begin(); iter!=Tree.end(); iter++)
509 {
510 const auto& parent = iter.node;
511 // const auto& parent_data = iter.node->data;
512
513 string parent_id = get_id(parent);
514 // if (Is<NodeType::Terminal>(parent_data.node_type))
515 // parent_id = parent_data.get_name(false);
516 // else{
517 // parent_id = fmt::format("{}",fmt::ptr(iter.node)).substr(2);
518 // }
519 // // parent_id = parent_id.substr(2);
520
521 // This is for the root --------------------------------------------
522 // if the first node is weighted, make a dummy output node so that the
523 // first node's weight can be shown
524 if (i==0 && parent->data.get_is_weighted())
525 {
526 // making the weight red if fixed
527 string font_color = "";
528 if (parent->data.weight_is_fixed) {
529 font_color = ", fontcolor=lightcoral";
530 }
531
532 out += "y [shape=box];\n";
533 out += fmt::format("y -> \"{}\" [label=\"{:.2f}\"{}];\n",
534 // parent_data.get_name(false),
535 parent_id,
536 parent->data.W,
537 font_color
538 );
539 }
540
541 // add the node
542 bool is_constant = Is<NodeType::Constant, NodeType::MeanLabel>(parent->data.node_type);
543 string node_label = parent->data.get_name(is_constant);
544
545 if (Is<NodeType::SplitBest>(parent->data.node_type)) {
546 std::string feature = parent->data.get_feature();
547 std::string threshold = fmt::format("{:.2f}", parent->data.W);
548
549 // Append markers for fixed flags
550 if (parent->data.keep_split_feature)
551 feature += "^"; // split feature fixed
552 if (parent->data.weight_is_fixed)
553 threshold += "*"; // split weight fixed
554
555 node_label = fmt::format("{} >= {}?", feature, threshold);
556 }
557 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
558 node_label = fmt::format("Add");
559 }
560
561 string node_style = parent->data.get_prob_change() >0.0 ? "" : ", style=filled, fillcolor=lightcoral";
562 out += fmt::format("\"{}\" [label=\"{}\"{}];\n", parent_id, node_label, node_style);
563
564 // add edges to the node's children
565 auto kid = iter.node->first_child;
566 for (int j = 0; j < iter.number_of_children(); ++j)
567 {
568 string edge_label="";
569 string head_label="";
570 string tail_label="";
571 bool use_head_tail_labels = false;
572
573 string kid_id = get_id(kid);
574 // string kid_id = fmt::format("{}",fmt::ptr(kid));
575 // kid_id = kid_id.substr(2);
576
577 if (kid->data.get_is_weighted()
579 NodeType::OffsetSum, NodeType::SplitBest>(kid->data.node_type))
580 {
581 edge_label = fmt::format("{:.2f}",kid->data.W);
582 }
583
584 if (Is<NodeType::SplitOn>(parent->data.node_type)){
585 use_head_tail_labels=true;
586 if (j == 0)
587 tail_label = fmt::format(">= {:.2f}",parent->data.W);
588 else if (j==1)
589 tail_label = "Y";
590 else
591 tail_label = "N";
592
593 head_label=edge_label;
594 }
595 else if (Is<NodeType::SplitBest>(parent->data.node_type)){
596 use_head_tail_labels=true;
597 if (j == 0){
598 tail_label = "Y";
599 }
600 else
601 tail_label = "N";
602
603 head_label = edge_label;
604 }
605
606 // drawing the edges
607 string font_color = "";
608 if (kid->data.weight_is_fixed) {
609 font_color = ", fontcolor=lightcoral";
610 }
611
612 if (use_head_tail_labels){
613 out += fmt::format("\"{}\" -> \"{}\" [headlabel=\"{}\",taillabel=\"{}\"{}];\n",
614 parent_id,
615 kid_id,
616 head_label,
617 tail_label,
618 font_color
619 );
620 }
621 else{
622 out += fmt::format("\"{}\" -> \"{}\" [label=\"{}\"{}];\n",
623 parent_id,
624 kid_id,
625 edge_label,
626 font_color
627 );
628 }
629 kid = kid->next_sibling;
630 }
631
632 // adding the offset as the last child
633 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
634 // drawing the edge
635 out += fmt::format("\"{}\" -> \"{}\" [label=\"\"];\n",
636 parent_id,
637 parent_id+"Offset"
638 );
639
640 // drawing the node
641 out += fmt::format("\"{}\" [label=\"{:.2f}\"{}];\n",
642 parent_id+"Offset",
643 parent->data.W,
644 node_style
645 );
646 }
647
648 ++i;
649 }
650
651 out += "label=\"^ split feature fixed, * split threshold fixed\";\nlabelloc=bottom;\nfontsize=10;";
652 out += "}\n";
653
654 return out;
655 }
656
659 vector<Node> linearize() const {
660 vector<Node> linear_program;
661 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
662 linear_program.push_back(i.node->data);
663 return linear_program;
664 }
665}; // Program
666} // Brush
667
669// weight optimization
671// #include "../variation.h"
672namespace Brush{
673
674template<ProgramType PType>
676{
677 // Updates the weights within a tree.
678 // make an optimizer
679 auto WO = WeightOptimizer();
680 // get new weights from optimization.
681 WO.update((*this), d);
682};
683
684
686// serialization
687// serialization for program
688template<ProgramType PType>
689void to_json(json &j, const Program<PType> &p)
690{
691 j = json{{"Tree",p.Tree}, {"is_fitted_", p.is_fitted_}};
692}
693
694template<ProgramType PType>
695void from_json(const json &j, Program<PType>& p)
696{
697 j.at("Tree").get_to(p.Tree);
698 j.at("is_fitted_").get_to(p.is_fitted_);
699}
700
701}//namespace Brush
702
703
704
705#endif
holds variable type data.
Definition data.h:51
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
auto Isnt(DataType dt) -> bool
Definition node.h:48
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
auto IsWeighable() noexcept -> bool
Definition node.h:51
ProgramType PT
Definition program.h:40
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:31
auto IsLeaf(NodeType nt) noexcept -> bool
Definition node.h:320
auto Is(NodeType nt) -> bool
Definition node.h:313
tree< Node >::pre_order_iterator Iter
Definition program.h:37
tree< Node >::post_order_iterator PostIter
Definition program.h:38
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
ProgramType
Definition types.h:70
An individual program, a.k.a. model.
Definition program.h:50
Program(const std::reference_wrapper< SearchSpace > s, const tree< Node > t)
Definition program.h:79
Program< PType > & replace_program(const json &j)
Replace the current program from a JSON representation, invalidating fitness.
Definition program.h:185
TreeType predict_proba(const Dataset &d)
Definition program.h:254
void lock_nodes(int end_depth=0, bool keep_leaves_unlocked=true, bool keep_current_weights=false)
Iterates over the program, locking the nodes until it reaches a certain depth. If the node is a Split...
Definition program.h:394
void update_weights(const Dataset &d)
Updates the program's weights using non-linear least squares.
Definition program.h:675
typename std::conditional_t< PType==PT::Regressor, ArrayXf, std::conditional_t< PType==PT::BinaryClassifier, ArrayXb, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXi, std::conditional_t< PType==PT::Representer, ArrayXXf, ArrayXf > > > > RetType
the return type of the tree when calling :func:predict.
Definition program.h:55
std::optional< std::reference_wrapper< SearchSpace > > SSref
Definition program.h:76
int linear_complexity() const
count the linear complexity of the program.
Definition program.h:102
RetType predict(const Ref< const ArrayXXf > &X)
Convenience function to call predict directly from X data.
Definition program.h:272
void set_weights(const ArrayXf &weights)
Set the weights in the tree from an array of weights.
Definition program.h:359
TreeType predict(const Dataset &d)
the standard predict function. Returns the output of the Tree directly.
Definition program.h:215
static constexpr PT program_type
Definition program.h:52
R predict_with_weights(const Dataset &d, const W **weights)
Definition program.h:192
std::conditional_t< PType==PT::BinaryClassifier, ArrayXf, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXXf, RetType > > TreeType
the type of output from the tree object
Definition program.h:62
int complexity() const
count the (recursive) complexity of the program.
Definition program.h:94
Program< PType > & replace_program(const Program< PType > &new_program)
Replace the current program with a new program, invalidating fitness.
Definition program.h:166
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
Definition program.h:122
ArrayXf get_weights()
Get the weights of the tree as an array.
Definition program.h:331
Program< PType > & fit(const Dataset &d)
Definition program.h:151
int depth() const
count the tree depth of the program. The depth is not influenced by weighted nodes.
Definition program.h:129
auto predict_with_weights(const Dataset &d, const ArrayXf &weights)
Definition program.h:200
int get_n_weights() const
returns the number of weights in the program.
Definition program.h:301
int depth_at(Iter &top) const
count the depth of a given subtree. The depth is not influenced by weighted nodes....
Definition program.h:139
vector< Node > linearize() const
turns program tree into a linear program.
Definition program.h:659
ArrayXi predict(const Dataset &d)
Specialized predict function for multiclass classification.
Definition program.h:241
void set_search_space(const std::reference_wrapper< SearchSpace > s)
Definition program.h:87
ArrayXb predict(const Dataset &d)
Specialized predict function for binary classification.
Definition program.h:228
string get_dot_model(string extras="") const
Definition program.h:491
Program()=default
string get_model(string fmt="compact", bool pretty=false) const
Get the model as a string.
Definition program.h:475
TreeType predict_proba(const Ref< const ArrayXXf > &X)
Predict probabilities from X.
Definition program.h:287
int depth_to_reach(Iter &top) const
count the depth until reaching the given subtree. The depth is not influenced by weighted nodes....
Definition program.h:147
Program< PType > copy()
Definition program.h:85
Program< PType > & fit(const Ref< const ArrayXXf > &X, const Ref< const ArrayXf > &y)
Convenience function to call fit directly from X,y data.
Definition program.h:263
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Definition program.h:111
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...