Feat C++ API
A feature engineering automation tool
shogun::CMyCARTree Class Reference

#include <MyCARTree.h>

Inheritance diagram for shogun::CMyCARTree:
Collaboration diagram for shogun::CMyCARTree:

Public Member Functions

 CMyCARTree ()
 This class implements the Classification And Regression Trees algorithm by Breiman et al for decision tree learning. A CART tree is a binary decision tree that is constructed by splitting a node into two child nodes repeatedly, beginning with the root node that contains the whole dataset.

TREE GROWING PROCESS :
During the tree growing process, we recursively split a node into left child and right child so that the resulting nodes are "purest". We do this until any of the stopping criteria is met. To find the best split, we scan through all possible splits in all predictive attributes. The best split is one that maximises some splitting criterion. For classification tasks, ie. when the dependent attribute is categorical, the Gini index is used. For regression tasks, ie. when the dependent variable is continuous, least squares deviation is used. The algorithm uses two stopping criteria : if node becomes completely "pure", ie. all its members have identical dependent variable, or all of them have identical predictive attributes (independent variables).

. More...
 
 CMyCARTree (SGVector< bool > attribute_types, EProblemType prob_type=PT_MULTICLASS)
 
 CMyCARTree (SGVector< bool > attribute_types, EProblemType prob_type, int32_t num_folds, bool cv_prune)
 
virtual ~CMyCARTree ()
 
virtual void set_labels (CLabels *lab)
 
virtual const char * get_name () const
 
virtual EProblemType get_machine_problem_type () const
 
void set_machine_problem_type (EProblemType mode)
 
virtual bool is_label_valid (CLabels *lab) const
 
virtual CBinaryLabels * apply_binary (CFeatures *data=NULL)
 
virtual CMulticlassLabels * apply_multiclass (CFeatures *data=NULL)
 
virtual CRegressionLabels * apply_regression (CFeatures *data=NULL)
 
