Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
tree_node.h
Go to the documentation of this file.
1#ifndef TREE_NODE_H
2#define TREE_NODE_H
3#include <tuple>
4#include <unordered_map>
5
6#include "../init.h"
7#include "../data/data.h"
8#include "node.h"
9#include "functions.h"
10#include "nodetype.h"
11#include "../../thirdparty/tree.hh"
12
13using std::string;
15using Brush::Node;
16
21template<>
22class tree_node_<Node> { // size: 5*4=20 bytes (on 32 bit arch), can be reduced by 8.
23 public:
27
28 tree_node_(const Node& val)
30 {}
31
34 {}
35
40
41 template<typename T>
42 auto fit(const Dataset& d);
43
44 template<typename T>
45 auto predict(const Dataset& d, const float** weights=nullptr);
46
47 template<typename T, typename W>
48 auto predict(const Dataset& d, const W** weights);
49
50 string get_model(bool pretty=false) const;
51 string get_tree_model(bool pretty=false, string offset="") const;
52
53 int get_complexity() const;
55 int get_size(bool include_weight=true) const;
56};
57using TreeNode = class tree_node_<Node>;
58
60// fit, eval, predict
61
62#include "dispatch_table.h"
63
64template<typename T>
65auto TreeNode::fit(const Dataset& d)
66{
67 auto F = dtable_fit.template Get<T>(data.node_type, data.sig_hash);
68 return F(d, (*this));
69};
70
71template<typename T>
72auto TreeNode::predict(const Dataset& d, const float** weights)
73{
74 auto F = dtable_predict.template Get<T>(data.node_type, data.sig_hash);
75 return F(d, (*this), weights);
76};
77
78template<typename T, typename W>
79auto TreeNode::predict(const Dataset& d, const W** weights)
80{
81 auto F = dtable_predict.template Get<T>(data.node_type, data.sig_dual_hash);
82 return F(d, (*this), weights);
83};
84
85// serialization functions
86void to_json(json &j, const tree<Node> &t);
87void from_json(const json &j, tree<Node> &t);
88
89// namespace node{
90
91// template<NodeType NT=0>
92// string get_model(const Node& data, const vector<string>& children)
93// {
94// string args = "";
95// for (int i = 0; i < children.size(); ++i){
96// args += children.at(i);
97// if (i < children.size()-1)
98// args += ",";
99// }
100
101// return fmt::format("{}({})", data.get_name(), args);
102
103// }
104
105// template<>
106// string get_model<NodeType::SplitBest>(const Node& data, const vector<string>& children)
107// {
108// return fmt::format("IF-THEN-ELSE({}>{:.3f},{},{})",
109// data.get_feature(),
110// data.W,
111// children.at(0),
112// children.at(1)
113// );
114
115// }
116
117// template<>
118// string get_model<NodeType::SplitOn>(const Node& data, const vector<string>& children)
119// {
120// return fmt::format("IF-THEN-ELSE({}>{:.3f},{},{})",
121// children.at(0),
122// data.W,
123// children.at(1),
124// children.at(2)
125// );
126
127// }
128// }
129#endif
holds variable type data.
Definition data.h:51
auto fit(const Dataset &d)
int get_linear_complexity() const
string get_model(bool pretty=false) const
auto predict(const Dataset &d, const float **weights=nullptr)
tree_node_(const Node &val)
Definition tree_node.h:28
string get_tree_model(bool pretty=false, string offset="") const
auto predict(const Dataset &d, const W **weights)
tree_node_< Node > * prev_sibling
Definition tree_node.h:38
tree_node_< Node > * parent
Definition tree_node.h:36
tree_node_(Node &&val)
Definition tree_node.h:32
int get_size(bool include_weight=true) const
int get_complexity() const
tree_node_< Node > * last_child
Definition tree_node.h:37
tree_node_< Node > * first_child
Definition tree_node.h:37
tree_node_< Node > * next_sibling
Definition tree_node.h:38
class tree_node_< Node > TreeNode
DispatchTable< false > dtable_predict
void from_json(const json &j, Fitness &f)
Definition fitness.cpp:25
void to_json(json &j, const Fitness &f)
Definition fitness.cpp:6
DispatchTable< true > dtable_fit
class holding the data for a node in a tree.
Definition node.h:84