Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
program.h
Go to the documentation of this file.
1/* Brush
2
3copyright 2020 William La Cava
4license: GNU/GPL v3
5*/
6#ifndef PROGRAM_H
7#define PROGRAM_H
8//external includes
9
10//
11
12#include <string>
13#include "assert.h"
14
15// internal includes
16
17// #include "data/data.h"
18#include "../init.h"
19#include "tree_node.h"
20#include "node.h"
22#include "../params.h"
23#include "../util/utils.h"
24#include "functions.h"
25// #include "../variation.h"
26// #include "weight_optimizer.h"
27
28
29using std::cout;
30using std::string;
33
34namespace Brush {
35
36
37typedef tree<Node>::pre_order_iterator Iter;
38typedef tree<Node>::post_order_iterator PostIter;
39
41
42// for unsupervised learning, classification and regression.
43
49template<PT PType> struct Program
50{
52 static constexpr PT program_type = PType;
53
55 using RetType = typename std::conditional_t<PType == PT::Regressor, ArrayXf,
56 std::conditional_t<PType == PT::BinaryClassifier, ArrayXb,
57 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXi,
58 std::conditional_t<PType == PT::Representer, ArrayXXf, ArrayXf
59 >>>>;
60
62 using TreeType = std::conditional_t<PType == PT::BinaryClassifier, ArrayXf,
63 std::conditional_t<PType == PT::MulticlassClassifier, ArrayXXf,
64 RetType>>;
65
68
70 // Fitness fitness;
71
73 tree<Node> Tree;
75 std::optional<std::reference_wrapper<SearchSpace>> SSref;
76
77 Program() = default;
78 Program(const std::reference_wrapper<SearchSpace> s, const tree<Node> t)
79 : Tree(t)
80 {
81 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
82 }
83
84 Program<PType> copy() { return Program<PType>(*this); }
85
86 inline void set_search_space(const std::reference_wrapper<SearchSpace> s)
87 {
88 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
89 }
90
93 int complexity() const{
94 auto head = Tree.begin();
95
96 return head.node->get_complexity();
97 }
98
101 int linear_complexity() const{
102 auto head = Tree.begin();
103
104 return head.node->get_linear_complexity();
105 }
106
110 int size(bool include_weight=true) const{
111 auto head = Tree.begin();
112
113 return head.node->get_size(include_weight);
114 }
115
121 int size_at(Iter& top, bool include_weight=true) const{
122
123 return top.node->get_size(include_weight);
124 }
125
128 int depth() const{
129 //tree.hh count the number of edges. We need to ensure that a single-node
130 //tree has depth>0
131 return 1+Tree.max_depth();
132 }
133
138 int depth_at(Iter& top) const{
139 return 1+Tree.max_depth(top);
140 }
141
146 int depth_to_reach(Iter& top) const{
147 return 1+Tree.depth(top);
148 }
149
151 {
152 TreeType out = Tree.begin().node->fit<TreeType>(d);
153 this->is_fitted_ = true;
155 // this->valid = true;
156 return *this;
157 };
158
159 template <typename R, typename W>
160 R predict_with_weights(const Dataset &d, const W** weights)
161 {
162 if (!is_fitted_)
163 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
164
165 return Tree.begin().node->predict<R>(d, weights);
166 };
167
168 auto predict_with_weights(const Dataset &d, const ArrayXf& weights)
169 {
170 float const * wptr = weights.data();
171 return this->predict_with_weights<RetType>(d, &wptr);
172 };
173
182 template <typename R = RetType>
183 TreeType predict(const Dataset &d) requires(is_same_v<R, TreeType>)
184 {
185 if (!is_fitted_)
186 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
187
188 return Tree.begin().node->predict<TreeType>(d);
189 };
190
195 template <typename R = RetType>
196 ArrayXb predict(const Dataset &d) requires(is_same_v<R, ArrayXb>)
197 {
198 if (!is_fitted_)
199 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
200
201 return (Tree.begin().node->predict<TreeType>(d) > 0.5);
202 };
203
208 template <typename R = RetType>
209 ArrayXi predict(const Dataset &d) requires(is_same_v<R, ArrayXi>)
210 {
211 if (!is_fitted_)
212 HANDLE_ERROR_THROW("Program is not fitted. Call 'fit' first.\n");
213
214 TreeType out = Tree.begin().node->predict<TreeType>(d);
215 auto argmax = Function<NodeType::ArgMax>{};
216 return argmax(out);
217 };
218
219 // template <typename R = RetType>
220 template <PT P = PType>
221 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
223 {
224 return predict<TreeType>(d);
225 };
226
231 Program<PType>& fit(const Ref<const ArrayXXf>& X, const Ref<const ArrayXf>& y)
232 {
233 Dataset d(X,y);
234 return fit(d);
235 };
236
240 RetType predict(const Ref<const ArrayXXf>& X)
241 {
242 Dataset d(X);
243 return predict(d);
244 };
245
253 template <PT P = PType>
254 requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
255 TreeType predict_proba(const Ref<const ArrayXXf>& X)
256 {
257 Dataset d(X);
258 return predict_proba(d);
259 };
260
266 void update_weights(const Dataset& d);
267
269 int get_n_weights() const
270 {
271 int count=0;
272 // check tree nodes for weights
273 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
274 {
275 const auto& node = i.node->data;
276 // some nodes cannot have their weights optimized, others must have
277 if ( Is<NodeType::OffsetSum>(node.node_type)
278 || (node.get_is_weighted() && IsWeighable(node.node_type)) )
279 ++count;
280 }
281 return count;
282 }
283
289 ArrayXf get_weights()
290 {
291 ArrayXf weights(get_n_weights());
292 int i = 0;
293 for (PostIter t = Tree.begin_post(); t != Tree.end_post(); ++t)
294 {
295 const auto& node = t.node->data;
296 if ( Is<NodeType::OffsetSum>(node.node_type)
297 || (node.get_is_weighted() && IsWeighable(node.node_type)) )
298 {
299 weights(i) = node.W;
300 ++i;
301 }
302 }
303 return weights;
304 }
305
312 void set_weights(const ArrayXf& weights)
313 {
314 // take the weights set them in the tree.
315 // return the weights of the tree as an array
316 if (weights.size() != get_n_weights())
317 HANDLE_ERROR_THROW("Tried to set_weights of incorrect size");
318 int j = 0;
319 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
320 {
321 auto& node = i.node->data;
322 if ( Is<NodeType::OffsetSum>(node.node_type)
323 || (node.get_is_weighted() && IsWeighable(node.node_type)) )
324 {
325 node.W = weights(j);
326 ++j;
327 }
328 }
329 }
330
339 void lock_nodes(int end_depth=0, bool skip_leaves=true)
340 {
341 // iterate over the nodes, locking them if their depth does not exceed end_depth.
342 if (end_depth<=0)
343 return;
344
345 // we need the iterator to calculate the depth, but
346 // the lambda below iterate using nodes. So we are creating an iterator
347 // and using it to access depth.
348 auto tree_iter = Tree.begin();
349
350 std::for_each(Tree.begin(), Tree.end(),
351 [&](auto& n){
352 auto d = Tree.depth(tree_iter);
353 std::advance(tree_iter, 1);
354
355 if (skip_leaves && IsLeaf(n.node_type))
356 return;
357
358 if (d<=end_depth)
359 n.fixed = true;
360 // n.set_prob_change(0.0f);
361 }
362 );
363 }
364
372 void unlock_nodes(int start_depth=0)
373 {
374 auto tree_iter = Tree.begin();
375
376 std::for_each(Tree.begin(), Tree.end(),
377 [&](auto& n){
378 auto d = Tree.depth(tree_iter);
379 std::advance(tree_iter, 1);
380
381 if (d>=start_depth)
382 n.fixed = false;
383 // n.set_prob_change(1.0f);
384 }
385 );
386 }
387
400 string get_model(string fmt="compact", bool pretty=false) const
401 {
402 auto head = Tree.begin();
403 if (fmt=="tree")
404 return head.node->get_tree_model(pretty);
405 else if (fmt=="dot")
406 return get_dot_model(); ;
407 return head.node->get_model(pretty);
408 }
409
416 string get_dot_model(string extras="") const
417 {
418 // TODO: make the node names their hash or index, and the node label the nodetype name.
419 // ref: https://stackoverflow.com/questions/10579041/graphviz-create-new-node-with-this-same-label#10579155
420 string out = "digraph G {\n";
421 if (! extras.empty())
422 out += fmt::format("{}\n", extras);
423
424 auto get_id = [](const auto& n){
425 if (Is<NodeType::Terminal>(n->data.node_type))
426 return n->data.get_name(false);
427
428 return fmt::format("{}",fmt::ptr(n)).substr(2);
429 };
430 // bool first = true;
431 std::map<string, unsigned int> node_count;
432 int i = 0;
433 for (Iter iter = Tree.begin(); iter!=Tree.end(); iter++)
434 {
435 const auto& parent = iter.node;
436 // const auto& parent_data = iter.node->data;
437
438 string parent_id = get_id(parent);
439 // if (Is<NodeType::Terminal>(parent_data.node_type))
440 // parent_id = parent_data.get_name(false);
441 // else{
442 // parent_id = fmt::format("{}",fmt::ptr(iter.node)).substr(2);
443 // }
444 // // parent_id = parent_id.substr(2);
445
446 // if the first node is weighted, make a dummy output node so that the
447 // first node's weight can be shown
448 if (i==0 && parent->data.get_is_weighted())
449 {
450 out += "y [shape=box];\n";
451 out += fmt::format("y -> \"{}\" [label=\"{:.2f}\"];\n",
452 // parent_data.get_name(false),
453 parent_id,
454 parent->data.W
455 );
456 }
457
458 // add the node
459 bool is_constant = Is<NodeType::Constant, NodeType::MeanLabel>(parent->data.node_type);
460 string node_label = parent->data.get_name(is_constant);
461
462 if (Is<NodeType::SplitBest>(parent->data.node_type)){
463 node_label = fmt::format("{}>{:.2f}?", parent->data.get_feature(), parent->data.W);
464 }
465 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
466 node_label = fmt::format("Add");
467 }
468 out += fmt::format("\"{}\" [label=\"{}\"];\n", parent_id, node_label);
469
470 // add edges to the node's children
471 auto kid = iter.node->first_child;
472 for (int j = 0; j < iter.number_of_children(); ++j)
473 {
474 string edge_label="";
475 string head_label="";
476 string tail_label="";
477 bool use_head_tail_labels = false;
478
479 string kid_id = get_id(kid);
480 // string kid_id = fmt::format("{}",fmt::ptr(kid));
481 // kid_id = kid_id.substr(2);
482
483 if (kid->data.get_is_weighted()
485 NodeType::OffsetSum, NodeType::SplitBest>(kid->data.node_type))
486 {
487 edge_label = fmt::format("{:.2f}",kid->data.W);
488 }
489
490 if (Is<NodeType::SplitOn>(parent->data.node_type)){
491 use_head_tail_labels=true;
492 if (j == 0)
493 tail_label = fmt::format(">{:.2f}",parent->data.W);
494 else if (j==1)
495 tail_label = "Y";
496 else
497 tail_label = "N";
498
499 head_label=edge_label;
500 }
501 else if (Is<NodeType::SplitBest>(parent->data.node_type)){
502 use_head_tail_labels=true;
503 if (j == 0){
504 tail_label = "Y";
505 }
506 else
507 tail_label = "N";
508
509 head_label = edge_label;
510 }
511
512 if (use_head_tail_labels){
513 out += fmt::format("\"{}\" -> \"{}\" [headlabel=\"{}\",taillabel=\"{}\"];\n",
514 parent_id,
515 kid_id,
516 head_label,
517 tail_label
518 );
519 }
520 else{
521 out += fmt::format("\"{}\" -> \"{}\" [label=\"{}\"];\n",
522 parent_id,
523 kid_id,
524 edge_label
525 );
526 }
527 kid = kid->next_sibling;
528 }
529
530 // adding the offset as the last child
531 if (Is<NodeType::OffsetSum>(parent->data.node_type)){
532 // drawing the edge
533 out += fmt::format("\"{}\" -> \"{}\" [label=\"\"];\n",
534 parent_id,
535 parent_id+"Offset"
536 );
537
538 // drawing the node
539 out += fmt::format("\"{}\" [label=\"{:.2f}\"];\n",
540 parent_id+"Offset",
541 parent->data.W
542 );
543 }
544
545 ++i;
546 }
547 out += "}\n";
548 return out;
549 }
550
553 vector<Node> linearize() const {
554 vector<Node> linear_program;
555 for (PostIter i = Tree.begin_post(); i != Tree.end_post(); ++i)
556 linear_program.push_back(i.node->data);
557 return linear_program;
558 }
559}; // Program
560} // Brush
561
563// weight optimization
565// #include "../variation.h"
566namespace Brush{
567
568template<ProgramType PType>
570{
571 // Updates the weights within a tree.
572 // make an optimizer
573 auto WO = WeightOptimizer();
574 // get new weights from optimization.
575 WO.update((*this), d);
576};
577
578
580// serialization
581// serialization for program
582template<ProgramType PType>
583void to_json(json &j, const Program<PType> &p)
584{
585 j = json{{"Tree",p.Tree}, {"is_fitted_", p.is_fitted_}};
586}
587
588template<ProgramType PType>
589void from_json(const json &j, Program<PType>& p)
590{
591 j.at("Tree").get_to(p.Tree);
592 j.at("is_fitted_").get_to(p.is_fitted_);
593}
594
595}//namespace Brush
596
597
598
599#endif
holds variable type data.
Definition data.h:51
#define HANDLE_ERROR_THROW(err)
Definition error.h:27
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
auto Isnt(DataType dt) -> bool
Definition node.h:43
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition types.h:39
auto IsWeighable() noexcept -> bool
Definition node.h:46
ProgramType PT
Definition program.h:40
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:25
auto Is(NodeType nt) -> bool
Definition node.h:272
tree< Node >::pre_order_iterator Iter
Definition program.h:37
tree< Node >::post_order_iterator PostIter
Definition program.h:38
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
Definition types.h:40
ProgramType
Definition types.h:70
An individual program, a.k.a. model.
Definition program.h:50
Program(const std::reference_wrapper< SearchSpace > s, const tree< Node > t)
Definition program.h:78
TreeType predict_proba(const Dataset &d)
Definition program.h:222
void update_weights(const Dataset &d)
Updates the program's weights using non-linear least squares.
Definition program.h:569
typename std::conditional_t< PType==PT::Regressor, ArrayXf, std::conditional_t< PType==PT::BinaryClassifier, ArrayXb, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXi, std::conditional_t< PType==PT::Representer, ArrayXXf, ArrayXf > > > > RetType
the return type of the tree when calling :func:predict.
Definition program.h:55
std::optional< std::reference_wrapper< SearchSpace > > SSref
Definition program.h:75
int linear_complexity() const
count the linear complexity of the program.
Definition program.h:101
RetType predict(const Ref< const ArrayXXf > &X)
Convenience function to call predict directly from X data.
Definition program.h:240
void set_weights(const ArrayXf &weights)
Set the weights in the tree from an array of weights.
Definition program.h:312
TreeType predict(const Dataset &d)
the standard predict function. Returns the output of the Tree directly.
Definition program.h:183
static constexpr PT program_type
Definition program.h:52
R predict_with_weights(const Dataset &d, const W **weights)
Definition program.h:160
std::conditional_t< PType==PT::BinaryClassifier, ArrayXf, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXXf, RetType > > TreeType
the type of output from the tree object
Definition program.h:62
int complexity() const
count the (recursive) complexity of the program.
Definition program.h:93
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
Definition program.h:121
ArrayXf get_weights()
Get the weights of the tree as an array.
Definition program.h:289
Program< PType > & fit(const Dataset &d)
Definition program.h:150
int depth() const
count the tree depth of the program. The depth is not influenced by weighted nodes.
Definition program.h:128
auto predict_with_weights(const Dataset &d, const ArrayXf &weights)
Definition program.h:168
int get_n_weights() const
returns the number of weights in the program.
Definition program.h:269
int depth_at(Iter &top) const
count the depth of a given subtree. The depth is not influenced by weighted nodes....
Definition program.h:138
void unlock_nodes(int start_depth=0)
Iterates over the program, unlocking the nodes until it reaches a certain depth. It does not protect ...
Definition program.h:372
vector< Node > linearize() const
turns program tree into a linear program.
Definition program.h:553
ArrayXi predict(const Dataset &d)
Specialized predict function for multiclass classification.
Definition program.h:209
void set_search_space(const std::reference_wrapper< SearchSpace > s)
Definition program.h:86
ArrayXb predict(const Dataset &d)
Specialized predict function for binary classification.
Definition program.h:196
string get_dot_model(string extras="") const
Get the model as a dot object.
Definition program.h:416
Program()=default
string get_model(string fmt="compact", bool pretty=false) const
Get the model as a string.
Definition program.h:400
void lock_nodes(int end_depth=0, bool skip_leaves=true)
Iterates over the program, locking the nodes until it reaches a certain depth.
Definition program.h:339
TreeType predict_proba(const Ref< const ArrayXXf > &X)
Predict probabilities from X.
Definition program.h:255
int depth_to_reach(Iter &top) const
count the depth until reaching the given subtree. The depth is not influenced by weighted nodes....
Definition program.h:146
Program< PType > copy()
Definition program.h:84
Program< PType > & fit(const Ref< const ArrayXXf > &X, const Ref< const ArrayXf > &y)
Convenience function to call fit directly from X,y data.
Definition program.h:231
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Definition program.h:110
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...