Feat C++ API
A feature engineering automation tool
MyCARTree.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
31 
32 #ifndef _MYCARTREE_H__
33 #define _MYCARTREE_H__
34 
35 //#include <shogun/lib/config.h>
36 
37 #include <shogun/multiclass/tree/TreeMachine.h>
38 #include <shogun/ensemble/MeanRule.h>
39 #include <shogun/features/DenseFeatures.h>
40 #include "MyCARTreeNodeData.h"
41 
42 using std::vector;
43 namespace shogun
44 {
45 
46 
47  class CMyCARTree : public CTreeMachine<MyCARTreeNodeData>
48  {
84  public:
87  CMyCARTree();
88 
93  CMyCARTree(SGVector<bool> attribute_types, EProblemType prob_type=PT_MULTICLASS);
94 
101  CMyCARTree(SGVector<bool> attribute_types, EProblemType prob_type, int32_t num_folds, bool cv_prune);
102 
104  virtual ~CMyCARTree();
105 
109  virtual void set_labels(CLabels* lab);
110 
114  virtual const char* get_name() const;
115 
119  virtual EProblemType get_machine_problem_type() const;
120 
124  void set_machine_problem_type(EProblemType mode);
125 
130  virtual bool is_label_valid(CLabels* lab) const;
131 
136  virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
137 
142  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
143 
148  virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
149 
156  void prune_using_test_dataset(CDenseFeatures<float64_t>* feats, CLabels* gnd_truth, SGVector<float64_t> weights=SGVector<float64_t>());
157 
161  void set_weights(SGVector<float64_t> w);
162 
166  SGVector<float64_t> get_weights() const;
167 
169  void clear_weights();
170 
174  void set_feature_types(SGVector<bool> ft);
175 
179  SGVector<bool> get_feature_types() const;
180 
182  void clear_feature_types();
183 
188  int32_t get_num_folds() const;
189 
194  void set_num_folds(int32_t folds);
195 
200  int32_t get_max_depth() const;
201 
206  void set_max_depth(int32_t depth);
207 
212  int32_t get_min_node_size() const;
213 
218  void set_min_node_size(int32_t nsize);
219 
224  void set_cv_pruning(bool cv_pruning);
225 
230  float64_t get_label_epsilon();
231 
236  void set_label_epsilon(float64_t epsilon);
237 
238  void pre_sort_features(CFeatures* data, SGMatrix<float64_t>& sorted_feats, SGMatrix<index_t>& sorted_indices);
239 
240  void set_sorted_features(SGMatrix<float64_t>& sorted_feats, SGMatrix<index_t>& sorted_indices);
241 
244  std::vector<double> feature_importances();
245 
249  SGVector<float64_t> get_certainty_vector() const;
250 
253  void set_probabilities(CLabels* labels, CFeatures* data=NULL);
254  protected:
259  virtual bool train_machine(CFeatures* data=NULL);
260 
269  virtual CBinaryTreeMachineNode<MyCARTreeNodeData>* CARTtrain(CFeatures* data, SGVector<float64_t> weights, CLabels* labels, int32_t level);
270 
277  SGVector<float64_t> get_unique_labels(SGVector<float64_t> labels_vec, int32_t &n_ulabels);
278 
293  virtual int32_t compute_best_attribute(const SGMatrix<float64_t>& mat, const SGVector<float64_t>& weights, CLabels* labels,
294  SGVector<float64_t>& left, SGVector<float64_t>& right, SGVector<bool>& is_left_final, int32_t &num_missing,
295  int32_t &count_left, int32_t &count_right, float64_t& IG, int32_t subset_size=0, const SGVector<int32_t>& active_indices=SGVector<index_t>());
296 
297 
306  SGVector<bool> surrogate_split(SGMatrix<float64_t> data, SGVector<float64_t> weights, SGVector<bool> nm_left, int32_t attr);
307 
308 
321  void handle_missing_vecs_for_continuous_surrogate(SGMatrix<float64_t> m, CDynamicArray<int32_t>* missing_vecs,
322  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
323  SGVector<float64_t> weights, float64_t p, int32_t attr);
324 
337  void handle_missing_vecs_for_nominal_surrogate(SGMatrix<float64_t> m, CDynamicArray<int32_t>* missing_vecs,
338  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
339  SGVector<float64_t> weights, float64_t p, int32_t attr);
340 
349  float64_t gain(SGVector<float64_t> wleft, SGVector<float64_t> wright, SGVector<float64_t> wtotal, SGVector<float64_t> labels);
350 
358  float64_t gain(const SGVector<float64_t>& wleft, const SGVector<float64_t>& wright, const SGVector<float64_t>& wtotal);
359 
366  float64_t gini_impurity_index(const SGVector<float64_t>& weighted_lab_classes, float64_t &total_weight);
367 
375  float64_t least_squares_deviation(const SGVector<float64_t>& labels, const SGVector<float64_t>& weights, float64_t &total_weight);
376 
383  CLabels* apply_from_current_node(CDenseFeatures<float64_t>* feats, bnode_t* current,
384  bool set_certainty=false);
385 
391  void prune_by_cross_validation(CDenseFeatures<float64_t>* data, int32_t folds);
392 
402  float64_t compute_error(CLabels* labels, CLabels* reference, SGVector<float64_t> weights);
403 
409  CDynamicObjectArray* prune_tree(CTreeMachine<MyCARTreeNodeData>* tree);
410 
416  float64_t find_weakest_alpha(bnode_t* node);
417 
423  void cut_weakest_link(bnode_t* node, float64_t alpha);
424 
429  void form_t1(bnode_t* node);
430 
432  void init();
433 
436  void get_importance(bnode_t* node, vector<double>& importances);
437 
438 
439  public:
441  static const float64_t MISSING;
442 
444  static const float64_t MIN_SPLIT_GAIN;
445 
447  static const float64_t EQ_DELTA;
448 
449  protected:
451  float64_t m_label_epsilon;
452 
454  SGVector<bool> m_nominal;
455 
457  SGVector<float64_t> m_weights;
458 
460  SGMatrix<float64_t> m_sorted_features;
461 
463  SGMatrix<index_t> m_sorted_indices;
464 
467 
470 
473 
476 
478  int32_t m_folds;
479 
481  EProblemType m_mode;
482 
484  CDynamicArray<float64_t>* m_alphas;
485 
487  int32_t m_max_depth;
488 
491 
495  SGVector<float64_t> m_certainty;
496  };
497 }
498 
499 #endif
int32_t get_min_node_size() const
Definition: MyCARTree.cc:264
CDynamicArray< float64_t > * m_alphas
Definition: MyCARTree.h:484
void set_max_depth(int32_t depth)
Definition: MyCARTree.cc:258
float64_t m_label_epsilon
Definition: MyCARTree.h:451
void set_weights(SGVector< float64_t > w)
Definition: MyCARTree.cc:208
void handle_missing_vecs_for_continuous_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: MyCARTree.cc:956
std::vector< double > feature_importances()
Definition: MyCARTree.cc:1580
int32_t get_max_depth() const
Definition: MyCARTree.cc:253
SGVector< float64_t > get_weights() const
Definition: MyCARTree.cc:214
virtual int32_t compute_best_attribute(const SGMatrix< float64_t > &mat, const SGVector< float64_t > &weights, CLabels *labels, SGVector< float64_t > &left, SGVector< float64_t > &right, SGVector< bool > &is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right, float64_t &IG, int32_t subset_size=0, const SGVector< int32_t > &active_indices=SGVector< index_t >())
Definition: MyCARTree.cc:570
static const float64_t MISSING
Definition: MyCARTree.h:441
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: MyCARTree.cc:136
CMyCARTree()
This class implements the Classification And Regression Trees algorithm by Breiman et al for decision...
Definition: MyCARTree.cc:56
SGVector< float64_t > get_unique_labels(SGVector< float64_t > labels_vec, int32_t &n_ulabels)
Definition: MyCARTree.cc:546
void set_machine_problem_type(EProblemType mode)
Definition: MyCARTree.cc:100
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
Definition: MyCARTree.cc:115
float64_t gain(SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
Definition: MyCARTree.cc:1093
EProblemType m_mode
Definition: MyCARTree.h:481
static const float64_t MIN_SPLIT_GAIN
Definition: MyCARTree.h:444
virtual EProblemType get_machine_problem_type() const
Definition: MyCARTree.cc:42
void form_t1(bnode_t *node)
Definition: MyCARTree.cc:1521
float64_t get_label_epsilon()
Definition: MyCARTree.cc:50
int32_t m_max_depth
Definition: MyCARTree.h:487
void cut_weakest_link(bnode_t *node, float64_t alpha)
Definition: MyCARTree.cc:1491
SGVector< float64_t > m_weights
Definition: MyCARTree.h:457
int32_t m_min_node_size
Definition: MyCARTree.h:490
void set_min_node_size(int32_t nsize)
Definition: MyCARTree.cc:269
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
Definition: MyCARTree.cc:150
int32_t get_num_folds() const
Definition: MyCARTree.cc:242
float64_t gini_impurity_index(const SGVector< float64_t > &weighted_lab_classes, float64_t &total_weight)
Definition: MyCARTree.cc:1119
virtual void set_labels(CLabels *lab)
Definition: MyCARTree.cc:86
void prune_by_cross_validation(CDenseFeatures< float64_t > *data, int32_t folds)
Definition: MyCARTree.cc:1243
SGVector< bool > get_feature_types() const
Definition: MyCARTree.cc:231
virtual CBinaryTreeMachineNode< MyCARTreeNodeData > * CARTtrain(CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
Definition: MyCARTree.cc:351
static const float64_t EQ_DELTA
Definition: MyCARTree.h:447
virtual bool is_label_valid(CLabels *lab) const
Definition: MyCARTree.cc:105
float64_t least_squares_deviation(const SGVector< float64_t > &labels, const SGVector< float64_t > &weights, float64_t &total_weight)
Definition: MyCARTree.cc:1129
SGVector< float64_t > m_certainty
Definition: MyCARTree.h:495
CDynamicObjectArray * prune_tree(CTreeMachine< MyCARTreeNodeData > *tree)
Definition: MyCARTree.cc:1420
void handle_missing_vecs_for_nominal_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: MyCARTree.cc:1013
void set_cv_pruning(bool cv_pruning)
Definition: MyCARTree.cc:45
virtual bool train_machine(CFeatures *data=NULL)
Definition: MyCARTree.cc:281
SGMatrix< float64_t > m_sorted_features
Definition: MyCARTree.h:460
void set_probabilities(CLabels *labels, CFeatures *data=NULL)
Definition: MyCARTree.cc:1615
SGMatrix< index_t > m_sorted_indices
Definition: MyCARTree.h:463
float64_t compute_error(CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
Definition: MyCARTree.cc:1382
void get_importance(bnode_t *node, vector< double > &importances)
Definition: MyCARTree.cc:1597
SGVector< bool > surrogate_split(SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
Definition: MyCARTree.cc:881
float64_t find_weakest_alpha(bnode_t *node)
Definition: MyCARTree.cc:1470
CLabels * apply_from_current_node(CDenseFeatures< float64_t > *feats, bnode_t *current, bool set_certainty=false)
Definition: MyCARTree.cc:1144
void set_feature_types(SGVector< bool > ft)
Definition: MyCARTree.cc:225
void prune_using_test_dataset(CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
Definition: MyCARTree.cc:162
SGVector< float64_t > get_certainty_vector() const
Definition: MyCARTree.cc:1238
void clear_feature_types()
Definition: MyCARTree.cc:236
void pre_sort_features(CFeatures *data, SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: MyCARTree.cc:331
virtual ~CMyCARTree()
Definition: MyCARTree.cc:81
void set_sorted_features(SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: MyCARTree.cc:324
void set_label_epsilon(float64_t epsilon)
Definition: MyCARTree.cc:275
SGVector< bool > m_nominal
Definition: MyCARTree.h:454
virtual const char * get_name() const
Definition: MyCARTree.cc:40
void set_num_folds(int32_t folds)
Definition: MyCARTree.cc:247