Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
tree_node.h
Go to the documentation of this file.
1#ifndef TREE_NODE_H
2#define TREE_NODE_H
3#include <tuple>
4#include <unordered_map>
5
6#include "../init.h"
7#include "../data/data.h"
8#include "node.h"
9#include "functions.h"
10#include "nodetype.h"
11#include "../../thirdparty/tree.hh"
12
13using std::string;
15using Brush::Node;
16
21template<>
22class tree_node_<Node> { // size: 5*4=20 bytes (on 32 bit arch), can be reduced by 8.
23 public:
25 : parent(0), first_child(0), last_child(0), prev_sibling(0), next_sibling(0)
26 {}
27
29 : parent(0), first_child(0), last_child(0), prev_sibling(0), next_sibling(0), data(val)
30 {}
31
33 : parent(0), first_child(0), last_child(0), prev_sibling(0), next_sibling(0), data(val)
34 {}
35
40
41 template<typename T>
42 auto fit(const Dataset& d);
43
44 template<typename T>
45 auto predict(const Dataset& d, const float** weights=nullptr);
46
47 template<typename T, typename W>
48 auto predict(const Dataset& d, const W** weights);
49
50 string get_model(bool pretty=false) const;
51 string get_tree_model(bool pretty=false, string offset="") const;
52
53 int get_complexity() const;
54 int get_size(bool include_weight=true) const;
55};
56using TreeNode = class tree_node_<Node>;
57
59// fit, eval, predict
60
61#include "dispatch_table.h"
62
63template<typename T>
64auto TreeNode::fit(const Dataset& d)
65{
66 auto F = dtable_fit.template Get<T>(data.node_type, data.sig_hash);
67 return F(d, (*this));
68};
69
70template<typename T>
71auto TreeNode::predict(const Dataset& d, const float** weights)
72{
73 auto F = dtable_predict.template Get<T>(data.node_type, data.sig_hash);
74 return F(d, (*this), weights);
75};
76
77template<typename T, typename W>
78auto TreeNode::predict(const Dataset& d, const W** weights)
79{
80 auto F = dtable_predict.template Get<T>(data.node_type, data.sig_dual_hash);
81 return F(d, (*this), weights);
82};
83
84// serialization functions
85void to_json(json &j, const tree<Node> &t);
86void from_json(const json &j, tree<Node> &t);
87
88// namespace node{
89
90// template<NodeType NT=0>
91// string get_model(const Node& data, const vector<string>& children)
92// {
93// string args = "";
94// for (int i = 0; i < children.size(); ++i){
95// args += children.at(i);
96// if (i < children.size()-1)
97// args += ",";
98// }
99
100// return fmt::format("{}({})", data.get_name(), args);
101
102// }
103
104// template<>
105// string get_model<NodeType::SplitBest>(const Node& data, const vector<string>& children)
106// {
107// return fmt::format("IF-THEN-ELSE({}>{:.3f},{},{})",
108// data.get_feature(),
109// data.W,
110// children.at(0),
111// children.at(1)
112// );
113
114// }
115
116// template<>
117// string get_model<NodeType::SplitOn>(const Node& data, const vector<string>& children)
118// {
119// return fmt::format("IF-THEN-ELSE({}>{:.3f},{},{})",
120// children.at(0),
121// data.W,
122// children.at(1),
123// children.at(2)
124// );
125
126// }
127// }
128#endif
void bind_engine(py::module &m, string name)
holds variable type data.
Definition data.h:51
auto fit(const Dataset &d)
string get_model(bool pretty=false) const
auto predict(const Dataset &d, const float **weights=nullptr)
tree_node_(const Node &val)
Definition tree_node.h:28
string get_tree_model(bool pretty=false, string offset="") const
auto predict(const Dataset &d, const W **weights)
tree_node_< Node > * parent
Definition tree_node.h:36
tree_node_(Node &&val)
Definition tree_node.h:32
int get_size(bool include_weight=true) const
int get_complexity() const
tree_node_< Node > * first_child
Definition tree_node.h:37
tree_node_< Node > * next_sibling
Definition tree_node.h:38
class tree_node_< Node > TreeNode
DispatchTable< false > dtable_predict
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:24
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
DispatchTable< true > dtable_fit
class holding the data for a node in a tree.
Definition node.h:84