22 #ifndef SUPERVISEDDESCENT_HPP_
23 #define SUPERVISEDDESCENT_HPP_
25 #include "superviseddescent/utils/ThreadPool.h"
27 #include "cereal/cereal.hpp"
28 #include "cereal/types/vector.hpp"
30 #include "opencv2/core/core.hpp"
44 namespace superviseddescent {
52 inline void no_eval(
const cv::Mat& current_predictions)
72 return cv::Mat::ones(1, params.cols, params.type());
85 template<
class RegressorType,
class NormalisationStrategy = NoNormalisation>
140 template<
class ProjectionFunction>
141 void train(cv::Mat parameters, cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection)
143 return train(parameters, initialisations, templates, projection,
no_eval);
165 template<
class ProjectionFunction,
class OnTrainingEpochCallback>
166 void train(cv::Mat parameters, cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection, OnTrainingEpochCallback on_training_epoch_callback)
169 Mat current_x = initialisations;
170 for (
size_t regressor_level = 0; regressor_level < regressors.size(); ++regressor_level) {
173 auto concurent_threads_supported = std::thread::hardware_concurrency();
174 if (concurent_threads_supported == 0) {
175 concurent_threads_supported = 4;
177 utils::ThreadPool thread_pool(concurent_threads_supported);
178 std::vector<std::future<typename std::result_of<ProjectionFunction(Mat, size_t, int)>::type>> results;
179 results.reserve(current_x.rows);
180 for (
int sample_index = 0; sample_index < current_x.rows; ++sample_index) {
181 results.emplace_back(
182 thread_pool.enqueue(projection, current_x.row(sample_index), regressor_level, sample_index)
187 for (
auto&& result : results) {
188 features.push_back(result.get());
192 if (templates.empty()) {
193 observed_values = features;
196 observed_values = features - templates;
201 for (
int sample_index = 0; sample_index < current_x.rows; ++sample_index) {
202 cv::Mat update_step = current_x.row(sample_index) - parameters.row(sample_index);
203 update_step = update_step.mul(normalisation_strategy(current_x.row(sample_index)));
204 b.push_back(update_step);
207 regressors[regressor_level].learn(observed_values, b);
210 for (
int sample_index = 0; sample_index < current_x.rows; ++sample_index) {
212 cv::Mat update_step = regressors[regressor_level].predict(observed_values.row(sample_index));
213 update_step = update_step.mul(1 / normalisation_strategy(current_x.row(sample_index)));
214 x_k.push_back(Mat(current_x.row(sample_index) - update_step));
217 on_training_epoch_callback(current_x);
237 template<
class ProjectionFunction>
238 cv::Mat
test(cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection)
240 return test(initialisations, templates, projection,
no_eval);
262 template<
class ProjectionFunction,
class OnRegressorIterationCallback>
263 cv::Mat
test(cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection, OnRegressorIterationCallback on_regressor_iteration_callback)
266 Mat current_x = initialisations;
267 for (
size_t regressor_level = 0; regressor_level < regressors.size(); ++regressor_level) {
269 auto concurent_threads_supported = std::thread::hardware_concurrency();
270 if (concurent_threads_supported == 0) {
271 concurent_threads_supported = 4;
273 utils::ThreadPool thread_pool(concurent_threads_supported);
274 std::vector<std::future<typename std::result_of<ProjectionFunction(Mat, size_t, int)>::type>> results;
275 results.reserve(current_x.rows);
276 for (
int sample_index = 0; sample_index < current_x.rows; ++sample_index) {
277 results.emplace_back(
278 thread_pool.enqueue(projection, current_x.row(sample_index), regressor_level, sample_index)
283 for (
auto&& result : results) {
284 features.push_back(result.get());
288 if (templates.empty()) {
289 observed_values = features;
292 observed_values = features - templates;
296 for (
int sample_index = 0; sample_index < current_x.rows; ++sample_index) {
297 cv::Mat update_step = regressors[regressor_level].predict(observed_values.row(sample_index));
298 update_step = update_step.mul(1 / normalisation_strategy(current_x.row(sample_index)));
299 x_k.push_back(Mat(current_x.row(sample_index) - update_step));
303 on_regressor_iteration_callback(current_x);
323 template<
class ProjectionFunction>
324 cv::Mat
predict(cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection)
327 Mat current_x = initialisations;
328 for (
size_t r = 0; r < regressors.size(); ++r) {
331 if (templates.empty()) {
332 observed_values = projection(current_x, r);
335 observed_values = projection(current_x, r) - templates;
337 cv::Mat update_step = regressors[r].predict(observed_values);
338 update_step = update_step.mul(1 / normalisation_strategy(current_x));
339 Mat x_k = current_x - update_step;
347 std::vector<RegressorType> regressors;
348 NormalisationStrategy normalisation_strategy;
350 friend class cereal::access;
356 template<
class Archive>
357 void serialize(Archive& ar)
359 ar(regressors, normalisation_strategy);
Definition: superviseddescent.hpp:60
cv::Mat predict(cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection)
Definition: superviseddescent.hpp:324
void train(cv::Mat parameters, cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection)
Definition: superviseddescent.hpp:141
Definition: superviseddescent.hpp:86
cv::Mat test(cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection)
Definition: superviseddescent.hpp:238
cv::Mat operator()(cv::Mat params)
Definition: superviseddescent.hpp:71
cv::Mat test(cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection, OnRegressorIterationCallback on_regressor_iteration_callback)
Definition: superviseddescent.hpp:263
void no_eval(const cv::Mat ¤t_predictions)
Definition: superviseddescent.hpp:52
void train(cv::Mat parameters, cv::Mat initialisations, cv::Mat templates, ProjectionFunction projection, OnTrainingEpochCallback on_training_epoch_callback)
Definition: superviseddescent.hpp:166
SupervisedDescentOptimiser(std::vector< RegressorType > regressors, NormalisationStrategy normalisation=NoNormalisation())
Definition: superviseddescent.hpp:103
SupervisedDescentOptimiser()=default