19 return "MulticlassLogisticRegression";
49 CLinearMulticlassMachine()
55 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),feats,NULL,labs)
70 SG_ADD(&
m_z,
"m_z",
"regularization constant",MS_AVAILABLE);
71 SG_ADD(&
m_epsilon,
"m_epsilon",
"tolerance epsilon",MS_NOT_AVAILABLE);
72 SG_ADD(&
m_max_iter,
"m_max_iter",
"max number of iterations",MS_NOT_AVAILABLE);
78 for (
int i = 0;
i < m_machines->get_num_elements(); ++
i)
80 auto ptr = m_machines->get_element(
i);
81 while (ptr->ref_count() > 2)
87 m_machines->reset_array();
105 if (m_multiclass_strategy)
107 delete m_multiclass_strategy;
108 m_multiclass_strategy = NULL;
115 vector<SGVector<float64_t>> weights_vector;
117 int n_machines = get_num_machines();
119 for (int32_t
i=0;
i<n_machines;
i++)
121 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(
i);
122 weights_vector.push_back(machine->get_w());
126 return weights_vector;
131 vector<float64_t> bias_vector;
133 int n_machines = get_num_machines();
135 for (int32_t
i=0;
i<n_machines;
i++)
137 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(
i);
138 bias_vector.push_back(machine->get_bias());
148 int n_machines = get_num_machines();
150 for (int32_t
i=0;
i<n_machines;
i++)
152 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(
i);
153 machine->set_w(SGVector<float64_t>(wnew.at(
i)));
162 set_features((CDotFeatures*)data);
164 m_machines->reset_array();
166 REQUIRE(m_features,
"%s::train_machine(): No features attached!\n");
167 REQUIRE(m_labels,
"%s::train_machine(): No labels attached!\n");
168 REQUIRE(m_labels->get_label_type()==LT_MULTICLASS,
"%s::train_machine(): "
169 "Attached labels are no multiclass labels\n");
170 REQUIRE(m_multiclass_strategy,
"%s::train_machine(): No multiclass strategy"
173 int32_t n_classes = ((CMulticlassLabels*)m_labels)->get_num_classes();
174 int32_t n_feats = m_features->get_dim_feature_space();
176 slep_options options = slep_options::default_options();
179 slep_result_t result = slep_mc_plain_lr(m_features,
180 (CMulticlassLabels*)m_labels,
184 SGMatrix<float64_t> all_w = result.w;
185 SGVector<float64_t> all_c = result.c;
187 for (int32_t
i=0;
i<n_classes;
i++)
189 SGVector<float64_t> w(n_feats);
190 for (int32_t j=0; j<n_feats; j++)
193 float64_t c = all_c[
i];
194 CLinearMachine* machine =
new CLinearMachine();
196 machine->set_bias(c);
197 m_machines->push_back(machine);
virtual const char * get_name() const
void set_w(vector< Eigen::VectorXd > &wnew)
void set_epsilon(float64_t epsilon)
int32_t get_max_iter() const
~CMulticlassLogisticRegression()
float64_t get_epsilon() const
vector< float64_t > get_bias()
void register_parameters()
vector< SGVector< float64_t > > get_w()
void set_max_iter(int32_t max_iter)
virtual bool train_machine(CFeatures *data=NULL)
CMulticlassLogisticRegression()