34 #include <shogun/base/some.h>
35 using namespace Eigen;
40 const char* CMyCARTree::get_name()
const {
return "MyCARTree"; }
42 EProblemType CMyCARTree::get_machine_problem_type()
const {
return m_mode; }
45 void CMyCARTree::set_cv_pruning(
bool cv_pruning)
47 m_apply_cv_pruning = cv_pruning;
50 float64_t CMyCARTree::get_label_epsilon() {
return m_label_epsilon; }
52 const float64_t CMyCARTree::MISSING = CMath::MAX_REAL_NUMBER;
53 const float64_t CMyCARTree::EQ_DELTA = 1e-7;
54 const float64_t CMyCARTree::MIN_SPLIT_GAIN = 1e-7;
56 CMyCARTree::CMyCARTree()
88 if (lab->get_label_type()==LT_MULTICLASS)
90 else if (lab->get_label_type()==LT_REGRESSION)
93 SG_ERROR(
"label type supplied is not supported\n")
107 if (
m_mode==PT_MULTICLASS && lab->get_label_type()==LT_MULTICLASS)
109 else if (
m_mode==PT_REGRESSION && lab->get_label_type()==LT_REGRESSION)
117 REQUIRE(data,
"Data required for classification in apply_multiclass\n")
120 bnode_t* current=
dynamic_cast<bnode_t*
>(get_root());
122 REQUIRE(current,
"Tree machine not yet trained.\n");
125 SGVector<double> tmp =
dynamic_cast<CDenseLabels*
>(ret)->get_labels();
127 CBinaryLabels* retbc =
new CBinaryLabels(tmp,0.5);
138 REQUIRE(data,
"Data required for classification in apply_multiclass\n")
141 bnode_t* current=
dynamic_cast<bnode_t*
>(get_root());
143 REQUIRE(current,
"Tree machine not yet trained.\n");
147 return dynamic_cast<CMulticlassLabels*
>(ret);
152 REQUIRE(data,
"Data required for classification in apply_multiclass\n")
155 bnode_t* current=
dynamic_cast<bnode_t*
>(get_root());
159 return dynamic_cast<CRegressionLabels*
>(ret);
166 weights=SGVector<float64_t>(feats->get_num_vectors());
167 weights.fill_vector(weights.vector,weights.vlen,1);
170 CDynamicObjectArray* pruned_trees=
prune_tree(
this);
173 float64_t min_error=CMath::MAX_REAL_NUMBER;
174 for (int32_t
i=0;
i<
m_alphas->get_num_elements();
i++)
176 CSGObject* element=pruned_trees->get_element(
i);
179 root=
dynamic_cast<bnode_t*
>(element);
181 SG_ERROR(
"%d element is NULL\n",
i);
195 CSGObject* element=pruned_trees->get_element(min_index);
198 root=
dynamic_cast<bnode_t*
>(element);
200 SG_ERROR(
"%d element is NULL\n",min_index);
202 this->set_root(root);
204 SG_UNREF(pruned_trees);
249 REQUIRE(folds>1,
"Number of folds is expected to be greater than 1. Supplied value is %d\n",folds)
260 REQUIRE(depth>0,
"Max allowed tree depth should be greater than 0. Supplied value is %d\n",depth)
271 REQUIRE(nsize>0,
"Min allowed node size should be greater than 0. Supplied value is %d\n",nsize)
277 REQUIRE(ep>=0,
"Input epsilon value is expected to be greater than or equal to 0\n")
283 REQUIRE(data,
"Data required for training\n")
284 REQUIRE(data->get_feature_class()==C_DENSE,
"Dense data required for training\n")
286 int32_t num_features=(
dynamic_cast<CDenseFeatures<float64_t>*
>(data))->get_num_features();
287 int32_t num_vectors=(
dynamic_cast<CDenseFeatures<float64_t>*
>(data))->get_num_vectors();
291 REQUIRE(
m_weights.vlen==num_vectors,
"Length of weights vector (currently %d) should be same as"
292 " number of vectors in data (presently %d)",
m_weights.vlen,num_vectors)
297 m_weights=SGVector<float64_t>(num_vectors);
303 REQUIRE(
m_nominal.vlen==num_features,
"Length of m_nominal vector (currently %d) should "
304 "be same as number of features in data (presently %d)",
m_nominal.vlen,num_features)
308 SG_WARNING(
"Feature types are not specified. All features are considered as continuous in training")
317 CDenseFeatures<float64_t>* feats=
dynamic_cast<CDenseFeatures<float64_t>*
>(data);
333 SGMatrix<float64_t> mat=(
dynamic_cast<CDenseFeatures<float64_t>*
>(data))->get_feature_matrix();
334 sorted_feats = SGMatrix<float64_t>(mat.num_cols, mat.num_rows);
335 sorted_indices = SGMatrix<index_t>(mat.num_cols, mat.num_rows);
336 for(int32_t
i=0;
i<sorted_indices.num_cols;
i++)
337 for(int32_t j=0; j<sorted_indices.num_rows; j++)
338 sorted_indices(j,
i)=j;
340 Map<MatrixXd> map_sorted_feats(sorted_feats.matrix, mat.num_cols, mat.num_rows);
341 Map<MatrixXd> map_data(mat.matrix, mat.num_rows, mat.num_cols);
343 map_sorted_feats=map_data.transpose();
345 #pragma omp parallel for
346 for(int32_t
i=0;
i<sorted_feats.num_cols;
i++)
347 CMath::qsort_index(sorted_feats.get_column_vector(
i), sorted_indices.get_column_vector(
i), sorted_feats.num_rows);
351 CBinaryTreeMachineNode<MyCARTreeNodeData>*
CMyCARTree::CARTtrain(CFeatures* data, SGVector<float64_t> weights, CLabels* labels, int32_t level)
353 REQUIRE(labels,
"labels have to be supplied\n");
354 REQUIRE(data,
"data matrix has to be supplied\n");
356 bnode_t* node=
new bnode_t();
357 SGVector<float64_t> labels_vec=(
dynamic_cast<CDenseLabels*
>(labels))->get_labels();
359 SGMatrix<float64_t> mat=(
dynamic_cast<CDenseFeatures<float64_t>*
>(data))->get_feature_matrix();
361 int32_t num_feats=mat.num_rows;
362 int32_t num_vecs=mat.num_cols;
370 for (int32_t
i=0;
i<labels_vec.vlen;
i++)
371 sum+=labels_vec[
i]*weights[
i];
376 node->data.node_label=sum/tot;
377 node->data.total_weight=tot;
383 SGVector<float64_t> lab=labels_vec.clone();
386 int32_t max=weights[0];
389 int32_t c=weights[0];
390 for (int32_t
i=1;
i<lab.vlen;
i++)
392 if (lab[
i]==lab[
i-1])
414 node->data.node_label=lab[maxi];
417 node->data.total_weight=weights.sum(weights);
418 node->data.weight_minus_node=node->data.total_weight-max;
422 SG_ERROR(
"mode should be either PT_MULTICLASS or PT_REGRESSION\n");
429 node->data.num_leaves=1;
430 node->data.weight_minus_branch=node->data.weight_minus_node;
437 node->data.num_leaves=1;
438 node->data.weight_minus_branch=node->data.weight_minus_node;
444 SGVector<float64_t> left(num_feats);
446 SGVector<float64_t> right(num_feats);
448 SGVector<bool> left_final(num_vecs);
449 int32_t num_missing_final=0;
452 int32_t best_attribute;
455 SGVector<index_t> indices(num_vecs);
458 CSubsetStack* subset_stack = data->get_subset_stack();
459 if (subset_stack->has_subsets())
460 indices=(subset_stack->get_last_subset())->get_subset_idx();
462 indices.range_fill();
463 SG_UNREF(subset_stack);
464 best_attribute=
compute_best_attribute(
m_sorted_features,weights,labels,left,right,left_final,num_missing_final,c_left,c_right,IG,0,indices);
467 best_attribute=
compute_best_attribute(mat,weights,labels,left,right,left_final,num_missing_final,c_left,c_right,IG);
469 if (best_attribute==-1)
471 node->data.num_leaves=1;
472 node->data.weight_minus_branch=node->data.weight_minus_node;
476 SGVector<float64_t> left_transit(c_left);
477 SGVector<float64_t> right_transit(c_right);
478 sg_memcpy(left_transit.vector,left.vector,c_left*
sizeof(float64_t));
479 sg_memcpy(right_transit.vector,right.vector,c_right*
sizeof(float64_t));
481 if (num_missing_final>0)
483 SGVector<bool> is_left_final(num_vecs-num_missing_final);
485 for (int32_t
i=0;
i<num_vecs;
i++)
488 is_left_final[ilf++]=left_final[
i];
494 int32_t count_left=0;
495 for (int32_t c=0;c<num_vecs;c++)
496 count_left=(left_final[c])?count_left+1:count_left;
498 SGVector<index_t> subsetl(count_left);
499 SGVector<float64_t> weightsl(count_left);
500 SGVector<index_t> subsetr(num_vecs-count_left);
501 SGVector<float64_t> weightsr(num_vecs-count_left);
504 for (int32_t c=0;c<num_vecs;c++)
509 weightsl[l++]=weights[c];
514 weightsr[
r++]=weights[c];
519 data->add_subset(subsetl);
520 labels->add_subset(subsetl);
521 bnode_t* left_child=
CARTtrain(data,weightsl,labels,level+1);
522 data->remove_subset();
523 labels->remove_subset();
526 data->add_subset(subsetr);
527 labels->add_subset(subsetr);
528 bnode_t* right_child=
CARTtrain(data,weightsr,labels,level+1);
529 data->remove_subset();
530 labels->remove_subset();
533 node->data.attribute_id=best_attribute;
536 node->left(left_child);
537 node->right(right_child);
538 left_child->data.transit_into_values=left_transit;
539 right_child->data.transit_into_values=right_transit;
540 node->data.num_leaves=left_child->data.num_leaves+right_child->data.num_leaves;
541 node->data.weight_minus_branch=left_child->data.weight_minus_branch+right_child->data.weight_minus_branch;
549 if (
m_mode==PT_REGRESSION)
552 SGVector<float64_t> ulabels(labels_vec.vlen);
554 ulabels[0]=labels_vec[sidx[0]];
557 for (int32_t
i=1;
i<sidx.vlen;
i++)
559 if (labels_vec[sidx[
i]]<=labels_vec[sidx[start]]+delta)
563 ulabels[n_ulabels]=labels_vec[sidx[
i]];
571 SGVector<float64_t>& left, SGVector<float64_t>& right, SGVector<bool>& is_left_final, int32_t &num_missing_final, int32_t &count_left,
572 int32_t &count_right,float64_t& IG, int32_t subset_size,
const SGVector<index_t>& active_indices)
574 SGVector<float64_t> labels_vec=(
dynamic_cast<CDenseLabels*
>(labels))->get_labels();
575 int32_t num_vecs=labels->get_num_labels();
578 num_feats=mat.num_cols;
580 num_feats=mat.num_rows;
590 if (
m_mode==PT_REGRESSION)
593 SGVector<float64_t> total_wclasses(n_ulabels);
594 total_wclasses.zero();
596 SGVector<int32_t> simple_labels(num_vecs);
597 for (int32_t
i=0;
i<num_vecs;
i++)
599 for (int32_t j=0;j<n_ulabels;j++)
601 if (CMath::abs(labels_vec[
i]-ulabels[j])<=delta)
604 total_wclasses[j]+=weights[
i];
610 SGVector<index_t> idx(num_feats);
614 num_feats=subset_size;
619 int32_t best_attribute=-1;
620 float64_t best_threshold=0;
622 SGVector<int64_t> indices_mask;
623 SGVector<int32_t> count_indices(mat.num_rows);
624 count_indices.zero();
625 SGVector<int32_t> dupes(num_vecs);
629 indices_mask = SGVector<int64_t>(mat.num_rows);
630 indices_mask.set_const(-1);
631 for(int32_t j=0;j<active_indices.size();j++)
633 if (indices_mask[active_indices[j]]>=0)
634 dupes[indices_mask[active_indices[j]]]=j;
636 indices_mask[active_indices[j]]=j;
637 count_indices[active_indices[j]]++;
641 for (int32_t
i=0;
i<num_feats;
i++)
643 SGVector<float64_t> feats(num_vecs);
644 SGVector<index_t> sorted_args(num_vecs);
645 SGVector<int32_t> temp_count_indices(count_indices.size());
646 sg_memcpy(temp_count_indices.vector, count_indices.vector,
sizeof(int32_t)*count_indices.size());
650 SGVector<float64_t> temp_col(mat.get_column_vector(idx[
i]), mat.num_rows,
false);
651 SGVector<index_t> sorted_indices(
m_sorted_indices.get_column_vector(idx[
i]), mat.num_rows,
false);
653 for(int32_t j=0;j<mat.num_rows;j++)
655 if (indices_mask[sorted_indices[j]]>=0)
657 int32_t count_index = count_indices[sorted_indices[j]];
660 feats[count]=temp_col[j];
661 sorted_args[count]=indices_mask[sorted_indices[j]];
672 for (int32_t j=0;j<num_vecs;j++)
673 feats[j]=mat(idx[
i],j);
676 sorted_args.range_fill();
677 CMath::qsort_index(feats.vector, sorted_args.vector, feats.size());
679 int32_t n_nm_vecs=feats.vlen;
681 while (feats[n_nm_vecs-1]==
MISSING)
683 total_wclasses[simple_labels[sorted_args[n_nm_vecs-1]]]-=weights[sorted_args[n_nm_vecs-1]];
688 if (feats[n_nm_vecs-1]<=feats[0]+
EQ_DELTA)
693 SGVector<int32_t> simple_feats(num_vecs);
694 simple_feats.fill_vector(simple_feats.vector,simple_feats.vlen,-1);
699 for (int32_t j=1;j<n_nm_vecs;j++)
701 if (feats[j]==feats[j-1])
704 simple_feats[j]=(++c);
707 SGVector<float64_t> ufeats(c+1);
710 for (int32_t j=1;j<n_nm_vecs;j++)
712 if (feats[j]==feats[j-1])
715 ufeats[++u]=feats[j];
719 int32_t num_cases=CMath::pow(2,c);
720 for (int32_t k=1;k<num_cases;k++)
722 SGVector<float64_t> wleft(n_ulabels);
723 SGVector<float64_t> wright(n_ulabels);
728 SGVector<bool> is_left(num_vecs);
729 is_left.fill_vector(is_left.vector,is_left.vlen,
false);
732 SGVector<bool> feats_left(c+1);
735 for (int32_t p=0;p<c+1;p++)
736 feats_left[p]=((k/CMath::pow(2,p))%(CMath::pow(2,p+1))==1);
739 for (int32_t j=0;j<n_nm_vecs;j++)
741 is_left[sorted_args[j]]=feats_left[simple_feats[j]];
742 if (is_left[sorted_args[j]])
743 wleft[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
745 wright[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
747 for (int32_t j=n_nm_vecs-1;j>=0;j--)
750 is_left[j]=is_left[dupes[j]];
754 if (
m_mode==PT_MULTICLASS)
755 g=
gain(wleft,wright,total_wclasses);
756 else if (
m_mode==PT_REGRESSION)
757 g=
gain(wleft,wright,total_wclasses,ulabels);
759 SG_ERROR(
"Undefined problem statement\n");
763 best_attribute=idx[
i];
765 sg_memcpy(is_left_final.vector,is_left.vector,is_left.vlen*
sizeof(
bool));
766 num_missing_final=num_vecs-n_nm_vecs;
769 for (int32_t l=0;l<c+1;l++)
770 count_left=(feats_left[l])?count_left+1:count_left;
772 count_right=c+1-count_left;
776 for (int32_t w=0;w<c+1;w++)
781 right[
r++]=ufeats[w];
789 SGVector<float64_t> right_wclasses=total_wclasses.clone();
790 SGVector<float64_t> left_wclasses(n_ulabels);
791 left_wclasses.zero();
795 float64_t z=feats[0];
796 right_wclasses[simple_labels[sorted_args[0]]]-=weights[sorted_args[0]];
797 left_wclasses[simple_labels[sorted_args[0]]]+=weights[sorted_args[0]];
798 for (int32_t j=1;j<n_nm_vecs;j++)
802 right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]];
803 left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
808 if (
m_mode==PT_MULTICLASS)
809 g=
gain(left_wclasses,right_wclasses,total_wclasses);
810 else if (
m_mode==PT_REGRESSION)
811 g=
gain(left_wclasses,right_wclasses,total_wclasses,ulabels);
813 SG_ERROR(
"Undefined problem statement\n");
818 best_attribute=idx[
i];
820 num_missing_final=num_vecs-n_nm_vecs;
826 right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]];
827 left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
832 while (n_nm_vecs<feats.vlen)
834 total_wclasses[simple_labels[sorted_args[n_nm_vecs-1]]]+=weights[sorted_args[n_nm_vecs-1]];
839 if (best_attribute==-1)
844 left[0]=best_threshold;
845 right[0]=best_threshold;
850 SGVector<float64_t> temp_vec(mat.get_column_vector(best_attribute), mat.num_rows,
false);
851 SGVector<index_t> sorted_indices(
m_sorted_indices.get_column_vector(best_attribute), mat.num_rows,
false);
853 for(int32_t
i=0;
i<mat.num_rows;
i++)
855 if (indices_mask[sorted_indices[
i]]>=0)
857 is_left_final[indices_mask[sorted_indices[
i]]]=(temp_vec[
i]<=best_threshold);
863 for (int32_t
i=num_vecs-1;
i>=0;
i--)
866 is_left_final[
i]=is_left_final[dupes[
i]];
872 for (int32_t
i=0;
i<num_vecs;
i++)
873 is_left_final[
i]=(mat(best_attribute,
i)<=best_threshold);
878 return best_attribute;
884 SGVector<bool> ret(m.num_cols);
891 CDynamicArray<int32_t>* missing_vecs=
new CDynamicArray<int32_t>();
893 CDynamicArray<float64_t>* association_index=
new CDynamicArray<float64_t>();
894 for (int32_t
i=0;
i<m.num_cols;
i++)
896 if (!CMath::fequals(m(attr,
i),
MISSING,0))
905 missing_vecs->push_back(
i);
906 association_index->push_back(0.);
911 float64_t p_r=(total-p_l)/total;
913 float64_t p=CMath::min(p_r,p_l);
916 for (int32_t
i=0;
i<m.num_rows;
i++)
922 CDynamicArray<int32_t>* intersect_vecs=
new CDynamicArray<int32_t>();
923 for (int32_t j=0;j<m.num_cols;j++)
925 if (!(CMath::fequals(m(
i,j),
MISSING,0) || CMath::fequals(m(attr,j),
MISSING,0)))
926 intersect_vecs->push_back(j);
929 if (intersect_vecs->get_num_elements()==0)
931 SG_UNREF(intersect_vecs);
941 SG_UNREF(intersect_vecs);
945 for (int32_t
i=0;
i<association_index->get_num_elements();
i++)
947 if (association_index->get_element(
i)==0.)
948 ret[missing_vecs->get_element(
i)]=(p_l>=p_r);
951 SG_UNREF(missing_vecs);
952 SG_UNREF(association_index);
957 CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
958 SGVector<float64_t> weights, float64_t p, int32_t attr)
962 SGVector<float64_t> feats(intersect_vecs->get_num_elements());
963 for (int32_t j=0;j<intersect_vecs->get_num_elements();j++)
965 feats[j]=m(attr,intersect_vecs->get_element(j));
966 denom+=weights[intersect_vecs->get_element(j)];
970 int32_t num_unique=feats.unique(feats.vector,feats.vlen);
974 for (int32_t j=0;j<num_unique-1;j++)
976 float64_t z=feats[j];
979 for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
982 if ((m(attr,intersect_vecs->get_element(k))<=z) && is_left[intersect_vecs->get_element(k)])
983 numer+=weights[intersect_vecs->get_element(k)];
984 else if ((m(attr,intersect_vecs->get_element(k))>z) && !is_left[intersect_vecs->get_element(k)])
985 numer+=weights[intersect_vecs->get_element(k)];
987 else if ((m(attr,intersect_vecs->get_element(k))<=z) && !is_left[intersect_vecs->get_element(k)])
988 numerc+=weights[intersect_vecs->get_element(k)];
989 else if ((m(attr,intersect_vecs->get_element(k))>z) && is_left[intersect_vecs->get_element(k)])
990 numerc+=weights[intersect_vecs->get_element(k)];
995 lambda=(p-(1-numer/denom))/p;
997 lambda=(p-(1-numerc/denom))/p;
998 for (int32_t k=0;k<missing_vecs->get_num_elements();k++)
1000 if ((lambda>association_index->get_element(k)) &&
1001 (!CMath::fequals(m(attr,missing_vecs->get_element(k)),
MISSING,0)))
1003 association_index->set_element(lambda,k);
1005 is_left[missing_vecs->get_element(k)]=(m(attr,missing_vecs->get_element(k))<=z);
1007 is_left[missing_vecs->get_element(k)]=(m(attr,missing_vecs->get_element(k))>z);
1014 CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
1015 SGVector<float64_t> weights, float64_t p, int32_t attr)
1019 SGVector<float64_t> feats(intersect_vecs->get_num_elements());
1020 for (int32_t j=0;j<intersect_vecs->get_num_elements();j++)
1022 feats[j]=m(attr,intersect_vecs->get_element(j));
1023 denom+=weights[intersect_vecs->get_element(j)];
1027 int32_t num_unique=feats.unique(feats.vector,feats.vlen);
1030 int32_t num_cases=CMath::pow(2,(num_unique-1));
1031 for (int32_t j=1;j<num_cases;j++)
1033 SGVector<bool> feats_left(num_unique);
1034 for (int32_t k=0;k<num_unique;k++)
1035 feats_left[k]=((j/CMath::pow(2,k))%(CMath::pow(2,k+1))==1);
1037 SGVector<bool> intersect_vecs_left(intersect_vecs->get_num_elements());
1038 for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
1040 for (int32_t q=0;q<num_unique;q++)
1042 if (feats[q]==m(attr,intersect_vecs->get_element(k)))
1044 intersect_vecs_left[k]=feats_left[q];
1051 float64_t numerc=0.;
1052 for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
1055 if (intersect_vecs_left[k]==is_left[intersect_vecs->get_element(k)])
1056 numer+=weights[intersect_vecs->get_element(k)];
1058 numerc+=weights[intersect_vecs->get_element(k)];
1062 float64_t lambda=0.;
1064 lambda=(p-(1-numer/denom))/p;
1066 lambda=(p-(1-numerc/denom))/p;
1069 for (int32_t k=0;k<missing_vecs->get_num_elements();k++)
1071 if ((lambda>association_index->get_element(k)) &&
1072 (!CMath::fequals(m(attr,missing_vecs->get_element(k)),
MISSING,0)))
1074 association_index->set_element(lambda,k);
1076 for (int32_t q=0;q<num_unique;q++)
1078 if (feats[q]==m(attr,missing_vecs->get_element(k)))
1081 is_left[missing_vecs->get_element(k)]=feats_left[q];
1083 is_left[missing_vecs->get_element(k)]=!feats_left[q];
1093 float64_t
CMyCARTree::gain(SGVector<float64_t> wleft, SGVector<float64_t> wright, SGVector<float64_t> wtotal,
1094 SGVector<float64_t> feats)
1096 float64_t total_lweight=0;
1097 float64_t total_rweight=0;
1098 float64_t total_weight=0;
1104 return lsd_n-(lsd_l*(total_lweight/total_weight))-(lsd_r*(total_rweight/total_weight));
1107 float64_t
CMyCARTree::gain(
const SGVector<float64_t>& wleft,
const SGVector<float64_t>& wright,
const SGVector<float64_t>& wtotal)
1109 float64_t total_lweight=0;
1110 float64_t total_rweight=0;
1111 float64_t total_weight=0;
1116 return gini_n-(gini_l*(total_lweight/total_weight))-(gini_r*(total_rweight/total_weight));
1121 Map<VectorXd> map_weighted_lab_classes(weighted_lab_classes.vector, weighted_lab_classes.size());
1122 total_weight=map_weighted_lab_classes.sum();
1123 float64_t gini=map_weighted_lab_classes.dot(map_weighted_lab_classes);
1125 gini=1.0-(gini/(total_weight*total_weight));
1132 Map<VectorXd> map_weights(weights.vector, weights.size());
1133 Map<VectorXd> map_feats(feats.vector, weights.size());
1134 float64_t mean=map_weights.dot(map_feats);
1135 total_weight=map_weights.sum();
1139 for (int32_t
i=0;
i<weights.vlen;
i++)
1140 dev+=weights[
i]*(feats[
i]-mean)*(feats[
i]-mean);
1142 return dev/total_weight;
1147 int32_t num_vecs=feats->get_num_vectors();
1148 REQUIRE(num_vecs>0,
"No data provided in apply\n");
1153 SGVector<float64_t> labels(num_vecs);
1154 for (int32_t
i=0;
i<num_vecs;
i++)
1156 SGVector<float64_t> sample=feats->get_feature_vector(
i);
1157 bnode_t* node=current;
1161 while(node->data.num_leaves!=1)
1163 bnode_t* leftchild=node->left();
1167 SGVector<float64_t> comp=leftchild->data.transit_into_values;
1169 for (int32_t k=0;k<comp.vlen;k++)
1171 if (comp[k]==sample[node->data.attribute_id])
1192 if (sample[node->data.attribute_id]<=leftchild->data.transit_into_values[0])
1205 SG_UNREF(leftchild);
1208 labels[
i]=node->data.node_label;
1211 m_certainty[
i]=((node->data.total_weight-node->data.weight_minus_node)/
1212 node->data.total_weight);
1221 CMulticlassLabels* mlabels=
new CMulticlassLabels(labels);
1227 CRegressionLabels* rlabels=
new CRegressionLabels(labels);
1232 SG_ERROR(
"mode should be either PT_MULTICLASS or PT_REGRESSION\n");
1245 int32_t num_vecs=data->get_num_vectors();
1248 SGVector<int32_t> subid(num_vecs);
1249 subid.random_vector(subid.vector,subid.vlen,0,folds-1);
1252 CDynamicArray<float64_t>* r_cv=
new CDynamicArray<float64_t>();
1253 CDynamicArray<float64_t>* alphak=
new CDynamicArray<float64_t>();
1254 SGVector<int32_t> num_alphak(folds);
1255 for (int32_t
i=0;
i<folds;
i++)
1258 CDynamicArray<int32_t>* test_indices=
new CDynamicArray<int32_t>();
1259 CDynamicArray<int32_t>* train_indices=
new CDynamicArray<int32_t>();
1260 for (int32_t j=0;j<num_vecs;j++)
1263 test_indices->push_back(j);
1265 train_indices->push_back(j);
1268 if (test_indices->get_num_elements()==0 || train_indices->get_num_elements()==0)
1270 SG_ERROR(
"Unfortunately you have reached the very low probability event where atleast one of "
1271 "the subsets in cross-validation is not represented at all. Please re-run.")
1274 SGVector<int32_t> subset(train_indices->get_array(),train_indices->get_num_elements(),
false);
1275 data->add_subset(subset);
1276 m_labels->add_subset(subset);
1277 SGVector<float64_t> subset_weights(train_indices->get_num_elements());
1278 for (int32_t j=0;j<train_indices->get_num_elements();j++)
1279 subset_weights[j]=
m_weights[train_indices->get_element(j)];
1282 bnode_t* root=
CARTtrain(data,subset_weights,m_labels,0);
1285 CTreeMachine<MyCARTreeNodeData>* tmax=
new CTreeMachine<MyCARTreeNodeData>();
1286 tmax->set_root(root);
1287 CDynamicObjectArray* pruned_trees=
prune_tree(tmax);
1289 data->remove_subset();
1290 m_labels->remove_subset();
1291 subset=SGVector<int32_t>(test_indices->get_array(),test_indices->get_num_elements(),
false);
1292 data->add_subset(subset);
1293 m_labels->add_subset(subset);
1294 subset_weights=SGVector<float64_t>(test_indices->get_num_elements());
1295 for (int32_t j=0;j<test_indices->get_num_elements();j++)
1296 subset_weights[j]=
m_weights[test_indices->get_element(j)];
1299 num_alphak[
i]=
m_alphas->get_num_elements();
1300 for (int32_t j=0;j<
m_alphas->get_num_elements();j++)
1302 alphak->push_back(
m_alphas->get_element(j));
1303 CSGObject* jth_element=pruned_trees->get_element(j);
1304 bnode_t* current_root=NULL;
1305 if (jth_element!=NULL)
1306 current_root=
dynamic_cast<bnode_t*
>(jth_element);
1308 SG_ERROR(
"%d element is NULL which should not be",j);
1311 float64_t error=
compute_error(labels, m_labels, subset_weights);
1312 r_cv->push_back(error);
1314 SG_UNREF(jth_element);
1317 data->remove_subset();
1318 m_labels->remove_subset();
1319 SG_UNREF(train_indices);
1320 SG_UNREF(test_indices);
1322 SG_UNREF(pruned_trees);
1326 CDynamicObjectArray* pruned_trees=
prune_tree(
this);
1329 int32_t min_index=-1;
1330 float64_t min_r_cv=CMath::MAX_REAL_NUMBER;
1331 for (int32_t
i=0;
i<
m_alphas->get_num_elements();
i++)
1341 for (int32_t j=0;j<folds;j++)
1344 for (int32_t k=base;k<num_alphak[j]+base-1;k++)
1346 if (alphak->get_element(k)<=alpha && alphak->get_element(k+1)>alpha)
1348 rv+=r_cv->get_element(k);
1355 rv+=r_cv->get_element(num_alphak[j]+base-1);
1357 base+=num_alphak[j];
1367 CSGObject* element=pruned_trees->get_element(min_index);
1368 bnode_t* best_tree_root=NULL;
1370 best_tree_root=
dynamic_cast<bnode_t*
>(element);
1372 SG_ERROR(
"%d element is NULL which should not be",min_index);
1374 this->set_root(best_tree_root);
1377 SG_UNREF(pruned_trees);
1384 REQUIRE(labels,
"input labels cannot be NULL");
1385 REQUIRE(reference,
"reference labels cannot be NULL")
1387 CDenseLabels* gnd_truth=
dynamic_cast<CDenseLabels*
>(reference);
1388 CDenseLabels* result=
dynamic_cast<CDenseLabels*
>(labels);
1390 float64_t denom=weights.sum(weights);
1396 for (int32_t
i=0;
i<weights.vlen;
i++)
1398 if (gnd_truth->get_label(
i)!=result->get_label(
i))
1407 for (int32_t
i=0;
i<weights.vlen;
i++)
1408 numer+=weights[
i]*CMath::pow((gnd_truth->get_label(
i)-result->get_label(
i)),2);
1414 SG_ERROR(
"Case not possible\n");
1422 REQUIRE(tree,
"Tree not provided for pruning.\n");
1424 CDynamicObjectArray* trees=
new CDynamicObjectArray();
1426 m_alphas=
new CDynamicArray<float64_t>();
1431 CTreeMachine<MyCARTreeNodeData>* t1=tree->clone_tree();
1433 node_t* t1root=t1->get_root();
1434 bnode_t* t1_root=NULL;
1436 t1_root=
dynamic_cast<bnode_t*
>(t1root);
1438 SG_ERROR(
"t1_root is NULL. This is not expected\n")
1441 trees->push_back(t1_root);
1442 while(t1_root->data.num_leaves>1)
1444 CTreeMachine<MyCARTreeNodeData>* t2=t1->clone_tree();
1447 node_t* t2root=t2->get_root();
1448 bnode_t* t2_root=NULL;
1450 t2_root=
dynamic_cast<bnode_t*
>(t2root);
1452 SG_ERROR(
"t1_root is NULL. This is not expected\n")
1457 trees->push_back(t2_root);
1472 if (node->data.num_leaves!=1)
1474 bnode_t* left=node->left();
1475 bnode_t* right=node->right();
1477 SGVector<float64_t> weak_links(3);
1480 weak_links[2]=(node->data.weight_minus_node-node->data.weight_minus_branch)/node->data.total_weight;
1481 weak_links[2]/=(node->data.num_leaves-1.0);
1485 return CMath::min(weak_links.vector,weak_links.vlen);
1488 return CMath::MAX_REAL_NUMBER;
1493 if (node->data.num_leaves==1)
1496 float64_t g=(node->data.weight_minus_node-node->data.weight_minus_branch)/node->data.total_weight;
1497 g/=(node->data.num_leaves-1.0);
1500 node->data.num_leaves=1;
1501 node->data.weight_minus_branch=node->data.weight_minus_node;
1502 CDynamicObjectArray* children=
new CDynamicObjectArray();
1503 node->set_children(children);
1509 bnode_t* left=node->left();
1510 bnode_t* right=node->right();
1513 node->data.num_leaves=left->data.num_leaves+right->data.num_leaves;
1514 node->data.weight_minus_branch=left->data.weight_minus_branch+right->data.weight_minus_branch;
1523 if (node->data.num_leaves!=1)
1525 bnode_t* left=node->left();
1526 bnode_t* right=node->right();
1531 node->data.num_leaves=left->data.num_leaves+right->data.num_leaves;
1532 node->data.weight_minus_branch=left->data.weight_minus_branch+right->data.weight_minus_branch;
1533 if (node->data.weight_minus_node==node->data.weight_minus_branch)
1535 node->data.num_leaves=1;
1536 CDynamicObjectArray* children=
new CDynamicObjectArray();
1537 node->set_children(children);
1557 m_alphas=
new CDynamicArray<float64_t>();
1565 SG_ADD(&
m_pre_sort,
"m_pre_sort",
"presort", MS_NOT_AVAILABLE);
1566 SG_ADD(&
m_sorted_features,
"m_sorted_features",
"sorted feats", MS_NOT_AVAILABLE);
1567 SG_ADD(&
m_sorted_indices,
"m_sorted_indices",
"sorted indices", MS_NOT_AVAILABLE);
1568 SG_ADD(&
m_nominal,
"m_nominal",
"feature types", MS_NOT_AVAILABLE);
1569 SG_ADD(&
m_weights,
"m_weights",
"weights", MS_NOT_AVAILABLE);
1570 SG_ADD(&
m_weights_set,
"m_weights_set",
"weights set", MS_NOT_AVAILABLE);
1571 SG_ADD(&
m_types_set,
"m_types_set",
"feature types set", MS_NOT_AVAILABLE);
1572 SG_ADD(&
m_apply_cv_pruning,
"m_apply_cv_pruning",
"apply cross validation pruning", MS_NOT_AVAILABLE);
1573 SG_ADD(&
m_folds,
"m_folds",
"number of subsets for cross validation", MS_NOT_AVAILABLE);
1574 SG_ADD(&
m_max_depth,
"m_max_depth",
"max allowed tree depth", MS_NOT_AVAILABLE)
1575 SG_ADD(&
m_min_node_size,
"m_min_node_size",
"min allowed node size", MS_NOT_AVAILABLE)
1576 SG_ADD(&
m_label_epsilon,
"m_label_epsilon",
"epsilon for labels", MS_NOT_AVAILABLE)
1577 SG_ADD((machine_int_t*)&
m_mode,
"m_mode",
"problem type (multiclass or regression)", MS_NOT_AVAILABLE)
1589 vector<double> importances(dt.size(),0.0);
1591 bnode_t* node =
dynamic_cast<bnode_t*
>(m_root);
1600 if (node->data.num_leaves!=1)
1602 bnode_t* left=node->left();
1603 bnode_t* right=node->right();
1606 importances[node->data.attribute_id] += node->data.IG;
1619 int size = labels->get_num_labels();
1622 std::cout <<
"ERROR: mismatch in size btw m_certainty and labels\n";
1624 for (
int i = 0;
i < size; ++
i)
1626 if (labels->get_value(
i) > 0)
int32_t get_min_node_size() const
CDynamicArray< float64_t > * m_alphas
void set_max_depth(int32_t depth)
float64_t m_label_epsilon
void set_weights(SGVector< float64_t > w)
void handle_missing_vecs_for_continuous_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
std::vector< double > feature_importances()
int32_t get_max_depth() const
SGVector< float64_t > get_weights() const
virtual int32_t compute_best_attribute(const SGMatrix< float64_t > &mat, const SGVector< float64_t > &weights, CLabels *labels, SGVector< float64_t > &left, SGVector< float64_t > &right, SGVector< bool > &is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right, float64_t &IG, int32_t subset_size=0, const SGVector< int32_t > &active_indices=SGVector< index_t >())
static const float64_t MISSING
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
CMyCARTree()
This class implements the Classification And Regression Trees algorithm by Breiman et al for decision...
SGVector< float64_t > get_unique_labels(SGVector< float64_t > labels_vec, int32_t &n_ulabels)
void set_machine_problem_type(EProblemType mode)
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
float64_t gain(SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
static const float64_t MIN_SPLIT_GAIN
void form_t1(bnode_t *node)
void cut_weakest_link(bnode_t *node, float64_t alpha)
SGVector< float64_t > m_weights
void set_min_node_size(int32_t nsize)
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
int32_t get_num_folds() const
float64_t gini_impurity_index(const SGVector< float64_t > &weighted_lab_classes, float64_t &total_weight)
virtual void set_labels(CLabels *lab)
void prune_by_cross_validation(CDenseFeatures< float64_t > *data, int32_t folds)
SGVector< bool > get_feature_types() const
virtual CBinaryTreeMachineNode< MyCARTreeNodeData > * CARTtrain(CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
static const float64_t EQ_DELTA
virtual bool is_label_valid(CLabels *lab) const
float64_t least_squares_deviation(const SGVector< float64_t > &labels, const SGVector< float64_t > &weights, float64_t &total_weight)
SGVector< float64_t > m_certainty
CDynamicObjectArray * prune_tree(CTreeMachine< MyCARTreeNodeData > *tree)
void handle_missing_vecs_for_nominal_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
void set_cv_pruning(bool cv_pruning)
virtual bool train_machine(CFeatures *data=NULL)
SGMatrix< float64_t > m_sorted_features
void set_probabilities(CLabels *labels, CFeatures *data=NULL)
SGMatrix< index_t > m_sorted_indices
float64_t compute_error(CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
void get_importance(bnode_t *node, vector< double > &importances)
SGVector< bool > surrogate_split(SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
float64_t find_weakest_alpha(bnode_t *node)
CLabels * apply_from_current_node(CDenseFeatures< float64_t > *feats, bnode_t *current, bool set_certainty=false)
void set_feature_types(SGVector< bool > ft)
void prune_using_test_dataset(CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
SGVector< float64_t > get_certainty_vector() const
void clear_feature_types()
void pre_sort_features(CFeatures *data, SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
void set_sorted_features(SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
void set_label_epsilon(float64_t epsilon)
SGVector< bool > m_nominal
void set_num_folds(int32_t folds)
vector< size_t > argsort(const vector< T > &v, bool ascending=true)
return indices that sort a vector
structure to store data of a node of CART. This can be used as a template type in TreeMachineNode cla...