Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
program.h
Go to the documentation of this file.
1/* Brush
2
3copyright 2020 William La Cava
4license: GNU/GPL v3
5*/
6#ifndef PROGRAM_H
7#define PROGRAM_H
8//external includes
9
10//
11
12#include <string>
13#include "assert.h"
14
15// internal includes
16
17// #include "data/data.h"
18#include "../init.h"
19#include "tree_node.h"
20#include "node.h"
22#include "../params.h"
23#include "../util/utils.h"
24#include "functions.h"
25// #include "../variation.h"
26// #include "weight_optimizer.h"
27
28
29using std::cout;
30using std::string;
33
34namespace Brush {
35
36
37typedef tree<Node>::pre_order_iterator Iter;
38typedef tree<Node>::post_order_iterator PostIter;
39
41
42// for unsupervised learning, classification and regression.
43
49template<PT PType> struct Program
50{
52 static constexpr PT program_type = PType;
53
55 using RetType = typename std::conditional_t<PType == PT::Regressor, ArrayXf,
56 std::conditional_t<PType == PT::BinaryClassifier, ArrayXb,
57 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXi,
58 std::conditional_t<PType == PT::Representer, ArrayXXf, ArrayXf
59 >>>>;
60
62 using TreeType = std::conditional_t<PType == PT::BinaryClassifier, ArrayXf,
63 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXXf,
64 RetType>>;
65
68
70 // Fitness fitness;
71
73 tree<Node> Tree;
75 std::optional<std::reference_wrapper<SearchSpace>> SSref;
76
77 Program() = default;
78 Program(const std::reference_wrapper<SearchSpace> s, const tree<Node> t)
79 : Tree(t)
80 {
81 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
82 }
83
84 Program<PType> copy() { return Program<PType>(*this); }
85
86 inline void set_search_space(const std::reference_wrapper<SearchSpace> s)
87 {
88 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
89 }
90
93 int complexity() const{
94 auto head = Tree.begin();
95
96 return head.node->get_complexity();
97 }
98
101 int linear_complexity() const{
102 auto head = Tree.begin();
103
104 return head.node->get_linear_complexity();
105 }
106
110 int size(bool include_weight=true) const{
111 auto head = Tree.begin();
112
113 return head.node->get_size(include_weight);
114 }
115
121 int size_at(Iter& top, bool include_weight=true) const{
122
123 return top.node->get_size(include_weight);
124 }
125
128 int depth() const{
129 //tree.hh count the number of edges. We need to ensure that a single-node
130 //tree has depth>0
131 return 1+Tree.max_depth();
132 }
133
138 int depth_at(Iter& top) const{
139 return 1+Tree.max_depth(top);
140 }
141
146 int depth_to_reach(Iter& top) const{
147 return 1+Tree.depth(top);
148 }
149
151 {
152 TreeType out = Tree.begin().node->fit<TreeType>(d);
153 this->is_fitted_ = true;
155 // this->valid = true;
156 return *this;
157 };
158
159 template <typename R, typename W>
160 R predict_with_weights(const Dataset &d, const W** weights)
161 {
162 if (!is_fitted_)
163 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
164
165 return Tree.begin().node->predict<R>(d, weights);
166 };
167
168 auto predict_with_weights(const Dataset &d, const ArrayXf& weights)
169 {
170 float const * wptr = weights.data();
171 return this->predict_with_weights<RetType>(d, &wptr);
172 };
173
182 template <typename R = RetType>
183 TreeType predict(const Dataset &d) requires(is_same_v<R, TreeType>)
184 {
185 if (!is_fitted_)
186 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
187
188 return Tree.begin().node->predict<TreeType>(d);
189 };
190
195 template <typename R = RetType>
196 ArrayXb predict(const Dataset &d) requires(is_same_v<R, ArrayXb>)
197 {
198 if (!is_fitted_)
199 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
200
201 return (Tree.begin().node->predict<TreeType>(d) > 0.5);
202 };
203
208 template <typename R = RetType>
209 ArrayXi predict(const Dataset &d) requires(is_same_v<R, ArrayXi>)
210 {
211 if (!is_fitted_)
212 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
213
214 TreeType out = Tree.begin().node->predict<TreeType>(d);
215 auto argmax = Function<NodeType::ArgMax>{};
216 return argmax(out);
217 };
218
219 // template <typename R = RetType>
220 template <PT P = PType>
221 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
223 {
224 return predict<TreeType>(d);
225 };
226
231 Program<PType>& fit(const Ref<const ArrayXXf>& X, const Ref<const ArrayXf>& y)
232 {
233 Dataset d(X,y);
234 return fit(d);
235 };
236
240 RetType predict(const Ref<const ArrayXXf>& X)
241 {
242 Dataset d(X);
243 return predict(d);
244 };
245
253 template <PT P = PType>
254 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
255 TreeType predict_proba(const Ref<const ArrayXXf>& X)
256 {
257 Dataset d(X);
258 return predict_proba(d);
259 };
260
266 void update_weights(const Dataset& d);
267
269 int get_n_weights() const
270 {
271 int count=0;
272 // check tree nodes for weights
273 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
274 {
275 const auto& node = i.node->data;
276 // some nodes cannot have their weights optimized, others must have.
277 // It is important that this condition also matches the condition in
278 // the methods get_weights, set_weights, .
279 if ( Is<NodeType::OffsetSum>(node.node_type)
280 || (node.get_is_weighted() && IsWeighable(node.ret_type)) )
281 ++count;
282 }
283 return count;
284 }
285
291 ArrayXf get_weights()
292 {
293 ArrayXf weights(get_n_weights());
294 int i = 0;
295 for (PostIter t = Tree.begin_post(); t != Tree.end_post(); ++t)
296 {
297 const auto& node = t.node->data;
298 if ( Is<NodeType::OffsetSum>(node.node_type)
299 || (node.get_is_weighted() && IsWeighable(node.ret_type)) )
300 {
301 weights(i) = node.W;
302 ++i;
303 }
304 }
305 return weights;
306 }
307
314 void set_weights(const ArrayXf& weights)
315 {
316 // take the weights set them in the tree.
317 // return the weights of the tree as an array
318 if (weights.size() != get_n_weights())
319 HANDLE_ERROR_THROW("Tried to set_weights of incorrect size");
320 int j = 0;
321 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
322 {
323 auto& node = i.node->data;
324 if ( Is<NodeType::OffsetSum>(node.node_type)
325 || (node.get_is_weighted() && IsWeighable(node.node_type)) )
326 {
327 node.W = weights(j);
328 ++j;
329 }
330 }
331 }
332
342 void lock_nodes(int end_depth=0, bool keep_leaves_unlocked=true)
343 {
344 // iterate over the nodes, locking them if their depth does not exceed end_depth.
345 if (end_depth<=0)
346 return;
347
348 // we need the iterator to calculate the depth, but
349 // the lambda below iterate using nodes. So we are creating an iterator
350 // and using it to access depth.
351 auto tree_iter = Tree.begin();
352
353 std::for_each(Tree.begin(), Tree.end(),
354 [&](auto& n){
355 auto d = Tree.depth(tree_iter);
356 std::advance(tree_iter, 1);
357
358 if (keep_leaves_unlocked && IsLeaf(n.node_type))
359 return;
360
361 // If we are skipping leaves, then the split feature is unlocked;
362 // Otherwise, then we lock based on depth.
363 if (n.node_type==NodeType::SplitBest)
364 {
365 if (keep_leaves_unlocked)
366 {
367 n.set_keep_split_feature(false);
368 }
369 else // leaves can be locked
370 {
371 // check if we should lock based on depth
372 n.set_keep_split_feature(d+1<=end_depth);
373 }
374 }
375
376 if (d<=end_depth)
377 n.fixed = true;
378 // n.set_prob_change(0.0f);
379 }
380 );
381 }
382
395 string get_model(string fmt="compact", bool pretty=false) const
396 {
397 auto head = Tree.begin();
398 if (fmt=="tree")
399 return head.node->get_tree_model(pretty);
400 else if (fmt=="dot")
401 return get_dot_model(); ;
402 return head.node->get_model(pretty);
403 }
404
411 string get_dot_model(string extras="") const
412 {
413 // TODO: make the node names their hash or index, and the node label the nodetype name.
414 // ref: https://stackoverflow.com/questions/10579041/graphviz-create-new-node-with-this-same-label#10579155
415 string out = "digraph G {\n";
416 if (! extras.empty())
417 out += fmt::format("{}\n", extras);
418
419 auto get_id = [](const auto& n){
420 if (Is<NodeType::Terminal>(n->data.node_type))
421 return n->data.get_name(false);
422
423 return fmt::format("{}",fmt::ptr(n)).substr(2);
424 };
425 // bool first = true;
426 std::map<string, unsigned int> node_count;
427 int i = 0;
428 for (Iter iter = Tree.begin(); iter!=Tree.end(); iter++)
429 {
430 const auto& parent = iter.node;
431 // const auto& parent_data = iter.node->data;
432
433 string parent_id = get_id(parent);
434 // if (Is<NodeType::Terminal>(parent_data.node_type))
435 // parent_id = parent_data.get_name(false);
436 // else{
437 // parent_id = fmt::format("{}",fmt::ptr(iter.node)).substr(2);
438 // }
439 // // parent_id = parent_id.substr(2);
440
441 // if the first node is weighted, make a dummy output node so that the
442 // first node's weight can be shown
443 if (i==0 && parent->data.get_is_weighted())
444 {
445 out += "y [shape=box];\n";
446 out += fmt::format("y -> \"{}\" [label=\"{:.2f}\"];\n",
447 // parent_data.get_name(false),
448 parent_id,
449 parent->data.W
450 );
451 }
452
453 // add the node
454 bool is_constant = Is<NodeType::Constant, NodeType::MeanLabel>(parent->data.node_type);
455 string node_label = parent->data.get_name(is_constant);
456
457 if (Is<NodeType::SplitBest>(parent->data.node_type)){
458 node_label = fmt::format("{}>={:.2f}?", parent->data.get_feature(), parent->data.W);
459 }
460 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
461 node_label = fmt::format("Add");
462 }
463
464 string node_style = parent->data.get_prob_change() >0.0 ? "" : ", style=filled, fillcolor=lightcoral";
465 out += fmt::format("\"{}\" [label=\"{}\"{}];\n", parent_id, node_label, node_style);
466
467 // add edges to the node's children
468 auto kid = iter.node->first_child;
469 for (int j = 0; j < iter.number_of_children(); ++j)
470 {
471 string edge_label="";
472 string head_label="";
473 string tail_label="";
474 bool use_head_tail_labels = false;
475
476 string kid_id = get_id(kid);
477 // string kid_id = fmt::format("{}",fmt::ptr(kid));
478 // kid_id = kid_id.substr(2);
479
480 if (kid->data.get_is_weighted()
482 NodeType::OffsetSum, NodeType::SplitBest>(kid->data.node_type))
483 {
484 edge_label = fmt::format("{:.2f}",kid->data.W);
485 }
486
487 if (Is<NodeType::SplitOn>(parent->data.node_type)){
488 use_head_tail_labels=true;
489 if (j == 0)
490 tail_label = fmt::format(">={:.2f}",parent->data.W);
491 else if (j==1)
492 tail_label = "Y";
493 else
494 tail_label = "N";
495
496 head_label=edge_label;
497 }
498 else if (Is<NodeType::SplitBest>(parent->data.node_type)){
499 use_head_tail_labels=true;
500 if (j == 0){
501 tail_label = "Y";
502 }
503 else
504 tail_label = "N";
505
506 head_label = edge_label;
507 }
508
509 if (use_head_tail_labels){
510 out += fmt::format("\"{}\" -> \"{}\" [headlabel=\"{}\",taillabel=\"{}\"];\n",
511 parent_id,
512 kid_id,
513 head_label,
514 tail_label
515 );
516 }
517 else{
518 out += fmt::format("\"{}\" -> \"{}\" [label=\"{}\"];\n",
519 parent_id,
520 kid_id,
521 edge_label
522 );
523 }
524 kid = kid->next_sibling;
525 }
526
527 // adding the offset as the last child
528 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
529 // drawing the edge
530 out += fmt::format("\"{}\" -> \"{}\" [label=\"\"];\n",
531 parent_id,
532 parent_id+"Offset"
533 );
534
535 // drawing the node
536 out += fmt::format("\"{}\" [label=\"{:.2f}\"{}];\n",
537 parent_id+"Offset",
538 parent->data.W,
539 node_style
540 );
541 }
542
543 ++i;
544 }
545 out += "}\n";
546 return out;
547 }
548
551 vector<Node> linearize() const {
552 vector<Node> linear_program;
553 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
554 linear_program.push_back(i.node->data);
555 return linear_program;
556 }
557}; // Program
558} // Brush
559
561// weight optimization
563// #include "../variation.h"
564namespace Brush{
565
566template<ProgramType PType>
568{
569 // Updates the weights within a tree.
570 // make an optimizer
571 auto WO = WeightOptimizer();
572 // get new weights from optimization.
573 WO.update((*this), d);
574};
575
576
578// serialization
579// serialization for program
580template<ProgramType PType>
581void to_json(json &j, const Program<PType> &p)
582{
583 j = json{{"Tree",p.Tree}, {"is_fitted_", p.is_fitted_}};
584}
585
586template<ProgramType PType>
587void from_json(const json &j, Program<PType>& p)
588{
589 j.at("Tree").get_to(p.Tree);
590 j.at("is_fitted_").get_to(p.is_fitted_);
591}
592
593}//namespace Brush
594
595
596
597#endif
holds variable type data.
Definition data.h:51
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
auto Isnt(DataType dt) -> bool
Definition node.h:43
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
auto IsWeighable() noexcept -> bool
Definition node.h:46
ProgramType PT
Definition program.h:40
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:31
auto Is(NodeType nt) -> bool
Definition node.h:291
tree< Node >::pre_order_iterator Iter
Definition program.h:37
tree< Node >::post_order_iterator PostIter
Definition program.h:38
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
ProgramType
Definition types.h:70
An individual program, a.k.a. model.
Definition program.h:50
Program(const std::reference_wrapper< SearchSpace > s, const tree< Node > t)
Definition program.h:78
TreeType predict_proba(const Dataset &d)
Definition program.h:222
void lock_nodes(int end_depth=0, bool keep_leaves_unlocked=true)
Iterates over the program, locking the nodes until it reaches a certain depth. If the node is a Split...
Definition program.h:342
void update_weights(const Dataset &d)
Updates the program's weights using non-linear least squares.
Definition program.h:567
typename std::conditional_t< PType==PT::Regressor, ArrayXf, std::conditional_t< PType==PT::BinaryClassifier, ArrayXb, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXi, std::conditional_t< PType==PT::Representer, ArrayXXf, ArrayXf > > > > RetType
the return type of the tree when calling :func:predict.
Definition program.h:55
std::optional< std::reference_wrapper< SearchSpace > > SSref
Definition program.h:75
int linear_complexity() const
count the linear complexity of the program.
Definition program.h:101
RetType predict(const Ref< const ArrayXXf > &X)
Convenience function to call predict directly from X data.
Definition program.h:240
void set_weights(const ArrayXf &weights)
Set the weights in the tree from an array of weights.
Definition program.h:314
TreeType predict(const Dataset &d)
the standard predict function. Returns the output of the Tree directly.
Definition program.h:183
static constexpr PT program_type
Definition program.h:52
R predict_with_weights(const Dataset &d, const W **weights)
Definition program.h:160
std::conditional_t< PType==PT::BinaryClassifier, ArrayXf, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXXf, RetType > > TreeType
the type of output from the tree object
Definition program.h:62
int complexity() const
count the (recursive) complexity of the program.
Definition program.h:93
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
Definition program.h:121
ArrayXf get_weights()
Get the weights of the tree as an array.
Definition program.h:291
Program< PType > & fit(const Dataset &d)
Definition program.h:150
int depth() const
count the tree depth of the program. The depth is not influenced by weighted nodes.
Definition program.h:128
auto predict_with_weights(const Dataset &d, const ArrayXf &weights)
Definition program.h:168
int get_n_weights() const
returns the number of weights in the program.
Definition program.h:269
int depth_at(Iter &top) const
count the depth of a given subtree. The depth is not influenced by weighted nodes....
Definition program.h:138
vector< Node > linearize() const
turns program tree into a linear program.
Definition program.h:551
ArrayXi predict(const Dataset &d)
Specialized predict function for multiclass classification.
Definition program.h:209
void set_search_space(const std::reference_wrapper< SearchSpace > s)
Definition program.h:86
ArrayXb predict(const Dataset &d)
Specialized predict function for binary classification.
Definition program.h:196
string get_dot_model(string extras="") const
Definition program.h:411
Program()=default
string get_model(string fmt="compact", bool pretty=false) const
Get the model as a string.
Definition program.h:395
TreeType predict_proba(const Ref< const ArrayXXf > &X)
Predict probabilities from X.
Definition program.h:255
int depth_to_reach(Iter &top) const
count the depth until reaching the given subtree. The depth is not influenced by weighted nodes....
Definition program.h:146
Program< PType > copy()
Definition program.h:84
Program< PType > & fit(const Ref< const ArrayXXf > &X, const Ref< const ArrayXf > &y)
Convenience function to call fit directly from X,y data.
Definition program.h:231
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Definition program.h:110
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...