7 #include "../util/rnd.h"
8 #include "../util/logger.h"
20 vector<bool> protect):
21 X(X), y(y), Z(Z), classification(c) , protect(protect)
31 this->
cases.resize(0);
36 logger.
log(
"storing protected attributes...",2);
50 +
to_string(pl.second.size()) +
" values; ",3);
56 logger.
log(
"storing group intersections...",3);
60 for (
auto level : pl.second)
62 ArrayXb x = (
X.row(group).array() == level);
63 this->
cases.push_back(x);
82 batch_size = std::min(batch_size,
int(
y.size()));
84 WARN(
"batch_size is set to "
85 +
to_string(batch_size) +
" when getting batch");
87 vector<size_t> idx(
y.size());
88 std::iota(idx.begin(), idx.end(), 0);
90 db.
X.resize(
X.rows(),batch_size);
91 db.
y.resize(batch_size);
92 for (
const auto& val:
Z )
94 db.
Z[val.first].first.resize(batch_size);
95 db.
Z[val.first].second.resize(batch_size);
97 for (
unsigned i = 0;
i<batch_size; ++
i)
100 db.
X.col(
i) =
X.col(idx.at(
i));
101 db.
y(
i) =
y(idx.at(
i));
103 for (
const auto& val:
Z )
105 db.
Z.at(val.first).first.at(
i) = \
106 Z.at(val.first).first.at(idx.at(
i));
107 db.
Z.at(val.first).second.at(
i) = \
108 Z.at(val.first).second.at(idx.at(
i));
122 LongData& Z,
bool c, vector<bool> protect)
124 this->
init(X, y, Z, c, protect);
128 LongData& Z,
bool c, vector<bool> protect)
130 o =
new Data(X, y, Z, c, protect);
167 bool c, vector<bool> protect)
169 o =
new Data(X, y, Z, c, protect);
197 bool c, vector<bool> protect)
216 vector<bool> protect)
230 Eigen::PermutationMatrix<Dynamic,Dynamic> perm(
o->
X.cols());
233 perm.indices().data()+perm.indices().size());
241 o->
y = (
o->
y.transpose() * perm).transpose() ;
245 std::vector<int> zidx(
o->
y.size());
248 for (
unsigned i = 0;
i < perm.indices().size(); ++
i)
249 zidx.at(perm.indices()(
i)) =
i;
254 for(
auto &val :
o->
Z)
274 logger.
log(
"Stratify split called with initial data size as "
277 std::map<float, vector<int>> label_indices;
280 for(
int x = 0; x <
o->
y.size(); x++)
281 label_indices[
o->
y(x)].push_back(x);
290 std::map<float, vector<int>>::iterator it = label_indices.begin();
292 vector<int> t_indices;
293 vector<int> v_indices;
298 for(; it != label_indices.end(); it++)
300 t_size = ceil(it->second.size()*split);
302 for(x = 0; x < t_size; x++)
303 t_indices.push_back(it->second.at(x));
305 for(; x < it->second.size(); x++)
306 v_indices.push_back(it->second.at(x));
313 +
to_string((it->second.size() - t_size)), 3,
"\n");
317 X_t.resize(
o->
X.rows(), t_indices.size());
318 X_v.resize(
o->
X.rows(), v_indices.size());
319 y_t.resize(t_indices.size());
320 y_v.resize(v_indices.size());
322 sort(t_indices.begin(), t_indices.end());
324 for(
int x = 0; x < t_indices.size(); x++)
326 t->
X.col(x) =
o->
X.col(t_indices.at(x));
327 t->
y(x) =
o->
y(t_indices.at(x));
331 for(
auto const &val :
o->
Z)
333 t->
Z[val.first].first.push_back(
334 val.second.first[t_indices.at(x)]);
335 t->
Z[val.first].second.push_back(
336 val.second.second[t_indices.at(x)]);
341 sort(v_indices.begin(), v_indices.end());
343 for(
int x = 0; x < v_indices.size(); x++)
345 v->
X.col(x) =
o->
X.col(v_indices.at(x));
346 v->
y(x) =
o->
y(v_indices.at(x));
350 for(
auto const &val :
o->
Z)
352 v->
Z[val.first].first.push_back(
353 val.second.first[t_indices.at(x)]);
354 v->
Z[val.first].second.push_back(
355 val.second.second[t_indices.at(x)]);
377 int train_size = min(
int(
o->
X.cols()*split),
379 int val_size = max(
int(
o->
X.cols()*(1-split)), 1);
381 X_t.resize(
o->
X.rows(),train_size);
382 X_v.resize(
o->
X.rows(),val_size);
383 y_t.resize(train_size);
384 y_v.resize(val_size);
387 t->
X = MatrixXf::Map(
o->
X.data(),
t->
X.rows(),
389 v->
X = MatrixXf::Map(
o->
X.data()+
t->
X.rows()*
t->
X.cols(),
390 v->
X.rows(),
v->
X.cols());
392 t->
y = VectorXf::Map(
o->
y.data(),
t->
y.size());
393 v->
y = VectorXf::Map(
o->
y.data()+
t->
y.size(),
v->
y.size());
406 for (
const auto val: Z )
408 size = Z.at(val.first).first.size();
412 int testSize = int(size*split);
413 int validateSize = int(size*(1-split));
415 for (
const auto &val: Z )
417 vector<ArrayXf> _Z_t_v, _Z_t_t, _Z_v_v, _Z_v_t;
418 _Z_t_v.assign(Z[val.first].first.begin(),
419 Z[val.first].first.begin()+testSize);
420 _Z_t_t.assign(Z[val.first].second.begin(),
421 Z[val.first].second.begin()+testSize);
422 _Z_v_v.assign(Z[val.first].first.begin()+testSize,
423 Z[val.first].first.begin()+testSize+validateSize);
424 _Z_v_t.assign(Z[val.first].second.begin()+testSize,
425 Z[val.first].second.begin()+testSize+validateSize);
427 Z_t[val.first] = make_pair(_Z_t_v, _Z_t_t);
428 Z_v[val.first] = make_pair(_Z_v_v, _Z_v_t);
433 vector<int>
const &order )
435 for (
int s = 1, d; s < order.size(); ++ s ) {
436 for ( d = order.at(s); d < s; d = order.at(d) ) ;
438 while ( d = order.at(d), d != s )
439 swap(
v.at(s),
v.at(d));
void init(MatrixXf &X, VectorXf &y, LongData &Z, bool c=false, vector< bool > protect=vector< bool >())
void split_longitudinal(LongData &Z, LongData &Z_t, LongData &Z_v, float split)
void split_stratified(float split)
split classification data as stratas
void shuffle_data()
shuffles original data
void setOriginalData(MatrixXf &X, VectorXf &y, LongData &Z, bool c=false, vector< bool > protect=vector< bool >())
void setTrainingData(MatrixXf &X_t, VectorXf &y_t, LongData &Z_t, bool c=false, vector< bool > protect=vector< bool >())
void train_test_split(bool shuffle, float split)
splits data into training and validation folds.
void reorder_longitudinal(vector< ArrayXf > &vec1, const vector< int > &order)
reordering utility for shuffling longitudinal data.
void setValidationData(MatrixXf &X_v, VectorXf &y_v, LongData &Z_v, bool c=false, vector< bool > protect=vector< bool >())
data holding X, y, and Z data
vector< int > protected_groups
Data(MatrixXf &X, VectorXf &y, LongData &Z, bool c=false, vector< bool > protect=vector< bool >())
void get_batch(Data &db, int batch_size) const
select random subset of data for training weights.
map< int, vector< float > > protect_levels
void set_validation(bool v=true)
void set_protected_groups()
string log(string m, int v, string sep="\n") const
print message with verbosity control.
void shuffle(RandomAccessIterator first, RandomAccessIterator last)
Eigen::Array< bool, Eigen::Dynamic, 1 > ArrayXb
std::map< string, std::pair< vector< ArrayXf >, vector< ArrayXf > > > LongData
vector< T > unique(vector< T > w)
returns unique elements in vector
std::string to_string(const T &value)
template function to convert objects to string for logging