36typedef tree<Node>::pre_order_iterator
Iter;
37typedef tree<Node>::post_order_iterator
PostIter;
75 std::optional<std::reference_wrapper<SearchSpace>>
SSref;
78 Program(
const std::reference_wrapper<SearchSpace> s,
const tree<Node> t)
81 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
88 SSref = std::optional<std::reference_wrapper<SearchSpace>>{s};
94 auto head =
Tree.begin();
96 return head.node->get_complexity();
102 auto head =
Tree.begin();
104 return head.node->get_linear_complexity();
110 int size(
bool include_weight=
true)
const{
111 auto head =
Tree.begin();
113 return head.node->get_size(include_weight);
123 return top.node->get_size(include_weight);
131 return 1+
Tree.max_depth();
139 return 1+
Tree.max_depth(top);
147 return 1+
Tree.depth(top);
153 this->is_fitted_ =
true;
167 this->Tree = new_program.
Tree;
168 this->is_fitted_ =
false;
171 if (new_program.
SSref.has_value())
172 this->SSref = new_program.
SSref;
190 template <
typename R,
typename W>
196 return Tree.begin().node->predict<R>(d, weights);
201 float const * wptr = weights.data();
213 template <
typename R = RetType>
226 template <
typename R = RetType>
239 template <
typename R = RetType>
251 template <PT P = PType>
284 template <PT P = PType>
306 const auto& node = i.node->data;
312 if (node.weight_is_fixed)
319 || (node.get_is_weighted() &&
IsWeighable(node.ret_type)) )
336 const auto& node = t.node->data;
339 if (node.weight_is_fixed)
343 || (node.get_is_weighted() &&
IsWeighable(node.ret_type)) )
368 auto& node = i.node->data;
371 if (node.weight_is_fixed)
375 || (node.get_is_weighted() &&
IsWeighable(node.node_type)) )
393 void lock_nodes(
int end_depth=0,
bool keep_leaves_unlocked=
true,
bool keep_current_weights=
false)
410 auto tree_iter =
Tree.begin();
412 std::for_each(
Tree.begin(),
Tree.end(),
414 auto d = Tree.depth(tree_iter);
415 std::advance(tree_iter, 1);
418 if (n.get_is_weighted() && d<end_depth)
419 n.weight_is_fixed = keep_current_weights;
421 n.weight_is_fixed = false;
425 if (
IsLeaf(n.node_type)) {
426 if (d<end_depth && !keep_leaves_unlocked)
427 n.node_is_fixed = true;
431 n.node_is_fixed = false;
439 n.node_is_fixed =
true;
441 n.node_is_fixed =
false;
448 if (keep_leaves_unlocked)
450 n.set_keep_split_feature(false);
455 n.set_keep_split_feature(d+1<end_depth);
474 string get_model(
string fmt=
"compact",
bool pretty=
false)
const
476 auto head =
Tree.begin();
478 return head.node->get_tree_model(pretty);
481 return head.node->get_model(pretty);
494 string out =
"digraph G {\n";
495 if (! extras.empty())
496 out += fmt::format(
"{}\n", extras);
498 auto get_id = [](
const auto& n){
500 return n->data.get_name(
false);
502 return fmt::format(
"{}",fmt::ptr(n)).substr(2);
505 std::map<string, unsigned int> node_count;
507 for (
Iter iter =
Tree.begin(); iter!=
Tree.end(); iter++)
509 const auto& parent = iter.node;
512 string parent_id = get_id(parent);
523 if (i==0 && parent->data.get_is_weighted())
526 string font_color =
"";
527 if (parent->data.weight_is_fixed) {
528 font_color =
", fontcolor=lightcoral";
531 out +=
"y [shape=box];\n";
532 out += fmt::format(
"y -> \"{}\" [label=\"{:.2f}\"{}];\n",
542 string node_label = parent->data.get_name(is_constant);
545 std::string feature = parent->data.get_feature();
546 std::string threshold = fmt::format(
"{:.2f}", parent->data.W);
549 if (parent->data.keep_split_feature)
551 if (parent->data.weight_is_fixed)
554 node_label = fmt::format(
"{} >= {}?", feature, threshold);
557 node_label = fmt::format(
"Add");
560 string node_style = parent->data.get_prob_change() >0.0 ?
"" :
", style=filled, fillcolor=lightcoral";
561 out += fmt::format(
"\"{}\" [label=\"{}\"{}];\n", parent_id, node_label, node_style);
564 auto kid = iter.node->first_child;
565 for (
int j = 0; j < iter.number_of_children(); ++j)
567 string edge_label=
"";
568 string head_label=
"";
569 string tail_label=
"";
570 bool use_head_tail_labels =
false;
572 string kid_id = get_id(kid);
576 if (kid->data.get_is_weighted()
580 edge_label = fmt::format(
"{:.2f}",kid->data.W);
584 use_head_tail_labels=
true;
586 tail_label = fmt::format(
">= {:.2f}",parent->data.W);
592 head_label=edge_label;
595 use_head_tail_labels=
true;
602 head_label = edge_label;
606 string font_color =
"";
607 if (kid->data.weight_is_fixed) {
608 font_color =
", fontcolor=lightcoral";
611 if (use_head_tail_labels){
612 out += fmt::format(
"\"{}\" -> \"{}\" [headlabel=\"{}\",taillabel=\"{}\"{}];\n",
621 out += fmt::format(
"\"{}\" -> \"{}\" [label=\"{}\"{}];\n",
628 kid = kid->next_sibling;
634 out += fmt::format(
"\"{}\" -> \"{}\" [label=\"\"];\n",
640 out += fmt::format(
"\"{}\" [label=\"{:.2f}\"{}];\n",
650 out +=
"label=\"^ split feature fixed, * split threshold fixed\";\nlabelloc=bottom;\nfontsize=10;";
659 vector<Node> linear_program;
661 linear_program.push_back(i.node->data);
662 return linear_program;
673template<ProgramType PType>
680 WO.update((*
this), d);
687template<ProgramType PType>
693template<ProgramType PType>
696 j.at(
"Tree").get_to(p.
Tree);
holds variable type data.
#define HANDLE_ERROR_THROW(err)
< nsga2 selection operator for getting the front
auto Isnt(DataType dt) -> bool
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
auto IsWeighable() noexcept -> bool
void from_json(const json &j, Fitness &f)
auto IsLeaf(NodeType nt) noexcept -> bool
auto Is(NodeType nt) -> bool
tree< Node >::pre_order_iterator Iter
tree< Node >::post_order_iterator PostIter
void to_json(json &j, const Fitness &f)
Eigen::Array< int, Eigen::Dynamic, 1 > ArrayXi
An individual program, a.k.a. model.
Program(const std::reference_wrapper< SearchSpace > s, const tree< Node > t)
Program< PType > & replace_program(const json &j)
Replace the current program from a JSON representation, invalidating fitness.
TreeType predict_proba(const Dataset &d)
void lock_nodes(int end_depth=0, bool keep_leaves_unlocked=true, bool keep_current_weights=false)
Iterates over the program, locking the nodes until it reaches a certain depth. If the node is a Split...
void update_weights(const Dataset &d)
Updates the program's weights using non-linear least squares.
typename std::conditional_t< PType==PT::Regressor, ArrayXf, std::conditional_t< PType==PT::BinaryClassifier, ArrayXb, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXi, std::conditional_t< PType==PT::Representer, ArrayXXf, ArrayXf > > > > RetType
the return type of the tree when calling :func:predict.
std::optional< std::reference_wrapper< SearchSpace > > SSref
int linear_complexity() const
count the linear complexity of the program.
RetType predict(const Ref< const ArrayXXf > &X)
Convenience function to call predict directly from X data.
void set_weights(const ArrayXf &weights)
Set the weights in the tree from an array of weights.
TreeType predict(const Dataset &d)
the standard predict function. Returns the output of the Tree directly.
static constexpr PT program_type
R predict_with_weights(const Dataset &d, const W **weights)
std::conditional_t< PType==PT::BinaryClassifier, ArrayXf, std::conditional_t< PType==PT::MulticlassClassifier, ArrayXXf, RetType > > TreeType
the type of output from the tree object
int complexity() const
count the (recursive) complexity of the program.
Program< PType > & replace_program(const Program< PType > &new_program)
Replace the current program with a new program, invalidating fitness.
int size_at(Iter &top, bool include_weight=true) const
count the size of a given subtree, optionally including the weights in weighted nodes....
ArrayXf get_weights()
Get the weights of the tree as an array.
Program< PType > & fit(const Dataset &d)
int depth() const
count the tree depth of the program. The depth is not influenced by weighted nodes.
auto predict_with_weights(const Dataset &d, const ArrayXf &weights)
int get_n_weights() const
returns the number of weights in the program.
int depth_at(Iter &top) const
count the depth of a given subtree. The depth is not influenced by weighted nodes....
vector< Node > linearize() const
turns program tree into a linear program.
ArrayXi predict(const Dataset &d)
Specialized predict function for multiclass classification.
void set_search_space(const std::reference_wrapper< SearchSpace > s)
ArrayXb predict(const Dataset &d)
Specialized predict function for binary classification.
string get_dot_model(string extras="") const
string get_model(string fmt="compact", bool pretty=false) const
Get the model as a string.
TreeType predict_proba(const Ref< const ArrayXXf > &X)
Predict probabilities from X.
int depth_to_reach(Iter &top) const
count the depth until reaching the given subtree. The depth is not influenced by weighted nodes....
Program< PType > & fit(const Ref< const ArrayXXf > &X, const Ref< const ArrayXf > &y)
Convenience function to call fit directly from X,y data.
int size(bool include_weight=true) const
count the tree size of the program, including the weights in weighted nodes.
Holds a search space, consisting of operations and terminals and functions, and methods to sample tha...