Feat C++ API
A feature engineering automation tool
data.h
Go to the documentation of this file.
1 /* FEAT
2 copyright 2017 William La Cava
3 license: GNU/GPL v3
4 */
5 
6 #ifndef DATA_H
7 #define DATA_H
8 
9 #include <string>
10 #include <Eigen/Dense>
11 #include <vector>
12 #include <map>
13 
14 using std::vector;
15 using Eigen::MatrixXf;
16 using Eigen::VectorXf;
17 using Eigen::ArrayXf;
18 using Eigen::VectorXi;
19 using Eigen::Dynamic;
20 using Eigen::Map;
21 typedef Eigen::Array<bool,Eigen::Dynamic,1> ArrayXb;
22 using namespace std;
23 typedef std::map<string, std::pair<vector<ArrayXf>, vector<ArrayXf>>> LongData;
24 // internal includes
25 //#include "params.h"
26 #include "../util/utils.h"
27 //#include "node/node.h"
28 //external includes
29 
30 namespace FT
31 {
36  namespace Dat{
41  class Data
42  {
43  public:
44 
45  MatrixXf& X; // n_features x n_samples matrix of features
46  VectorXf& y; // n_samples labels
47  LongData& Z; // longitudinal features
49  bool validation;
50  vector<bool> protect; // protected subgroups of features
51 
52  Data(MatrixXf& X, VectorXf& y, LongData& Z, bool c = false,
53  vector<bool> protect = vector<bool>());
54 
55  void set_validation(bool v=true);
56  void set_protected_groups();
57 
59  void get_batch(Data &db, int batch_size) const;
60  // protect_levels stores the levels of protected factors in X.
61  map<int,vector<float>> protect_levels;
62  vector<int> protected_groups;
64  vector<ArrayXb> cases; // used to pre-process cases if there
65  // aren't that many group intersections
66  };
67 
68  /* !
69  * @class DataRef
70  * @brief Holds training and validation splits of data,
71  * with pointers to each.
72  * */
73  class DataRef
74  {
75  private:
76  bool oCreated;
77  bool tCreated;
78  bool vCreated;
79  // training and validation data
80  MatrixXf X_t;
81  MatrixXf X_v;
82  VectorXf y_t;
83  VectorXf y_v;
86 
88  /* vector<bool> protect; // indicator of protected subgroups */
89 
90  public:
91  Data *o = NULL; //< pointer to original data
92  Data *v = NULL; //< pointer to validation data
93  Data *t = NULL; //< pointer to training data
94 
95  DataRef();
96 
97  ~DataRef();
98 
99 
100  DataRef(MatrixXf& X, VectorXf& y, LongData& Z, bool c=false,
101  vector<bool> protect = vector<bool>());
102 
103  void init(MatrixXf& X, VectorXf& y, LongData& Z,
104  bool c=false, vector<bool> protect = vector<bool>());
105 
106  void setOriginalData(MatrixXf& X, VectorXf& y,
107  LongData& Z, bool c=false,
108  vector<bool> protect = vector<bool>());
109 
110  void setOriginalData(Data *d);
111 
112  void setTrainingData(MatrixXf& X_t, VectorXf& y_t,
113  LongData& Z_t,
114  bool c = false,
115  vector<bool> protect = vector<bool>());
116 
117  void setTrainingData(Data *d, bool toDelete = false);
118 
119  void setValidationData(MatrixXf& X_v, VectorXf& y_v,
120  LongData& Z_v,
121  bool c = false,
122  vector<bool> protect = vector<bool>());
123 
124  void setValidationData(Data *d);
125 
127  void shuffle_data();
128 
130  void split_stratified(float split);
131 
133  void train_test_split(bool shuffle, float split);
134 
135  void split_longitudinal(
136  LongData&Z,
137  LongData&Z_t,
138  LongData&Z_v,
139  float split);
140 
142  void reorder_longitudinal(vector<ArrayXf> &vec1,
143  const vector<int>& order);
144 
145  };
146  }
147 }
148 
149 #endif
VectorXf y_t
Definition: data.h:82
bool classification
Definition: data.h:87
MatrixXf X_v
Definition: data.h:81
MatrixXf X_t
Definition: data.h:80
VectorXf y_v
Definition: data.h:83
LongData Z_t
Definition: data.h:84
bool oCreated
Definition: data.h:76
bool tCreated
Definition: data.h:77
bool vCreated
Definition: data.h:78
LongData Z_v
Definition: data.h:85
data holding X, y, and Z data
Definition: data.h:42
vector< int > protected_groups
Definition: data.h:62
vector< ArrayXb > cases
Definition: data.h:64
int group_intersections
Definition: data.h:63
VectorXf & y
Definition: data.h:46
bool classification
Definition: data.h:48
vector< bool > protect
Definition: data.h:50
map< int, vector< float > > protect_levels
Definition: data.h:61
LongData & Z
Definition: data.h:47
bool validation
Definition: data.h:49
MatrixXf & X
Definition: data.h:45
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
Definition: data.h:21
std::map< string, std::pair< vector< ArrayXf >, vector< ArrayXf > > > LongData
Definition: data.h:23
main Feat namespace
Definition: data.cc:13