Feat C++ API
A feature engineering automation tool
state.cc
Go to the documentation of this file.
1 /* FEAT
2 copyright 2017 William La Cava
3 license: GNU/GPL v3
4 */
5 
6 #include "state.h"
7 #include <iostream>
8 
9 #ifdef USE_CUDA
10  #include "../pop/op/node.h"
11 #endif
12 
13 namespace FT
14 {
15  namespace Dat{
16 
17  #ifndef USE_CUDA
18 
19  bool State::check(std::map<char, unsigned int> &arity)
20  {
21  if(arity.find('z') == arity.end())
22  return (f.size() >= arity.at('f') &&
23  b.size() >= arity.at('b') &&
24  c.size() >= arity.at('c'));
25  else
26  return (f.size() >= arity.at('f') &&
27  b.size() >= arity.at('b') &&
28  c.size() >= arity.at('c') &&
29  z.size() >= arity.at('z'));
30  }
31 
33  // various string State
34  bool State::check_s(std::map<char, unsigned int> &arity)
35  {
36  if(arity.find('z') == arity.end())
37  return (fs.size() >= arity.at('f') &&
38  bs.size() >= arity.at('b') &&
39  cs.size() >= arity.at('c'));
40  else
41  return (fs.size() >= arity.at('f') &&
42  bs.size() >= arity.at('b') &&
43  cs.size() >= arity.at('c') &&
44  zs.size() >= arity.at('z'));
45  }
46 
47  void Trace::copy_to_trace(State& state, std::map<char,
48  unsigned int> &arity)
49  {
50  for (int i = 0; i < arity.at('f'); i++) {
51  /* cout << "push back float arg for " << program.at(i)->name << "\n"; */
52  f.push_back(state.f.at(state.f.size() - (arity.at('f') - i)));
53  }
54 
55  for (int i = 0; i < arity.at('c'); i++) {
56  /* cout << "push back float arg for " << program.at(i)->name << "\n"; */
57  c.push_back(state.c.at(state.c.size() - (arity.at('c') - i)));
58  }
59 
60  for (int i = 0; i < arity.at('b'); i++) {
61  /* cout << "push back bool arg for " << program.at(i)->name << "\n"; */
62  b.push_back(state.b.at(state.b.size() - (arity.at('b') - i)));
63  }
64  }
65 
66  #else
67  using namespace Pop::Op;
68 
69  State::State()
70  {
71  idx['f'] = 0;
72  idx['c'] = 0;
73  idx['b'] = 0;
74  }
75 
76  void State::update_idx(char otype, std::map<char, unsigned>& arity)
77  {
78  ++idx.at(otype);
79  for (const auto& a : arity)
80  idx.at(a.first) -= a.second;
81  }
82 
83  bool State::check(std::map<char, unsigned int> &arity)
84  {
85  if(arity.find('z') == arity.end())
86  return (f.rows() >= arity.at('f') &&
87  c.rows() >= arity.at('c') &&
88  b.rows() >= arity.at('b'));
89  else
90  return (f.rows() >= arity.at('f') &&
91  c.rows() >= arity.at('c') &&
92  b.rows() >= arity.at('b') &&
93  z.size() >= arity.at('z'));
94  }
95 
96  bool State::check_s(std::map<char, unsigned int> &arity)
97  {
98  if(arity.find('z') == arity.end())
99  return (fs.size() >= arity.at('f') &&
100  cs.size() >= arity.at('c') &&
101  bs.size() >= arity.at('b'));
102  else
103  return (fs.size() >= arity.at('f') &&
104  cs.size() >= arity.at('c') &&
105  bs.size() >= arity.at('b') &&
106  zs.size() >= arity.at('z'));
107  }
108 
109  void State::allocate(const std::map<char, size_t>& stack_size, size_t N)
110  {
111  //std::cout << "before dev_allocate, dev_f is " << dev_f << "\n";
112  dev_allocate(dev_f, N*stack_size.at('f'),
113  dev_c, N*stack_size.at('c'),
114  dev_b, N*stack_size.at('b'));
115  //std::cout << "after dev_allocate, dev_f is " << dev_f << "\n";
116 
117  //printf("Allocated Stack Sizes\n");
118  //printf("\tFloating stack N=%zu and stack size as %zu\n",N, stack_size.at('f'));
119 
120  this->N = N;
121 
122  f.resize(stack_size.at('f'),N);
123  c.resize(stack_size.at('c'),N);
124  b.resize(stack_size.at('b'),N);
125  }
126 
127  void State::limit()
128  {
129  // clean floating point stack.
130  for (unsigned r = 0 ; r < f.rows(); ++r)
131  {
132  f.row(r) = (f.row(r) < MIN_FLT).select(MIN_FLT,f.row(r));
133  f.row(r) = (f.row(r) > MAX_FLT).select(MAX_FLT,f.row(r));
134  f.row(r) = (isnan(f.row(r))).select(0,f.row(r));
135  }
136 
137  for (unsigned r = 0 ; r < c.rows(); ++r)
138  {
139  c.row(r) = (c.row(r) < MIN_FLT).select(MIN_FLT,c.row(r));
140  c.row(r) = (c.row(r) > MAX_FLT).select(MAX_FLT,c.row(r));
141  c.row(r) = (isnan(c.row(r))).select(0, c.row(r));
142  }
143  }
144 
146  void State::trim()
147  {
148  //std::cout << "resizing f to " << idx['f'] << "x" << f.cols() << "\n";
149  //f.resize(idx['f'],f.cols());
150  //b.resize(idx['b'],b.cols());
151  //std::cout << "new f size: " << f.size() << "," << f.rows() << "x" << f.cols() << "\n";
152  //unsigned frows = f.rows()-1;
153  //for (unsigned r = idx['f']; r < f.rows(); ++r)
154  // f.block(r,0,frows-r,f.cols()) = f.block(r+1,0,frows-r,f.cols());
155  // f.conservativeResize(frows,f.cols());
156  f.conservativeResize(idx.at('f'), f.cols());
157  c.conservativeResize(idx.at('c'), c.cols());
158  b.conservativeResize(idx.at('b'), b.cols());
159  }
160 
161  void State::copy_to_host()
162  {
163  /* std::cout << "size of f before copy_from_device: " << f.size() */
164  /* << ", stack size: " << N*stack_size.at('f') << "\n"; */
165  /* std::cout << "size of b before copy_from_device: " << b.size() */
166  /* << ", stack size: " << N*stack_size.at('b') << "\n"; */
167 
168  copy_from_device(dev_f, f.data(), N*idx.at('f'),
169  dev_c, c.data(), N*idx.at('c'),
170  dev_b, b.data(), N*idx.at('b'));
171  //copy_from_device(dev_f, f.data(), dev_b, b.data(), N*stack_size.at('f'), N*stack_size.at('b'));
172 
173  trim();
174  limit();
175  }
176 
177  void State::copy_to_host(float* host_f, int increment)
178  {
179  copy_from_device((dev_f+increment), host_f, N);
180  }
181 
182  void State::copy_to_host(int* host_i, int increment)
183  {
184  copy_from_device((dev_c+increment), host_i, N);
185  }
186 
187  State::~State()
188  {
189  free_device(dev_f, dev_c, dev_b);
190  }
191 
192  void Trace::copy_to_trace(State& state, std::map<char, unsigned int> &arity)
193  {
194  int increment;
195 
196  for (int i = 0; i < arity.at('f'); i++)
197  {
198  ArrayXf tmp(state.N);
199 
200  increment = (state.idx.at('f') - (arity.at('f') - i))*state.N;
201 
202  /*cout << "State index " << state.idx.at('f')
203  << " i = "<<i
204  << " arity.at('f') = " << arity.at('f')
205  << " increment = " << increment<<endl;*/
206 
207  copy_from_device((state.dev_f+increment), tmp.data(), state.N);
208 
209  f.push_back(tmp.cast<float>());
210  }
211 
212  for (int i = 0; i < arity.at('c'); i++)
213  {
214  ArrayXi tmp(state.N);
215 
216  increment = (state.idx.at('c') - (arity.at('c') - i))*state.N;
217  copy_from_device((state.dev_c+increment), tmp.data(), state.N);
218 
219  c.push_back(tmp);
220  }
221 
222  for (int i = 0; i < arity.at('b'); i++)
223  {
224  ArrayXb tmp(state.N);
225 
226  increment = (state.idx.at('b') - (arity.at('b') - i))*state.N;
227  copy_from_device((state.dev_b+increment), tmp.data(), state.N);
228 
229  b.push_back(tmp);
230  }
231  }
232 
233  #endif
234  }
235 }
236 
unsigned int size()
returns top element of stack
Definition: state.h:66
type & at(int i)
Definition: state.h:72
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition: data.h:21
void free_device(float *dev_f, int *dev_c, bool *dev_b)
void copy_from_device(float *dev_f, float *host_f, size_t Sizef)
void dev_allocate(float *&f, size_t Sizef, int *&c, size_t Sizec, bool *&b, size_t Sizeb)
ArrayXb isnan(const ArrayXf &x)
returns true for elements of x that are NaN
Definition: utils.cc:226
std::string trim(std::string str, const std::string &chars)
Definition: utils.cc:43
static Rnd & r
Definition: rnd.h:135
main Feat namespace
Definition: data.cc:13
int i
Definition: params.cc:552
static float MAX_FLT
Definition: init.h:47
static float MIN_FLT
Definition: init.h:48
contains various types of State actually used by feat
Definition: state.h:102
Stack< ArrayXb > b
boolean node stack
Definition: state.h:104
Stack< string > fs
floating node string stack
Definition: state.h:107
Stack< string > zs
longitudinal node string stack
Definition: state.h:110
bool check_s(std::map< char, unsigned int > &arity)
Definition: state.cc:34
Stack< string > bs
boolean node string stack
Definition: state.h:108
Stack< std::pair< vector< ArrayXf >, vector< ArrayXf > > > z
longitudinal node stack
Definition: state.h:106
bool check(std::map< char, unsigned int > &arity)
checks if arity of node provided satisfies the node names in various string State
Definition: state.cc:19
Stack< string > cs
categorical node string stack
Definition: state.h:109
Stack< ArrayXi > c
categorical stack
Definition: state.h:105
Stack< ArrayXf > f
floating node stack
Definition: state.h:103
vector< ArrayXf > f
Definition: state.h:233
vector< ArrayXb > b
Definition: state.h:235
void copy_to_trace(State &state, std::map< char, unsigned int > &arity)
Definition: state.cc:47
vector< ArrayXi > c
Definition: state.h:234