Brush C++ API
A flexible interpretable machine learning framework
Loading...
Searching...
No Matches
weight_optimizer.h
Go to the documentation of this file.
1/* Brush
2copyright 2020 William La Cava
3license: GNU/GPL v3
4
5Code below heavily inspired by heal-research/operon, Copyright 2019-2022 Heal Research
6*/
7#ifndef WEIGHT_OPTIMIZER_H
8#define WEIGHT_OPTIMIZER_H
9
10#include <ceres/dynamic_autodiff_cost_function.h>
11#include <ceres/dynamic_numeric_diff_cost_function.h>
12#include <ceres/loss_function.h>
13#include <ceres/solver.h>
14#include <ceres/tiny_solver.h>
15
16#include "tiny_cost_function.h"
17
18namespace Brush
19{
20
29
30template<typename PT>
32 typedef float Scalar;
33 ResidualEvaluator(PT& program, Dataset const& dataset)
34 : program_(program)
36 , numParameters_(program.get_weights().size())
37 , y_true_(dataset.y)
38 {}
39
40 template<typename T>
41 auto operator()(Eigen::DenseBase<T>& parameters, Eigen::DenseBase<T>& residuals) const noexcept -> void
42 {
43 return (*this)(parameters.data(), residuals.data());
44 }
45
46 template <typename T>
47 auto operator()(T const* parameters, T* residuals) const -> bool
48 {
49 using ArrayType = Array<T, Dynamic, 1>; // ColMajor?
50 const T ** new_weights = &parameters;
51
53 GetDataset(),
55 );
56
57 auto residualMap = ArrayType::Map(residuals, GetDataset().get_n_samples());
58
60
61 return true;
62 }
63
64 [[nodiscard]] auto NumParameters() const -> size_t { return numParameters_; }
65 [[nodiscard]] auto NumResiduals() const -> size_t { return y_true_.get().size(); }
66 inline auto GetProgram() const { return program_.get();};
67 inline auto GetDataset() const { return dataset_.get();};
68 inline auto GetTarget() const { return y_true_.get();};
69
70private:
71 std::reference_wrapper<PT> program_;
72 std::reference_wrapper<Dataset const> dataset_;
73 std::reference_wrapper<ArrayXf const> y_true_;
74 size_t numParameters_; // cache the number of parameters in the tree
75};
76
77// TODO: see this struct and try to understand how to make non-templated classes
79{
84 template<typename PT>
85 void update(PT& program, const Dataset& dataset)
86 {
87 if (program.get_n_weights() == 0)
88 return;
89
90 // fmt::print("number of weights: {}\n",program.get_n_weights());
91 auto init_weights = program.get_weights();
92
94 ResidualEvaluator<PT> evaluator(program, dataset);
95 CFType cost_function(evaluator);
96 ceres::TinySolver<CFType> solver;
97 solver.options.max_num_iterations = 10;
98
99 typename decltype(solver)::Parameters parameters = program.get_weights();
100 solver.Solve(cost_function, &parameters);
101
102 // fmt::print("Summary:\nInitial cost: {}\nFinal Cost: {}\nIterations: {}\n",
103 // solver.summary.initial_cost,
104 // solver.summary.final_cost,
105 // solver.summary.iterations
106 // );
107 // fmt::print("Initial weights: {}\nFinal weights: {}\n",
108 // init_weights,
109 // parameters
110 // );
111 if (solver.summary.final_cost < solver.summary.initial_cost)
112 {
113 program.set_weights(parameters);
114 }
115
116 }
117};
118
119} // namespace Brush
120#endif
void bind_engine(py::module &m, string name)
< nsga2 selection operator for getting the front
Definition data.cpp:12
ProgramType
Definition types.h:70
auto NumResiduals() const -> size_t
std::reference_wrapper< Dataset const > dataset_
ResidualEvaluator(PT &program, Dataset const &dataset)
auto operator()(Eigen::DenseBase< T > &parameters, Eigen::DenseBase< T > &residuals) const noexcept -> void
auto NumParameters() const -> size_t
std::reference_wrapper< PT > program_
auto operator()(T const *parameters, T *residuals) const -> bool
std::reference_wrapper< ArrayXf const > y_true_
void update(PT &program, const Dataset &dataset)
Update program weights using non-linear least squares.