Feat C++ API
A feature engineering automation tool
n_divide.cc
Go to the documentation of this file.
1 /* FEAT
2 copyright 2017 William La Cava
3 license: GNU/GPL v3
4 */
5 
6 #include "n_divide.h"
7 
8 namespace FT{
9 
10  namespace Pop{
11  namespace Op{
12  NodeDivide::NodeDivide(vector<float> W0)
13  {
14  name = "/";
15  otype = 'f';
16  arity['f'] = 2;
17  arity['b'] = 0;
18  complexity = 2;
19 
20  if (W0.empty())
21  {
22  for (int i = 0; i < arity['f']; i++) {
23  W.push_back(r.rnd_dbl());
24  }
25  }
26  else
27  W = W0;
28  }
29 
30  #ifndef USE_CUDA
32  void NodeDivide::evaluate(const Data& data, State& state)
33  {
34  ArrayXf x1 = state.pop<float>();
35  ArrayXf x2 = state.pop<float>();
36  // safe division returns x1/x2 if x2 != 0, and MAX_FLT otherwise
37  ArrayXf ret = (this->W[0] * x1) / (this->W[1] * x2);
38  clean(ret);
39  state.push<float>(ret);
40  }
41  #else
42  void NodeDivide::evaluate(const Data& data, State& state)
43  {
44  GPU_Divide(state.dev_f, state.idx[otype], state.N, W[0], W[1]);
45  }
46  #endif
47 
50  {
51  state.push<float>("(" + to_string(W[0], 4) + "*" + state.popStr<float>() + "/"
52  + to_string(W[1], 4) + "*" + state.popStr<float>() + ")");
53  }
54 
55  // Might want to check derivative orderings for other 2 arg nodes
56  ArrayXf NodeDivide::getDerivative(Trace& state, int loc)
57  {
58  ArrayXf& x1 = state.get<float>()[state.size<float>()-1];
59  ArrayXf& x2 = state.get<float>()[state.size<float>()-2];
60 
61  switch (loc) {
62  case 3: // d/dW[1]
63  return limited(-this->W[0] * x1/(x2 * pow(this->W[1], 2)));
64  case 2: // d/dW[0]
65  return limited(x1/(this->W[1] * x2));
66  case 1: // d/dx2
67  {
68  /* std::cout << "x1: " << x1.transpose() << "\n"; */
69  /* ArrayXf num = -this->W[0] * x1; */
70  /* ArrayXf denom = limited(this->W[1] * pow(x2, 2)); */
71  /* ArrayXf val = num/denom; */
72  return limited((-this->W[0] * x1)/(this->W[1] * pow(x2, 2)));
73  }
74  case 0: // d/dx1
75  default:
76  return limited(this->W[0]/(this->W[1] * x2));
77  // return limited(this->W[1]/(this->W[0] * x2));
78  }
79  }
80 
81  NodeDivide* NodeDivide::clone_impl() const { return new NodeDivide(*this); }
82 
83  NodeDivide* NodeDivide::rnd_clone_impl() const { return new NodeDivide(); }
84  }
85  }
86 }
data holding X, y, and Z data
Definition: data.h:42
NodeDivide * clone_impl() const override
Definition: n_divide.cc:81
NodeDivide(vector< float > W0=vector< float >())
Definition: n_divide.cc:12
ArrayXf getDerivative(Trace &state, int loc)
Definition: n_divide.cc:56
void evaluate(const Data &data, State &state)
Evaluates the node and updates the state states.
Definition: n_divide.cc:32
void eval_eqn(State &state)
Evaluates the node symbolically.
Definition: n_divide.cc:49
NodeDivide * rnd_clone_impl() const override
Definition: n_divide.cc:83
std::vector< float > W
Definition: n_Dx.h:16
string name
node type
Definition: node.h:56
std::map< char, unsigned int > arity
arity of the operator
Definition: node.h:59
ArrayXf limited(ArrayXf x)
limits node output to be between MIN_FLT and MAX_FLT
Definition: node.cc:37
char otype
output type
Definition: node.h:58
int complexity
complexity of node
Definition: node.h:60
float rnd_dbl(float min=0.0, float max=1.0)
Definition: rnd.cc:89
void GPU_Divide(float *x, size_t idx, size_t N, float W0, float W1)
static Rnd & r
Definition: rnd.h:135
std::string to_string(const T &value)
template function to convert objects to string for logging
Definition: utils.h:422
void clean(ArrayXf &x)
limits node output to be between MIN_FLT and MAX_FLT
Definition: utils.cc:18
main Feat namespace
Definition: data.cc:13
int i
Definition: params.cc:552
contains various types of State actually used by feat
Definition: state.h:102
string popStr()
Definition: state.h:143
Eigen::Array< T, Eigen::Dynamic, 1 > pop()
Definition: state.h:128
void push(Eigen::Array< T, Eigen::Dynamic, 1 > value)
Definition: state.h:123
used for tracing stack outputs for backprop algorithm.
Definition: state.h:232
unsigned int size()
Definition: state.h:242
vector< Eigen::Array< T, Eigen::Dynamic, 1 > > & get()
Definition: state.h:237