Feat C++ API
A feature engineering automation tool
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MulticlassLogisticRegression.cc
Go to the documentation of this file.
1 /* This program is free software; you can redistribute it and/or modify
2  * it under the terms of the GNU General Public License as published by
3  * the Free Software Foundation; either version 3 of the License, or
4  * (at your option) any later version.
5  *
6  * Written (W) 2012 Sergey Lisitsyn
7  * Copyright (C) 2012 Sergey Lisitsyn
8  */
9 
11 
12 #include<iostream>
13 
14 namespace shogun
15 {
16 
18  {
19  return "MulticlassLogisticRegression";
20  }
21 
23  {
24  ASSERT(z>0)
25  m_z = z;
26  }
27 
28  inline float64_t CMulticlassLogisticRegression::get_z() const { return m_z; }
29 
31  {
32  ASSERT(epsilon>0)
33  m_epsilon = epsilon;
34  }
35 
36  inline float64_t CMulticlassLogisticRegression::get_epsilon() const { return m_epsilon; }
37 
39  {
40  ASSERT(max_iter>0)
41  m_max_iter = max_iter;
42  }
43 
44  inline int32_t CMulticlassLogisticRegression::get_max_iter() const { return m_max_iter; }
45 
46  using namespace shogun;
47 
49  CLinearMulticlassMachine()
50  {
51  init_defaults();
52  }
53 
54  CMulticlassLogisticRegression::CMulticlassLogisticRegression(float64_t z, CDotFeatures* feats, CLabels* labs) :
55  CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),feats,NULL,labs)
56  {
57  init_defaults();
58  set_z(z);
59  }
60 
62  {
63  set_z(0.1);
64  set_epsilon(1e-2);
65  set_max_iter(10000);
66  }
67 
69  {
70  SG_ADD(&m_z, "m_z", "regularization constant",MS_AVAILABLE);
71  SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE);
72  SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE);
73  }
74 
76  {
77 
78  for (int i = 0; i < m_machines->get_num_elements(); ++i)
79  {
80  auto ptr = m_machines->get_element(i);
81  while (ptr->ref_count() > 2)
82  {
83  SG_UNREF(ptr);
84  }
85  SG_UNREF(ptr);
86  }
87  m_machines->reset_array();
88 
90  if (m_features)
91  {
92  delete m_features;
93  m_features = NULL;
94  }
95  if (m_machines)
96  {
97  delete m_machines;
98  m_machines = NULL;
99  }
100  if (m_machine)
101  {
102  delete m_machine;
103  m_machine = NULL;
104  }
105  if (m_multiclass_strategy)
106  {
107  delete m_multiclass_strategy;
108  m_multiclass_strategy = NULL;
109  }
110  }
111 
112 
113  vector<SGVector<float64_t>> CMulticlassLogisticRegression::get_w()
114  {
115  vector<SGVector<float64_t>> weights_vector;
116 
117  int n_machines = get_num_machines();
118 
119  for (int32_t i=0; i<n_machines; i++)
120  {
121  CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i);
122  weights_vector.push_back(machine->get_w());
123  SG_UNREF(machine);
124  }
125 
126  return weights_vector;
127  }
128 
130  {
131  vector<float64_t> bias_vector;
132 
133  int n_machines = get_num_machines();
134 
135  for (int32_t i=0; i<n_machines; i++)
136  {
137  CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i);
138  bias_vector.push_back(machine->get_bias());
139  SG_UNREF(machine);
140  }
141 
142  return bias_vector;
143  }
144 
145  void CMulticlassLogisticRegression::set_w(vector<Eigen::VectorXd>& wnew)
146  {
147 
148  int n_machines = get_num_machines();
149 
150  for (int32_t i=0; i<n_machines; i++)
151  {
152  CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i);
153  machine->set_w(SGVector<float64_t>(wnew.at(i)));
154  SG_UNREF(machine);
155  }
156 
157  }
158 
160  {
161  if (data)
162  set_features((CDotFeatures*)data);
163 
164  m_machines->reset_array();
165 
166  REQUIRE(m_features, "%s::train_machine(): No features attached!\n");
167  REQUIRE(m_labels, "%s::train_machine(): No labels attached!\n");
168  REQUIRE(m_labels->get_label_type()==LT_MULTICLASS, "%s::train_machine(): "
169  "Attached labels are no multiclass labels\n");
170  REQUIRE(m_multiclass_strategy, "%s::train_machine(): No multiclass strategy"
171  " attached!\n");
172 
173  int32_t n_classes = ((CMulticlassLabels*)m_labels)->get_num_classes();
174  int32_t n_feats = m_features->get_dim_feature_space();
175 
176  slep_options options = slep_options::default_options();
177  options.tolerance = m_epsilon;
178  options.max_iter = m_max_iter;
179  slep_result_t result = slep_mc_plain_lr(m_features,
180  (CMulticlassLabels*)m_labels,
181  m_z,
182  options);
183 
184  SGMatrix<float64_t> all_w = result.w;
185  SGVector<float64_t> all_c = result.c;
186 
187  for (int32_t i=0; i<n_classes; i++)
188  {
189  SGVector<float64_t> w(n_feats);
190  for (int32_t j=0; j<n_feats; j++)
191  w[j] = all_w(j,i);
192 
193  float64_t c = all_c[i];
194  CLinearMachine* machine = new CLinearMachine();
195  machine->set_w(w);
196  machine->set_bias(c);
197  m_machines->push_back(machine);
198  }
199  return true;
200  }
201 }
void set_w(vector< Eigen::VectorXd > &wnew)
virtual bool train_machine(CFeatures *data=NULL)
int i
Definition: params.cc:552