Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
weight_optimizer.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5Code below heavily inspired by heal-research/operon, Copyright 2019-2022 Heal Research
6*/
7#ifndef WEIGHT_OPTIMIZER_H
8#define WEIGHT_OPTIMIZER_H
9
10#include <ceres/dynamic_autodiff_cost_function.h>
11#include <ceres/dynamic_numeric_diff_cost_function.h>
12#include <ceres/loss_function.h>
13#include <ceres/solver.h>
14#include <ceres/tiny_solver.h>
15
16#include "tiny_cost_function.h"
17
18namespace Brush
19{
20
29
30template<typename PT>
32 typedef float Scalar;
33 ResidualEvaluator(PT& program, Dataset const& dataset)
34 : program_(program)
35 , dataset_(dataset)
36 , numParameters_(program.get_weights().size())
37 , y_true_(dataset.y)
38 {}
39
40 template<typename T>
41 auto operator()(Eigen::DenseBase<T>& parameters, Eigen::DenseBase<T>& residuals) const noexcept -> void
42 {
43 return (*this)(parameters.data(), residuals.data());
44 }
45
46 template <typename T>
47 auto operator()(T const* parameters, T* residuals) const -> bool
48 {
49 using ArrayType = Array<T, Dynamic, 1>; // ColMajor?
50 const T ** new_weights = &parameters;
51
52 ArrayType y_pred = GetProgram().template predict_with_weights<ArrayType>(
53 GetDataset(),
54 new_weights
55 );
56
57 auto residualMap = ArrayType::Map(residuals, GetDataset().get_n_samples());
58
59 // how we calculate the residuals
60 if (GetDataset().classification) // classification
61 {
62 // tolerance to avoid numeric errors.
63
64 // Using an eps with 7 significant digits to avoid weird behavior.
65 // From some experiments, I noticed that using very small values (previous
66 // was pow(10,-10) caused the optimizer to evaluate to nans), probably because
67 // we were calculating log(0) due to cpp ignoring more than 7 significant digits
68 // for floats. using -6. As a reference, tensorflow uses -7.
69 // We need to remember that we dont want to have eps too big, otherwise
70 // we are biasing the predictions away from 0 and 1
71 float eps = 1e-6f;
72
73 auto y = GetTarget();
74
75 // cout << T(1.0) << ", " << T(eps) << endl;
76
77 // clamp values and avoid log(0)
78 y_pred = y_pred.min(T(1.0) - T(eps)).max(T(eps));
79
80 // log loss
81 // residualMap = -(y*log(y_pred.array()) + (T(1.0)-y)*log(T(1.0)-y_pred.array()));
82 residualMap = -(y*log(y_pred) + (T(1.0)-y)*log(T(1.0)-y_pred));
83 }
84 else { // This is MSE, default behavior
85 residualMap = (y_pred - GetTarget());
86 }
87
88 return true;
89 }
90
91 [[nodiscard]] auto NumParameters() const -> size_t { return numParameters_; }
92 [[nodiscard]] auto NumResiduals() const -> size_t { return y_true_.get().size(); }
93 inline auto GetProgram() const { return program_.get();};
94 inline auto GetDataset() const { return dataset_.get();};
95 inline auto GetTarget() const { return y_true_.get();};
96
97private:
98 std::reference_wrapper<PT> program_;
99 std::reference_wrapper<Dataset const> dataset_;
100 std::reference_wrapper<ArrayXf const> y_true_;
101 size_t numParameters_; // cache the number of parameters in the tree
102};
103
104// TODO: see this struct and try to understand how to make non-templated classes
106{
111 template<typename PT>
112 void update(PT& program, const Dataset& dataset)
113 {
114 if (program.get_n_weights() == 0)
115 return;
116
117 // fmt::print("number of weights: {}\n",program.get_n_weights());
118 auto init_weights = program.get_weights();
119
121 ResidualEvaluator<PT> evaluator(program, dataset);
122 CFType cost_function(evaluator);
123 ceres::TinySolver<CFType> solver;
124 solver.options.max_num_iterations = 10;
125
126 typename decltype(solver)::Parameters parameters = program.get_weights();
127 solver.Solve(cost_function, &parameters);
128
129 // fmt::print("Summary:\nInitial cost: {}\nFinal Cost: {}\nIterations: {}\n",
130 // solver.summary.initial_cost,
131 // solver.summary.final_cost,
132 // solver.summary.iterations
133 // );
134 // fmt::print("Initial weights: {}\nFinal weights: {}\n",
135 // init_weights,
136 // parameters
137 // );
138 if (solver.summary.final_cost < solver.summary.initial_cost)
139 {
140 program.set_weights(parameters);
141 }
142
143 }
144};
145
146} // namespace Brush
147#endif
holds variable type data.
Definition data.h:51
< nsga2 selection operator for getting the front
Definition bandit.cpp:4
ProgramType PT
Definition program.h:40
auto NumResiduals() const -> size_t
std::reference_wrapper< Dataset const > dataset_
ResidualEvaluator(PT &program, Dataset const &dataset)
auto operator()(Eigen::DenseBase< T > &parameters, Eigen::DenseBase< T > &residuals) const noexcept -> void
auto NumParameters() const -> size_t
std::reference_wrapper< PT > program_
auto operator()(T const *parameters, T *residuals) const -> bool
std::reference_wrapper< ArrayXf const > y_true_
void update(PT &program, const Dataset &dataset)
Update program weights using non-linear least squares.