32 #ifndef _MYCARTREE_H__
33 #define _MYCARTREE_H__
37 #include <shogun/multiclass/tree/TreeMachine.h>
38 #include <shogun/ensemble/MeanRule.h>
39 #include <shogun/features/DenseFeatures.h>
47 class CMyCARTree :
public CTreeMachine<MyCARTreeNodeData>
93 CMyCARTree(SGVector<bool> attribute_types, EProblemType prob_type=PT_MULTICLASS);
101 CMyCARTree(SGVector<bool> attribute_types, EProblemType prob_type, int32_t num_folds,
bool cv_prune);
114 virtual const char*
get_name()
const;
136 virtual CBinaryLabels*
apply_binary(CFeatures* data=NULL);
156 void prune_using_test_dataset(CDenseFeatures<float64_t>* feats, CLabels* gnd_truth, SGVector<float64_t> weights=SGVector<float64_t>());
238 void pre_sort_features(CFeatures* data, SGMatrix<float64_t>& sorted_feats, SGMatrix<index_t>& sorted_indices);
240 void set_sorted_features(SGMatrix<float64_t>& sorted_feats, SGMatrix<index_t>& sorted_indices);
269 virtual CBinaryTreeMachineNode<MyCARTreeNodeData>*
CARTtrain(CFeatures* data, SGVector<float64_t> weights, CLabels* labels, int32_t level);
277 SGVector<float64_t>
get_unique_labels(SGVector<float64_t> labels_vec, int32_t &n_ulabels);
293 virtual int32_t
compute_best_attribute(
const SGMatrix<float64_t>& mat,
const SGVector<float64_t>& weights, CLabels* labels,
294 SGVector<float64_t>& left, SGVector<float64_t>& right, SGVector<bool>& is_left_final, int32_t &num_missing,
295 int32_t &count_left, int32_t &count_right, float64_t& IG, int32_t subset_size=0,
const SGVector<int32_t>& active_indices=SGVector<index_t>());
306 SGVector<bool>
surrogate_split(SGMatrix<float64_t> data, SGVector<float64_t> weights, SGVector<bool> nm_left, int32_t attr);
322 CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
323 SGVector<float64_t> weights, float64_t p, int32_t attr);
338 CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
339 SGVector<float64_t> weights, float64_t p, int32_t attr);
349 float64_t
gain(SGVector<float64_t> wleft, SGVector<float64_t> wright, SGVector<float64_t> wtotal, SGVector<float64_t> labels);
358 float64_t
gain(
const SGVector<float64_t>& wleft,
const SGVector<float64_t>& wright,
const SGVector<float64_t>& wtotal);
366 float64_t
gini_impurity_index(
const SGVector<float64_t>& weighted_lab_classes, float64_t &total_weight);
375 float64_t
least_squares_deviation(
const SGVector<float64_t>& labels,
const SGVector<float64_t>& weights, float64_t &total_weight);
384 bool set_certainty=
false);
402 float64_t
compute_error(CLabels* labels, CLabels* reference, SGVector<float64_t> weights);
409 CDynamicObjectArray*
prune_tree(CTreeMachine<MyCARTreeNodeData>* tree);
int32_t get_min_node_size() const
CDynamicArray< float64_t > * m_alphas
void set_max_depth(int32_t depth)
float64_t m_label_epsilon
void set_weights(SGVector< float64_t > w)
void handle_missing_vecs_for_continuous_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
std::vector< double > feature_importances()
int32_t get_max_depth() const
SGVector< float64_t > get_weights() const
virtual int32_t compute_best_attribute(const SGMatrix< float64_t > &mat, const SGVector< float64_t > &weights, CLabels *labels, SGVector< float64_t > &left, SGVector< float64_t > &right, SGVector< bool > &is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right, float64_t &IG, int32_t subset_size=0, const SGVector< int32_t > &active_indices=SGVector< index_t >())
static const float64_t MISSING
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
CMyCARTree()
This class implements the Classification And Regression Trees algorithm by Breiman et al for decision...
SGVector< float64_t > get_unique_labels(SGVector< float64_t > labels_vec, int32_t &n_ulabels)
void set_machine_problem_type(EProblemType mode)
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
float64_t gain(SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
static const float64_t MIN_SPLIT_GAIN
virtual EProblemType get_machine_problem_type() const
void form_t1(bnode_t *node)
float64_t get_label_epsilon()
void cut_weakest_link(bnode_t *node, float64_t alpha)
SGVector< float64_t > m_weights
void set_min_node_size(int32_t nsize)
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
int32_t get_num_folds() const
float64_t gini_impurity_index(const SGVector< float64_t > &weighted_lab_classes, float64_t &total_weight)
virtual void set_labels(CLabels *lab)
void prune_by_cross_validation(CDenseFeatures< float64_t > *data, int32_t folds)
SGVector< bool > get_feature_types() const
virtual CBinaryTreeMachineNode< MyCARTreeNodeData > * CARTtrain(CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
static const float64_t EQ_DELTA
virtual bool is_label_valid(CLabels *lab) const
float64_t least_squares_deviation(const SGVector< float64_t > &labels, const SGVector< float64_t > &weights, float64_t &total_weight)
SGVector< float64_t > m_certainty
CDynamicObjectArray * prune_tree(CTreeMachine< MyCARTreeNodeData > *tree)
void handle_missing_vecs_for_nominal_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
void set_cv_pruning(bool cv_pruning)
virtual bool train_machine(CFeatures *data=NULL)
SGMatrix< float64_t > m_sorted_features
void set_probabilities(CLabels *labels, CFeatures *data=NULL)
SGMatrix< index_t > m_sorted_indices
float64_t compute_error(CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
void get_importance(bnode_t *node, vector< double > &importances)
SGVector< bool > surrogate_split(SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
float64_t find_weakest_alpha(bnode_t *node)
CLabels * apply_from_current_node(CDenseFeatures< float64_t > *feats, bnode_t *current, bool set_certainty=false)
void set_feature_types(SGVector< bool > ft)
void prune_using_test_dataset(CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
SGVector< float64_t > get_certainty_vector() const
void clear_feature_types()
void pre_sort_features(CFeatures *data, SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
void set_sorted_features(SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
void set_label_epsilon(float64_t epsilon)
SGVector< bool > m_nominal
virtual const char * get_name() const
void set_num_folds(int32_t folds)