16 CLinearMulticlassMachine()
27 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),features,NULL,labs)
46 return "MulticlassLibLinear";
139 SG_ERROR(
"Please enable save_train_state option and train machine.\n")
141 ASSERT(m_labels && m_labels->get_label_type() == LT_MULTICLASS)
143 int32_t num_vectors = m_features->get_num_vectors();
144 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
146 v_array<int32_t> nz_idxs;
147 nz_idxs.reserve(num_vectors);
149 for (int32_t
i=0;
i<num_vectors;
i++)
151 for (int32_t y=0; y<num_classes; y++)
160 int32_t num_nz = nz_idxs.index();
161 nz_idxs.reserve(num_nz);
162 return SGVector<int32_t>(nz_idxs.begin,num_nz);
170 vector<SGVector<float64_t>> weights_vector;
171 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
173 for (int32_t
i=0;
i<num_classes;
i++)
175 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(
i);
176 weights_vector.push_back(machine->get_w().clone());
179 return weights_vector;
186 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
188 for (int32_t
i=0;
i<num_classes;
i++)
190 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(
i);
191 machine->set_w(SGVector<float64_t>(wnew.at(
i)));
199 set_features((CDotFeatures*)data);
202 ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS)
203 ASSERT(m_multiclass_strategy)
206 int32_t num_vectors = m_features->get_num_vectors();
207 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
210 liblinear_problem mc_problem;
211 mc_problem.l = num_vectors;
212 mc_problem.n = m_features->get_dim_feature_space() + bias_n;
213 mc_problem.y = SG_MALLOC(float64_t, mc_problem.l);
214 for (int32_t
i=0;
i<num_vectors;
i++)
215 mc_problem.y[
i] = ((CMulticlassLabels*) m_labels)->get_int_label(
i);
217 mc_problem.x = m_features;
225 float64_t* C = SG_MALLOC(float64_t, num_vectors);
226 for (int32_t
i=0;
i<num_vectors;
i++)
229 Solver_MCSVM_CS solver(&mc_problem,num_classes,C,w0.matrix,
m_epsilon,
233 m_machines->reset_array();
234 for (int32_t
i=0;
i<num_classes;
i++)
236 CLinearMachine* machine =
new CLinearMachine();
237 SGVector<float64_t> cw(mc_problem.n-bias_n);
239 for (int32_t j=0; j<mc_problem.n-bias_n; j++)
245 machine->set_bias(
m_train_state->w[(mc_problem.n-bias_n)*num_classes+
i]);
247 m_machines->push_back(machine);
254 SG_FREE(mc_problem.y);
263 return SGMatrix<float64_t>();
278 SG_ADD(&
m_C,
"m_C",
"regularization constant",MS_AVAILABLE);
279 SG_ADD(&
m_epsilon,
"m_epsilon",
"tolerance epsilon",MS_NOT_AVAILABLE);
280 SG_ADD(&
m_max_iter,
"m_max_iter",
"max number of iterations",MS_NOT_AVAILABLE);
281 SG_ADD(&
m_use_bias,
"m_use_bias",
"indicates whether bias should be used",MS_NOT_AVAILABLE);
282 SG_ADD(&
m_save_train_state,
"m_save_train_state",
"indicates whether bias should be used",MS_NOT_AVAILABLE);
vector< SGVector< float64_t > > get_w() const
void set_epsilon(float64_t epsilon)
virtual const char * get_name() const
void set_use_bias(bool use_bias)
~CMyMulticlassLibLinear()
void set_max_iter(int32_t max_iter)
mcsvm_state * m_train_state
void set_w(vector< Eigen::VectorXd > wnew)
bool train_machine(CFeatures *data)
void register_parameters()
SGVector< int32_t > get_support_vectors() const
float64_t get_epsilon() const
SGMatrix< float64_t > obtain_regularizer_matrix() const
int32_t get_max_iter() const
bool get_use_bias() const
bool get_save_train_state() const
void set_save_train_state(bool save_train_state)