Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
program.h
Go to the documentation of this file.
1/* Brush
2
3copyright 2020 William La Cava
4license: GNU/GPL v3
5*/
6#ifndef PROGRAM_H
7#define PROGRAM_H
8//external includes
9
10
11#include <string>
12#include "assert.h"
13
14// internal includes
15
16// #include "data/data.h"
17#include "../init.h"
18#include "tree_node.h"
19#include "node.h"
21#include "../params.h"
22#include "../util/utils.h"
23#include "functions.h"
24// #include "../variation.h"
25// #include "weight_optimizer.h"
26
27
28using std::cout;
29using std::string;
32
33namespace Brush {
34
35
36typedef tree<Node>::pre_order_iterator Iter;
37typedef tree<Node>::post_order_iterator PostIter;
38
40
41// for unsupervised learning, classification and regression.
42
48template<PT PType> struct Program
49{
51 static constexpr PT program_type = PType;
52
54 using RetType = typename std::conditional_t<PType == PT::Regressor, ArrayXf,
55 std::conditional_t<PType == PT::BinaryClassifier, ArrayXb,
56 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXi,
57 std::conditional_t<PType == PT::Representer, ArrayXXf, ArrayXf
58 >>>>;
59
61 using TreeType = std::conditional_t<PType == PT::BinaryClassifier, ArrayXf,
62 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXXf,
63 RetType>>;
64
66 bool is_fitted_ = false;
67
69 // Fitness fitness;
70
72 tree<Node> Tree;
73
75 std::optional<std::reference_wrapper<SearchSpace>> SSref;
76
77 Program() = default;
78 Program(const std::reference_wrapper<SearchSpace> s, const tree<Node> t)
79 : Tree(t), is_fitted_(false)
80 {
81 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
82 }
83
84 Program<PType> copy() { return Program<PType>(*this); }
85
86 inline void set_search_space(const std::reference_wrapper<SearchSpace> s)
87 {
88 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
89 }
90
93 int complexity() const{
94 auto head = Tree.begin();
95
96 return head.node->get_complexity();
97 }
98
101 int linear_complexity() const{
102 auto head = Tree.begin();
103
104 return head.node->get_linear_complexity();
105 }
106
110 int size(bool include_weight=true) const{
111 auto head = Tree.begin();
112
113 return head.node->get_size(include_weight);
114 }
115
121 int size_at(Iter& top, bool include_weight=true) const{
122
123 return top.node->get_size(include_weight);
124 }
125
128 int depth() const{
129 //tree.hh count the number of edges. We need to ensure that a single-node
130 //tree has depth>0
131 return 1+Tree.max_depth();
132 }
133
138 int depth_at(Iter& top) const{
139 return 1+Tree.max_depth(top);
140 }
141
146 int depth_to_reach(Iter& top) const{
147 return 1+Tree.depth(top);
148 }
149
151 {
152 TreeType out = Tree.begin().node->fit<TreeType>(d);
153 this->is_fitted_ = true;
155 // this->valid = true;
156 return *this;
157 };
158
166 {
167 this->Tree = new_program.Tree;
168 this->is_fitted_ = false;
169
170 // Update search space reference if the new program has one, otherwise keep current
171 if (new_program.SSref.has_value())
172 this->SSref = new_program.SSref;
173 // If new_program doesn't have a search space, keep this->SSref as is
174
175 return *this;
176 };
177
185 {
186 Program<PType> new_program = j;
187 return replace_program(new_program);
188 };
189
190 template <typename R, typename W>
191 R predict_with_weights(const Dataset &d, const W** weights)
192 {
193 if (!is_fitted_)
194 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
195
196 return Tree.begin().node->predict<R>(d, weights);
197 };
198
199 auto predict_with_weights(const Dataset &d, const ArrayXf& weights)
200 {
201 float const * wptr = weights.data();
202 return this->predict_with_weights<RetType>(d, &wptr);
203 };
204
213 template <typename R = RetType>
214 TreeType predict(const Dataset &d) requires(is_same_v<R, TreeType>)
215 {
216 if (!is_fitted_)
217 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
218
219 return Tree.begin().node->predict<TreeType>(d);
220 };
221
226 template <typename R = RetType>
227 ArrayXb predict(const Dataset &d) requires(is_same_v<R, ArrayXb>)
228 {
229 if (!is_fitted_)
230 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
231
232 return (Tree.begin().node->predict<TreeType>(d) > 0.5);
233 };
234
239 template <typename R = RetType>
240 ArrayXi predict(const Dataset &d) requires(is_same_v<R, ArrayXi>)
241 {
242 if (!is_fitted_)
243 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
244
245 TreeType out = Tree.begin().node->predict<TreeType>(d);
246 auto argmax = Function<NodeType::ArgMax>{};
247 return argmax(out);
248 };
249
250 // template <typename R = RetType>
251 template <PT P = PType>
252 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
254 {
255 return predict<TreeType>(d);
256 };
257
262 Program<PType>& fit(const Ref<const ArrayXXf>& X, const Ref<const ArrayXf>& y)
263 {
264 Dataset d(X,y);
265 return fit(d);
266 };
267
271 RetType predict(const Ref<const ArrayXXf>& X)
272 {
273 Dataset d(X);
274 return predict(d);
275 };
276
284 template <PT P = PType>
285 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
286 TreeType predict_proba(const Ref<const ArrayXXf>& X)
287 {
288 Dataset d(X);
289 return predict_proba(d);
290 };
291
297 void update_weights(const Dataset& d);
298
300 int get_n_weights() const
301 {
302 int count=0;
303 // check tree nodes for weights
304 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
305 {
306 const auto& node = i.node->data;
307
308 // we do not want to include fixed weights, because they should not be changed.
309 // get_n_weights, get_weights, and set_weights, are functions used in
310 // parameter optimization --- we can simply make weights invisible to parameter
311 // optimization, and they will not be changed (so they remain fixed).
312 if (node.weight_is_fixed)
313 continue;
314
315 // some nodes cannot have their weights optimized, others must have.
316 // It is important that this condition also matches the condition in
317 // the methods get_weights and set_weights.
318 if (Is<NodeType::OffsetSum>(node.node_type)
319 || (node.get_is_weighted() && IsWeighable(node.ret_type)) )
320 ++count;
321 }
322 return count;
323 }
324
330 ArrayXf get_weights()
331 {
332 ArrayXf weights(get_n_weights());
333 int i = 0;
334 for (PostIter t = Tree.begin_post(); t != Tree.end_post(); ++t)
335 {
336 const auto& node = t.node->data;
337
338 // skip fixed weights (this also avoid changing offsetSum weight if is locked)
339 if (node.weight_is_fixed)
340 continue;
341
342 if ( Is<NodeType::OffsetSum>(node.node_type)
343 || (node.get_is_weighted() && IsWeighable(node.ret_type)) )
344 {
345 weights(i) = node.W;
346 ++i;
347 }
348 }
349 return weights;
350 }
351
358 void set_weights(const ArrayXf& weights)
359 {
360 // take the weights set them in the tree.
361 // return the weights of the tree as an array
362 if (weights.size() != get_n_weights())
363 HANDLE_ERROR_THROW("Tried to set_weights of incorrect size");
364
365 int j = 0;
366 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
367 {
368 auto& node = i.node->data;
369
370 // skip fixed weights (this also avoid changing offsetSum weight if is locked)
371 if (node.weight_is_fixed)
372 continue;
373
374 if ( Is<NodeType::OffsetSum>(node.node_type)
375 || (node.get_is_weighted() && IsWeighable(node.node_type)) )
376 {
377 node.W = weights(j);
378 ++j;
379 }
380 }
381 }
382
393 void lock_nodes(int end_depth=0, bool keep_leaves_unlocked=true, bool keep_current_weights=false)
394 {
395 // This also support unlocking the ndoes by setting 0, true, false.
396 // OBS for unlocking the node: Do not change prob_change here, because some
397 // nodes are meant to never be replaced (e.g. logistic root).
398 // unlocking and changing probabilities are different things.
399 // every `if` performing a lock should have its counterpart `else` performing
400 // unlocking.
401
402 // iterate over the nodes, locking them if their depth does not exceed end_depth.
403 if (end_depth<0) {
404 return;
405 }
406
407 // we need the iterator to calculate the depth, but
408 // the lambda below iterate using nodes. So we are creating an iterator
409 // and using it to access depth.
410 auto tree_iter = Tree.begin();
411
412 std::for_each(Tree.begin(), Tree.end(),
413 [&](auto& n){
414 auto d = Tree.depth(tree_iter);
415 std::advance(tree_iter, 1);
416
417 // weights (this will work for all nodes)
418 if (n.get_is_weighted() && d<end_depth) // this will lock and unlock weights!
419 n.weight_is_fixed = keep_current_weights;
420 else{
421 n.weight_is_fixed = false;
422 }
423
424 // leaves (terminals, constants, and splitbest variables)
425 if (IsLeaf(n.node_type)) {
426 if (d<end_depth && !keep_leaves_unlocked)
427 n.node_is_fixed = true; // leaves should be locked based on depth
428 else {
429 // either we are outside end_depth or leaves should be unlocked.
430 // this `else` also helps unlocking th enode
431 n.node_is_fixed = false;
432 }
433
434 return; // stop here
435 }
436
437 // non-terminal nodes
438 if (d<end_depth) // This if-else could be a single line, but im trying to make it readable
439 n.node_is_fixed = true;
440 else // this else clausure helps unlocking the node
441 n.node_is_fixed = false;
442
443 // special case - split best nodes.
444 // If we are skipping leaves, then the split feature is unlocked;
445 // Otherwise, then we lock based on depth.
446 if (n.node_type==NodeType::SplitBest)
447 {
448 if (keep_leaves_unlocked)
449 {
450 n.set_keep_split_feature(false);
451 }
452 else // leaves can be locked
453 {
454 // check if we should lock based on depth
455 n.set_keep_split_feature(d+1<end_depth);
456 }
457 }
458 }
459 );
460 }
461
474 string get_model(string fmt="compact", bool pretty=false) const
475 {
476 auto head = Tree.begin();
477 if (fmt=="tree")
478 return head.node->get_tree_model(pretty);
479 else if (fmt=="dot")
480 return get_dot_model(); ;
481 return head.node->get_model(pretty);
482 }
483
490 string get_dot_model(string extras="") const
491 {
492 // TODO: make node IDs stable (hash or index) and labels reflect nodetype names.
493 // ref: https://stackoverflow.com/questions/10579041/graphviz-create-new-node-with-this-same-label#10579155
494 string out = "digraph G {\n";
495 if (! extras.empty())
496 out += fmt::format("{}\n", extras);
497
498 auto get_id = [](const auto& n){
499 if (Is<NodeType::Terminal>(n->data.node_type))
500 return n->data.get_name(false);
501
502 return fmt::format("{}",fmt::ptr(n)).substr(2);
503 };
504 // bool first = true;
505 std::map<string, unsigned int> node_count;
506 int i = 0;
507 for (Iter iter = Tree.begin(); iter!=Tree.end(); iter++)
508 {
509 const auto& parent = iter.node;
510 // const auto& parent_data = iter.node->data;
511
512 string parent_id = get_id(parent);
513 // if (Is<NodeType::Terminal>(parent_data.node_type))
514 // parent_id = parent_data.get_name(false);
515 // else{
516 // parent_id = fmt::format("{}",fmt::ptr(iter.node)).substr(2);
517 // }
518 // // parent_id = parent_id.substr(2);
519
520 // This is for the root --------------------------------------------
521 // if the first node is weighted, make a dummy output node so that the
522 // first node's weight can be shown
523 if (i==0 && parent->data.get_is_weighted())
524 {
525 // making the weight red if fixed
526 string font_color = "";
527 if (parent->data.weight_is_fixed) {
528 font_color = ", fontcolor=lightcoral";
529 }
530
531 out += "y [shape=box];\n";
532 out += fmt::format("y -> \"{}\" [label=\"{:.2f}\"{}];\n",
533 // parent_data.get_name(false),
534 parent_id,
535 parent->data.W,
536 font_color
537 );
538 }
539
540 // add the node
541 bool is_constant = Is<NodeType::Constant, NodeType::MeanLabel>(parent->data.node_type);
542 string node_label = parent->data.get_name(is_constant);
543
544 if (Is<NodeType::SplitBest>(parent->data.node_type)) {
545 std::string feature = parent->data.get_feature();
546 std::string threshold = fmt::format("{:.2f}", parent->data.W);
547
548 // Append markers for fixed flags
549 if (parent->data.keep_split_feature)
550 feature += "^"; // split feature fixed
551 if (parent->data.weight_is_fixed)
552 threshold += "*"; // split weight fixed
553
554 node_label = fmt::format("{} >= {}?", feature, threshold);
555 }
556 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
557 node_label = fmt::format("Add");
558 }
559
560 string node_style = parent->data.get_prob_change() >0.0 ? "" : ", style=filled, fillcolor=lightcoral";
561 out += fmt::format("\"{}\" [label=\"{}\"{}];\n", parent_id, node_label, node_style);
562
563 // add edges to the node's children
564 auto kid = iter.node->first_child;
565 for (int j = 0; j < iter.number_of_children(); ++j)
566 {
567 string edge_label="";
568 string head_label="";
569 string tail_label="";
570 bool use_head_tail_labels = false;
571
572 string kid_id = get_id(kid);
573 // string kid_id = fmt::format("{}",fmt::ptr(kid));
574 // kid_id = kid_id.substr(2);
575
576 if (kid->data.get_is_weighted()
578 NodeType::OffsetSum, NodeType::SplitBest>(kid->data.node_type))
579 {
580 edge_label = fmt::format("{:.2f}",kid->data.W);
581 }
582
583 if (Is<NodeType::SplitOn>(parent->data.node_type)){
584 use_head_tail_labels=true;
585 if (j == 0)
586 tail_label = fmt::format(">= {:.2f}",parent->data.W);
587 else if (j==1)
588 tail_label = "Y";
589 else
590 tail_label = "N";
591
592 head_label=edge_label;
593 }
594 else if (Is<NodeType::SplitBest>(parent->data.node_type)){
595 use_head_tail_labels=true;
596 if (j == 0){
597 tail_label = "Y";
598 }
599 else
600 tail_label = "N";
601
602 head_label = edge_label;
603 }
604
605 // drawing the edges
606 string font_color = "";
607 if (kid->data.weight_is_fixed) {
608 font_color = ", fontcolor=lightcoral";
609 }
610
611 if (use_head_tail_labels){
612 out += fmt::format("\"{}\" -> \"{}\" [headlabel=\"{}\",taillabel=\"{}\"{}];\n",
613 parent_id,
614 kid_id,
615 head_label,
616 tail_label,
617 font_color
618 );
619 }
620 else{
621 out += fmt::format("\"{}\" -> \"{}\" [label=\"{}\"{}];\n",
622 parent_id,
623 kid_id,
624 edge_label,
625 font_color
626 );
627 }
628 kid = kid->next_sibling;
629 }
630
631 // adding the offset as the last child
632 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
633 // drawing the edge
634 out += fmt::format("\"{}\" -> \"{}\" [label=\"\"];\n",
635 parent_id,
636 parent_id+"Offset"
637 );
638
639 // drawing the node
640 out += fmt::format("\"{}\" [label=\"{:.2f}\"{}];\n",
641 parent_id+"Offset",
642 parent->data.W,
643 node_style
644 );
645 }
646
647 ++i;
648 }
649
650 out += "label=\"^ split feature fixed, * split threshold fixed\";\nlabelloc=bottom;\nfontsize=10;";
651 out += "}\n";
652
653 return out;
654 }
655
658 vector<Node> linearize() const {
659 vector<Node> linear_program;
660 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
661 linear_program.push_back(i.node->data);
662 return linear_program;
663 }
664}; // Program
665} // Brush
666
668// weight optimization
670// #include "../variation.h"
671namespace Brush{
672
673template<ProgramType PType>
675{
676 // Updates the weights within a tree.
677 // make an optimizer
678 auto WO = WeightOptimizer();
679 // get new weights from optimization.
680 WO.update((*this), d);
681};
682
683
685// serialization
686// serialization for program
687template<ProgramType PType>
688void to_json(json &j, const Program<PType> &p)
689{
690 j = json{{"Tree",p.Tree}, {"is_fitted_", p.is_fitted_}};
691}
692
693template<ProgramType PType>
694void from_json(const json &j, Program<PType>& p)
695{
696 j.at("Tree").get_to(p.Tree);
697 j.at("is_fitted_").get_to(p.is_fitted_);
698}
699
700}//namespace Brush
701
702
703
704#endif
holds variable type data.
Definition data.h:51
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition bandit.cpp:3
auto Isnt(DataType dt) -> bool
Definition node.h:48
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
auto IsWeighable() noexcept -> bool
Definition node.h:51
ProgramType PT
Definition program.h:39
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:31
auto IsLeaf(NodeType nt) noexcept -> bool
Definition node.h:320
auto Is(NodeType nt) -> bool
Definition node.h:313
tree< Node >::pre_order_iterator Iter
Definition program.h:36
tree< Node >::post_order_iterator PostIter
Definition program.h:37
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
ProgramType
Definition types.h:70
An individual program, a.k.a. model.
Definition program.h:49
Program(const std::reference_wrapper< SearchSpace > s, const tree< Node > t)
Definition program.h:78
Program< PType > & replace_program(const json &j)
Replace the current program from a JSON representation, invalidating fitness.
Definition program.h:184
TreeType predict_proba(const Dataset &d)
Definition program.h:253
void lock_nodes(int end_depth=0, bool keep_leaves_unlocked=true, bool keep_current_weights=false)
Iterates over the program, locking the nodes until it reaches a certain depth. If the node is a Split...
Definition program.h:393
void update_weights(const Dataset &d)
Updates the program's weights using non-linear least squares.
Definition program.h:674
typename std::conditional_t< PType==PT::Regressor, ArrayXf, std::conditional_t< PType==PT::BinaryClassifier, ArrayXb, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXi, std::conditional_t< PType==PT::Representer, ArrayXXf, ArrayXf > > > > RetType
the return type of the tree when calling :func:predict.
Definition program.h:54
std::optional< std::reference_wrapper< SearchSpace > > SSref
Definition program.h:75
int linear_complexity() const
count the linear complexity of the program.
Definition program.h:101
RetType predict(const Ref< const ArrayXXf > &X)
Convenience function to call predict directly from X data.
Definition program.h:271
void set_weights(const ArrayXf &weights)
Set the weights in the tree from an array of weights.
Definition program.h:358
TreeType predict(const Dataset &d)
the standard predict function. Returns the output of the Tree directly.
Definition program.h:214
static constexpr PT program_type
Definition program.h:51
R predict_with_weights(const Dataset &d, const W **weights)
Definition program.h:191
std::conditional_t< PType==PT::BinaryClassifier, ArrayXf, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXXf, RetType > > TreeType
the type of output from the tree object
Definition program.h:61
int complexity() const
count the (recursive) complexity of the program.
Definition program.h:93
Program< PType > & replace_program(const Program< PType > &new_program)
Replace the current program with a new program, invalidating fitness.
Definition program.h:165
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
Definition program.h:121
ArrayXf get_weights()
Get the weights of the tree as an array.
Definition program.h:330
Program< PType > & fit(const Dataset &d)
Definition program.h:150
int depth() const
count the tree depth of the program. The depth is not influenced by weighted nodes.
Definition program.h:128
auto predict_with_weights(const Dataset &d, const ArrayXf &weights)
Definition program.h:199
int get_n_weights() const
returns the number of weights in the program.
Definition program.h:300
int depth_at(Iter &top) const
count the depth of a given subtree. The depth is not influenced by weighted nodes....
Definition program.h:138
vector< Node > linearize() const
turns program tree into a linear program.
Definition program.h:658
ArrayXi predict(const Dataset &d)
Specialized predict function for multiclass classification.
Definition program.h:240
void set_search_space(const std::reference_wrapper< SearchSpace > s)
Definition program.h:86
ArrayXb predict(const Dataset &d)
Specialized predict function for binary classification.
Definition program.h:227
string get_dot_model(string extras="") const
Definition program.h:490
Program()=default
string get_model(string fmt="compact", bool pretty=false) const
Get the model as a string.
Definition program.h:474
TreeType predict_proba(const Ref< const ArrayXXf > &X)
Predict probabilities from X.
Definition program.h:286
int depth_to_reach(Iter &top) const
count the depth until reaching the given subtree. The depth is not influenced by weighted nodes....
Definition program.h:146
Program< PType > copy()
Definition program.h:84
Program< PType > & fit(const Ref< const ArrayXXf > &X, const Ref< const ArrayXf > &y)
Convenience function to call fit directly from X,y data.
Definition program.h:262
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Definition program.h:110
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...