void prune_using_test_dataset (CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
 
void set_weights (SGVector< float64_t > w)
 
SGVector< float64_t > get_weights () const
 
void clear_weights ()
 
void set_feature_types (SGVector< bool > ft)
 
SGVector< bool > get_feature_types () const
 
void clear_feature_types ()
 
int32_t get_num_folds () const
 
void set_num_folds (int32_t folds)
 
int32_t get_max_depth () const
 
void set_max_depth (int32_t depth)
 
int32_t get_min_node_size () const
 
void set_min_node_size (int32_t nsize)
 
void set_cv_pruning (bool cv_pruning)
 
float64_t get_label_epsilon ()
 
void set_label_epsilon (float64_t epsilon)
 
void pre_sort_features (CFeatures *data, SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
 
void set_sorted_features (SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
 
std::vector< double > feature_importances ()
 
SGVector< float64_t > get_certainty_vector () const
 
void set_probabilities (CLabels *labels, CFeatures *data=NULL)
 

Static Public Attributes

static const float64_t MISSING = CMath::MAX_REAL_NUMBER
 
static const float64_t MIN_SPLIT_GAIN = 1e-7
 
static const float64_t EQ_DELTA = 1e-7
 

Protected Member Functions

virtual bool train_machine (CFeatures *data=NULL)
 
virtual CBinaryTreeMachineNode< MyCARTreeNodeData > * CARTtrain (CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
 
SGVector< float64_t > get_unique_labels (SGVector< float64_t > labels_vec, int32_t &n_ulabels)
 
virtual int32_t compute_best_attribute (const SGMatrix< float64_t > &mat, const SGVector< float64_t > &weights, CLabels *labels, SGVector< float64_t > &left, SGVector< float64_t > &right, SGVector< bool > &is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right, float64_t &IG, int32_t subset_size=0, const SGVector< int32_t > &active_indices=SGVector< index_t >())
 
SGVector< bool > surrogate_split (SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
 
void handle_missing_vecs_for_continuous_surrogate (SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
 
void handle_missing_vecs_for_nominal_surrogate (SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
 
float64_t gain (SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
 
float64_t gain (const SGVector< float64_t > &wleft, const SGVector< float64_t > &wright, const SGVector< float64_t > &wtotal)
 
float64_t gini_impurity_index (const SGVector< float64_t > &weighted_lab_classes, float64_t &total_weight)
 
float64_t least_squares_deviation (const SGVector< float64_t > &labels, const SGVector< float64_t > &weights, float64_t &total_weight)
 
CLabels * apply_from_current_node (CDenseFeatures< float64_t > *feats, bnode_t *current, bool set_certainty=false)
 
void prune_by_cross_validation (CDenseFeatures< float64_t > *data, int32_t folds)
 
float64_t compute_error (CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
 
CDynamicObjectArray * prune_tree (CTreeMachine< MyCARTreeNodeData > *tree)
 
float64_t find_weakest_alpha (bnode_t *node)
 
void cut_weakest_link (bnode_t *node, float64_t alpha)
 
void form_t1 (bnode_t *node)
 
void init ()
 
void get_importance (bnode_t *node, vector< double > &importances)
 

Protected Attributes

float64_t m_label_epsilon
 
SGVector< bool > m_nominal
 
SGVector< float64_t > m_weights
 
SGMatrix< float64_t > m_sorted_features
 
SGMatrix< index_t > m_sorted_indices
 
bool m_pre_sort
 
bool m_types_set
 
bool m_weights_set
 
bool m_apply_cv_pruning
 
int32_t m_folds
 
EProblemType m_mode
 
CDynamicArray< float64_t > * m_alphas
 
int32_t m_max_depth
 
int32_t m_min_node_size
 
SGVector< float64_t > m_certainty
 

Detailed Description

Definition at line 47 of file MyCARTree.h.

Constructor & Destructor Documentation

◆ CMyCARTree() [1/3]

CMyCARTree::CMyCARTree ( )

This class implements the Classification And Regression Trees algorithm by Breiman et al for decision tree learning. A CART tree is a binary decision tree that is constructed by splitting a node into two child nodes repeatedly, beginning with the root node that contains the whole dataset.

TREE GROWING PROCESS :
During the tree growing process, we recursively split a node into left child and right child so that the resulting nodes are "purest". We do this until any of the stopping criteria is met. To find the best split, we scan through all possible splits in all predictive attributes. The best split is one that maximises some splitting criterion. For classification tasks, ie. when the dependent attribute is categorical, the Gini index is used. For regression tasks, ie. when the dependent variable is continuous, least squares deviation is used. The algorithm uses two stopping criteria : if node becomes completely "pure", ie. all its members have identical dependent variable, or all of them have identical predictive attributes (independent variables).

.

COST-COMPLEXITY PRUNING :
The maximal tree, \(T_max\) grown during tree growing process is bound to overfit. Hence pruning becomes necessary. Cost-Complexity pruning yields a list of subtrees of varying depths using the complexity normalized resubstitution error, \(R_\alpha(T)\). The resubstitution error R(T) is a measure of how well a decision tree fits the training data. This measure favours larger trees over smaller ones. However, complexity normalized resubstitution error, adds penalty for increased complexity and hence counters overfitting.
\(R_\alpha(T)=R(T)+\alpha \times (numleaves)\)
The best subtree among the list of subtrees can be chosen using cross validation or using best-fit in the test dataset.
cf. https://onlinecourses.science.psu.edu/stat557/node/93

HANDLING MISSING VALUES :
While choosing the best split at a node, missing attribute values are left out. But data vectors with missing values of the best attribute chosen are sent to left child or right child using a surrogate split. A surrogate split is one that imitates the best split as closely as possible. While choosing a surrogate split, all splits alternative to the best split are scaned and the degree of closeness between the two is measured using a metric called predictive measure of association, \(\lambda_{i,j}\).
\(\lambda_{i,j} = \frac{min(P_L,P_R)-(1-P_{L_iL_j}-P_{R_iR_j})}{min(P_L,P_R)}\)
where \(P_L\) and \(P_R\) are the node probabilities for the optimal split of node i into left and right nodes respectively, \(P_{L_iL_j}\) ( \(P_{R_iR_j}\) resp.) is the probability that both (optimal) node i and (surrogate) node j send an observation to the Left (Right resp.).
We use best surrogate split, 2nd best surrogate split and so on until all data points with missing attributes in a node have been sent to left/right child. If all possible surrogate splits are used up but some data points are still to be assigned left/right child, majority rule is used, ie. the data points are assigned the child where majority of data points have gone from the node.
cf. http://pic.dhe.ibm.com/infocenter/spssstat/v20r0m0/index.jsp?topic=%2Fcom.ibm.spss.statistics.help%2Falg_tree-cart.htm default constructor

Definition at line 56 of file MyCARTree.cc.

◆ CMyCARTree() [2/3]

CMyCARTree::CMyCARTree ( SGVector< bool >  attribute_types,
EProblemType  prob_type = PT_MULTICLASS 
)

constructor

Parameters
attribute_typestype of each predictive attribute (true for nominal, false for ordinal/continuous)
prob_typemachine problem type - PT_MULTICLASS or PT_REGRESSION

Definition at line 62 of file MyCARTree.cc.

◆ CMyCARTree() [3/3]

CMyCARTree::CMyCARTree ( SGVector< bool >  attribute_types,
EProblemType  prob_type,
int32_t  num_folds,
bool  cv_prune 
)

constructor - to be used while using cross-validation pruning

Parameters
attribute_typestype of each predictive attribute (true for nominal, false for ordinal/continuous)
prob_typemachine problem type - PT_MULTICLASS or PT_REGRESSION
num_foldsnumber of subsets used in cross-valiation
cv_prune- whether to use cross-validation pruning

Definition at line 70 of file MyCARTree.cc.

◆ ~CMyCARTree()

CMyCARTree::~CMyCARTree ( )
virtual

destructor

Definition at line 81 of file MyCARTree.cc.

Member Function Documentation

◆ apply_binary()

CBinaryLabels * CMyCARTree::apply_binary ( CFeatures *  data = NULL)
virtual

WGL: classify data using Classification Tree

Parameters
datadata to be classified
Returns
BinaryLabels corresponding to labels of various test vectors

Definition at line 115 of file MyCARTree.cc.

◆ apply_from_current_node()

CLabels * CMyCARTree::apply_from_current_node ( CDenseFeatures< float64_t > *  feats,
bnode_t *  current,
bool  set_certainty = false 
)
protected

uses current subtree to classify/regress data

Parameters
featsdata to be classified/regressed
currentroot of current subtree
Returns
classification/regression labels of input data

Definition at line 1144 of file MyCARTree.cc.

◆ apply_multiclass()

CMulticlassLabels * CMyCARTree::apply_multiclass ( CFeatures *  data = NULL)
virtual

classify data using Classification Tree

Parameters
datadata to be classified
Returns
MulticlassLabels corresponding to labels of various test vectors

Definition at line 136 of file MyCARTree.cc.

◆ apply_regression()

CRegressionLabels * CMyCARTree::apply_regression ( CFeatures *  data = NULL)
virtual

Get regression labels using Regression Tree

Parameters
datadata whose regression output is needed
Returns
Regression output for various test vectors

Definition at line 150 of file MyCARTree.cc.

◆ CARTtrain()

CBinaryTreeMachineNode< MyCARTreeNodeData > * CMyCARTree::CARTtrain ( CFeatures *  data,
SGVector< float64_t >  weights,
CLabels *  labels,
int32_t  level 
)
protectedvirtual

CARTtrain - recursive CART training method

Parameters
datatraining data
weightsvector of weights of data points
labelslabels of data points
levelcurrent tree depth
Returns
pointer to the root of the CART subtree

Definition at line 351 of file MyCARTree.cc.

◆ clear_feature_types()

void CMyCARTree::clear_feature_types ( )

clear feature types of various features

Definition at line 236 of file MyCARTree.cc.

◆ clear_weights()

void CMyCARTree::clear_weights ( )

clear weights of data points

Definition at line 219 of file MyCARTree.cc.

◆ compute_best_attribute()

int32_t CMyCARTree::compute_best_attribute ( const SGMatrix< float64_t > &  mat,
const SGVector< float64_t > &  weights,
CLabels *  labels,
SGVector< float64_t > &  left,
SGVector< float64_t > &  right,
SGVector< bool > &  is_left_final,
int32_t &  num_missing,
int32_t &  count_left,
int32_t &  count_right,
float64_t &  IG,
int32_t  subset_size = 0,
const SGVector< int32_t > &  active_indices = SGVector<index_t>() 
)
protectedvirtual

computes best attribute for CARTtrain

Parameters
matdata matrix
weightsdata weights
labels_vecdata labels
leftstores feature values for left transition
rightstores feature values for right transition
is_left_finalstores which feature vectors go to the left child
num_missingnumber of missing attributes
count_leftstores number of feature values for left transition
count_rightstores number of feature values for right transition
IG(WGL): store impurity gain at node
Returns
index to the best attribute

Definition at line 570 of file MyCARTree.cc.

◆ compute_error()

float64_t CMyCARTree::compute_error ( CLabels *  labels,
CLabels *  reference,
SGVector< float64_t >  weights 
)
protected

computes error in classification/regression for classification it eveluates weight_missclassified/total_weight for regression it evaluates weighted sum of squared error/total_weight

Parameters
labelsthe labels whose error needs to be calculated
referenceactual labels against which test labels are compared
weightsweights associated with the labels
Returns
error evaluated

Definition at line 1382 of file MyCARTree.cc.

◆ cut_weakest_link()

void CMyCARTree::cut_weakest_link ( bnode_t *  node,
float64_t  alpha 
)
protected

recursively cuts weakest link(s) in a tree

Parameters
nodethe root of subtree whose weakest link it cuts
alphaalpha value corresponding to weakest link

Definition at line 1491 of file MyCARTree.cc.

◆ feature_importances()

vector< double > CMyCARTree::feature_importances ( )

WGL: return Gini importance scores for features

Importance is defined as the sum across all splits in the tree of the information criterion brought about by each feature.

Definition at line 1580 of file MyCARTree.cc.

◆ find_weakest_alpha()

float64_t CMyCARTree::find_weakest_alpha ( bnode_t *  node)
protected

recursively finds alpha corresponding to weakest link(s)

Parameters
nodethe root of subtree whose weakest link it finds
Returns
alpha value corresponding to the weakest link in subtree

Definition at line 1470 of file MyCARTree.cc.

◆ form_t1()

void CMyCARTree::form_t1 ( bnode_t *  node)
protected

recursively forms base case $ft_1$f tree from $ft_max$f during pruning

Parameters
nodethe root of current subtree

Definition at line 1521 of file MyCARTree.cc.

◆ gain() [1/2]

float64_t CMyCARTree::gain ( const SGVector< float64_t > &  wleft,
const SGVector< float64_t > &  wright,
const SGVector< float64_t > &  wtotal 
)
protected

returns gain in Gini impurity measure

Parameters
wleftleft child label distribution
wrightright child label distribution
wtotallabel distribution in current node
Returns
Gini gain achieved after spliting the node

Definition at line 1107 of file MyCARTree.cc.

◆ gain() [2/2]

float64_t CMyCARTree::gain ( SGVector< float64_t >  wleft,
SGVector< float64_t >  wright,
SGVector< float64_t >  wtotal,
SGVector< float64_t >  labels 
)
protected

returns gain in regression case

Parameters
wleftleft child weight distribution
wrightright child weights distribution
wtotalweight distribution in current node
labelsregression labels
Returns
least squared deviation gain achieved after spliting the node

Definition at line 1093 of file MyCARTree.cc.

◆ get_certainty_vector()

SGVector< float64_t > CMyCARTree::get_certainty_vector ( ) const

WGL: gets the probability estimate for each sample TODO: make this set_probabilities

Definition at line 1238 of file MyCARTree.cc.

◆ get_feature_types()

SGVector< bool > CMyCARTree::get_feature_types ( ) const

set feature types of various features

Returns
bool vector - true for nominal feature false for continuous feature type

Definition at line 231 of file MyCARTree.cc.

◆ get_importance()

void CMyCARTree::get_importance ( bnode_t *  node,
vector< double > &  importances 
)
protected

WGL: recursive function for getting node importance

Definition at line 1597 of file MyCARTree.cc.

◆ get_label_epsilon()

float64_t CMyCARTree::get_label_epsilon ( )

get label epsilon

Returns
equality range for regression labels

Definition at line 50 of file MyCARTree.cc.

◆ get_machine_problem_type()

EProblemType CMyCARTree::get_machine_problem_type ( ) const
virtual

get problem type - multiclass classification or regression

Returns
PT_MULTICLASS or PT_REGRESSION

Definition at line 42 of file MyCARTree.cc.

◆ get_max_depth()

int32_t CMyCARTree::get_max_depth ( ) const

get max allowed tree depth

Returns
max allowed tree depth

Definition at line 253 of file MyCARTree.cc.

◆ get_min_node_size()

int32_t CMyCARTree::get_min_node_size ( ) const

get min allowed node size

Returns
min allowed node size

Definition at line 264 of file MyCARTree.cc.

◆ get_name()

const char * CMyCARTree::get_name ( ) const
virtual

get name

Returns
class name CARTree

Reimplemented in shogun::CMyRandomCARTree.

Definition at line 40 of file MyCARTree.cc.

◆ get_num_folds()

int32_t CMyCARTree::get_num_folds ( ) const

get number of subsets used for cross validation

Returns
number of folds used in cross validation

Definition at line 242 of file MyCARTree.cc.

◆ get_unique_labels()

SGVector< float64_t > CMyCARTree::get_unique_labels ( SGVector< float64_t >  labels_vec,
int32_t &  n_ulabels 
)
protected

modify labels for compute_best_attribute

Parameters
labels_veclabels vector
n_ulabelsstores number of unique labels
Returns
unique labels

Definition at line 546 of file MyCARTree.cc.

◆ get_weights()

SGVector< float64_t > CMyCARTree::get_weights ( ) const

get weights of data points

Returns
vector of weights

Definition at line 214 of file MyCARTree.cc.

◆ gini_impurity_index()

float64_t CMyCARTree::gini_impurity_index ( const SGVector< float64_t > &  weighted_lab_classes,
float64_t &  total_weight 
)
protected

returns Gini impurity of a node

Parameters
weighted_lab_classesvector of weights associated with various labels
total_weightstores the total weight of all classes
Returns
Gini index of the node

Definition at line 1119 of file MyCARTree.cc.

◆ handle_missing_vecs_for_continuous_surrogate()

void CMyCARTree::handle_missing_vecs_for_continuous_surrogate ( SGMatrix< float64_t >  m,
CDynamicArray< int32_t > *  missing_vecs,
CDynamicArray< float64_t > *  association_index,
CDynamicArray< int32_t > *  intersect_vecs,
SGVector< bool >  is_left,
SGVector< float64_t >  weights,
float64_t  p,
int32_t  attr 
)
protected

handles missing values for a chosen continuous surrogate attribute

Parameters
mtraining data matrix
missing_vecscolumn indices of vectors with missing attribute in data matrix
association_indexstores the final lambda values used to address members of missing_vecs
intersect_vecscolumn indices of vectors with known values for the best attribute as well as the chosen surrogate
is_leftwhether a vector goes into left child
weightsweights of training data vectors
pmin(p_l,p_r) in the lambda formula
attrsurrogate attribute chosen for split
Returns
vector denoting whether a data point goes to left child for all data points including ones with missing attributes

Definition at line 956 of file MyCARTree.cc.

◆ handle_missing_vecs_for_nominal_surrogate()

void CMyCARTree::handle_missing_vecs_for_nominal_surrogate ( SGMatrix< float64_t >  m,
CDynamicArray< int32_t > *  missing_vecs,
CDynamicArray< float64_t > *  association_index,
CDynamicArray< int32_t > *  intersect_vecs,
SGVector< bool >  is_left,
SGVector< float64_t >  weights,
float64_t  p,
int32_t  attr 
)
protected

handles missing values for a chosen nominal surrogate attribute

Parameters
mtraining data matrix
missing_vecscolumn indices of vectors with missing attribute in data matrix
association_indexstores the final lambda values used to address members of missing_vecs
intersect_vecscolumn indices of vectors with known values for the best attribute as well as the chosen surrogate
is_leftwhether a vector goes into left child
weightsweights of training data vectors
pmin(p_l,p_r) in the lambda formula
attrsurrogate attribute chosen for split
Returns
vector denoting whether a data point goes to left child for all data points including ones with missing attributes

Definition at line 1013 of file MyCARTree.cc.

◆ init()

void CMyCARTree::init ( )
protected

initializes members of class

Definition at line 1547 of file MyCARTree.cc.

◆ is_label_valid()

bool CMyCARTree::is_label_valid ( CLabels *  lab) const
virtual

whether labels supplied are valid for current problem type

Parameters
lablabels supplied
Returns
true for valid labels, false for invalid labels

Definition at line 105 of file MyCARTree.cc.

◆ least_squares_deviation()

float64_t CMyCARTree::least_squares_deviation ( const SGVector< float64_t > &  labels,
const SGVector< float64_t > &  weights,
float64_t &  total_weight 
)
protected

returns least squares deviation

Parameters
labelsregression labels
weightsweights of regression data points
total_weightstores sum of weights in weights vector
Returns
least squares deviation of the data

Definition at line 1129 of file MyCARTree.cc.

◆ pre_sort_features()

void CMyCARTree::pre_sort_features ( CFeatures *  data,
SGMatrix< float64_t > &  sorted_feats,
SGMatrix< index_t > &  sorted_indices 
)

Definition at line 331 of file MyCARTree.cc.

◆ prune_by_cross_validation()

void CMyCARTree::prune_by_cross_validation ( CDenseFeatures< float64_t > *  data,
int32_t  folds 
)
protected

prune by cross validation

Parameters
datatraining data
foldsthe integer V for V-fold cross validation

Definition at line 1243 of file MyCARTree.cc.

◆ prune_tree()

CDynamicObjectArray * CMyCARTree::prune_tree ( CTreeMachine< MyCARTreeNodeData > *  tree)
protected

cost-complexity pruning

Parameters
treethe tree to be pruned
Returns
CDynamicObjectArray of pruned trees

Definition at line 1420 of file MyCARTree.cc.

◆ prune_using_test_dataset()

void CMyCARTree::prune_using_test_dataset ( CDenseFeatures< float64_t > *  feats,
CLabels *  gnd_truth,
SGVector< float64_t >  weights = SGVector<float64_t>() 
)

uses test dataset to choose best pruned subtree

Parameters
featstest data to be used
gnd_truthtest labels
weightsweights of data points

Definition at line 162 of file MyCARTree.cc.

◆ set_cv_pruning()

void CMyCARTree::set_cv_pruning ( bool  cv_pruning)

Set cross validation pruning parameter

Parameters
cv_pruningallow CV pruning

Definition at line 45 of file MyCARTree.cc.

◆ set_feature_types()

void CMyCARTree::set_feature_types ( SGVector< bool >  ft)

set feature types of various features

Parameters
ftbool vector true for nominal feature false for continuous feature type

Definition at line 225 of file MyCARTree.cc.

◆ set_label_epsilon()

void CMyCARTree::set_label_epsilon ( float64_t  epsilon)

set label epsilon

Parameters
epsilonequality range for regression labels

Definition at line 275 of file MyCARTree.cc.

◆ set_labels()

void CMyCARTree::set_labels ( CLabels *  lab)
virtual

set labels - automagically switch machine problem type based on type of labels supplied

Parameters
lablabels

Definition at line 86 of file MyCARTree.cc.

◆ set_machine_problem_type()

void CMyCARTree::set_machine_problem_type ( EProblemType  mode)

set problem type - multiclass classification or regression

Parameters
modeEProblemType PT_MULTICLASS or PT_REGRESSION

Definition at line 100 of file MyCARTree.cc.

◆ set_max_depth()

void CMyCARTree::set_max_depth ( int32_t  depth)

set max allowed tree depth

Parameters
depthmax allowed tree depth

Definition at line 258 of file MyCARTree.cc.

◆ set_min_node_size()

void CMyCARTree::set_min_node_size ( int32_t  nsize)

set min allowed node size

Parameters
nsizemin allowed node size

Definition at line 269 of file MyCARTree.cc.

◆ set_num_folds()

void CMyCARTree::set_num_folds ( int32_t  folds)

set number of subsets for cross validation

Parameters
foldsnumber of folds used in cross validation

Definition at line 247 of file MyCARTree.cc.

◆ set_probabilities()

void CMyCARTree::set_probabilities ( CLabels *  labels,
CFeatures *  data = NULL 
)

WGL: sets the probabilities on each label according to m_certainty

Definition at line 1615 of file MyCARTree.cc.

◆ set_sorted_features()

void CMyCARTree::set_sorted_features ( SGMatrix< float64_t > &  sorted_feats,
SGMatrix< index_t > &  sorted_indices 
)

Definition at line 324 of file MyCARTree.cc.

◆ set_weights()

void CMyCARTree::set_weights ( SGVector< float64_t >  w)

set weights of data points

Parameters
wvector of weights

Definition at line 208 of file MyCARTree.cc.

◆ surrogate_split()

SGVector< bool > CMyCARTree::surrogate_split ( SGMatrix< float64_t >  data,
SGVector< float64_t >  weights,
SGVector< bool >  nm_left,
int32_t  attr 
)
protected

handles missing values through surrogate splits

Parameters
datatraining data matrix
weightsvector of weights of data points
nm_leftwhether a data point is put into left child (available for only data points with non-missing attribute attr)
attrbest attribute chosen for split
Returns
vector denoting whether a data point goes to left child for all data points including ones with missing attributes

Definition at line 881 of file MyCARTree.cc.

◆ train_machine()

bool CMyCARTree::train_machine ( CFeatures *  data = NULL)
protectedvirtual

train machine - build CART from training data

Parameters
datatraining data
Returns
true

Definition at line 281 of file MyCARTree.cc.

Member Data Documentation

◆ EQ_DELTA

const float64_t CMyCARTree::EQ_DELTA = 1e-7
static

equality epsilon

Definition at line 447 of file MyCARTree.h.

◆ m_alphas

CDynamicArray<float64_t>* shogun::CMyCARTree::m_alphas
protected

stores \(\alpha_k\) values evaluated in cost-complexity pruning

Definition at line 484 of file MyCARTree.h.

◆ m_apply_cv_pruning

bool shogun::CMyCARTree::m_apply_cv_pruning
protected

flag indicating whether cross validation pruning has to be applied or not - false by default

Definition at line 475 of file MyCARTree.h.

◆ m_certainty

SGVector<float64_t> shogun::CMyCARTree::m_certainty
protected

percentage of certainty of labels predicted by decision tree ie. weight of elements belonging to predicted class in a node/ total weight in a node

Definition at line 495 of file MyCARTree.h.

◆ m_folds

int32_t shogun::CMyCARTree::m_folds
protected

V in V-fold cross validation - 5 by default

Definition at line 478 of file MyCARTree.h.

◆ m_label_epsilon

float64_t shogun::CMyCARTree::m_label_epsilon
protected

equality range for regression labels

Definition at line 451 of file MyCARTree.h.

◆ m_max_depth

int32_t shogun::CMyCARTree::m_max_depth
protected

max allowed depth of tree

Definition at line 487 of file MyCARTree.h.

◆ m_min_node_size

int32_t shogun::CMyCARTree::m_min_node_size
protected

minimum number of feature vectors required in a node

Definition at line 490 of file MyCARTree.h.

◆ m_mode

EProblemType shogun::CMyCARTree::m_mode
protected

Problem type : PT_MULTICLASS or PT_REGRESSION

Definition at line 481 of file MyCARTree.h.

◆ m_nominal

SGVector<bool> shogun::CMyCARTree::m_nominal
protected

vector depicting whether various feature dimensions are nominal or not

Definition at line 454 of file MyCARTree.h.

◆ m_pre_sort

bool shogun::CMyCARTree::m_pre_sort
protected

If pre sorted features are used in train

Definition at line 466 of file MyCARTree.h.

◆ m_sorted_features

SGMatrix<float64_t> shogun::CMyCARTree::m_sorted_features
protected

sorted transposed features

Definition at line 460 of file MyCARTree.h.

◆ m_sorted_indices

SGMatrix<index_t> shogun::CMyCARTree::m_sorted_indices
protected

sorted indices

Definition at line 463 of file MyCARTree.h.

◆ m_types_set

bool shogun::CMyCARTree::m_types_set
protected

flag storing whether the type of various feature dimensions are specified using is_nominal_feature

Definition at line 469 of file MyCARTree.h.

◆ m_weights

SGVector<float64_t> shogun::CMyCARTree::m_weights
protected

weights of samples in training set

Definition at line 457 of file MyCARTree.h.

◆ m_weights_set

bool shogun::CMyCARTree::m_weights_set
protected

flag storing whether weights of samples are specified using weights vector

Definition at line 472 of file MyCARTree.h.

◆ MIN_SPLIT_GAIN

const float64_t CMyCARTree::MIN_SPLIT_GAIN = 1e-7
static

min gain for splitting to be allowed

Definition at line 444 of file MyCARTree.h.

◆ MISSING

const float64_t CMyCARTree::MISSING = CMath::MAX_REAL_NUMBER
static

denotes that a feature in a vector is missing MISSING = NOT_A_NUMBER

Definition at line 441 of file MyCARTree.h.


The documentation for this class was generated from the following files: