246 lines
6.9 KiB
C++
246 lines
6.9 KiB
C++
// Copyright (C) 2017 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_GLOBAL_FuNCTION_SEARCH_Hh_
|
|
#define DLIB_GLOBAL_FuNCTION_SEARCH_Hh_
|
|
|
|
#include "global_function_search_abstract.h"
|
|
#include <vector>
|
|
#include "../matrix.h"
|
|
#include <mutex>
|
|
#include "../rand.h"
|
|
#include "upper_bound_function.h"
|
|
#include "../test_for_odr_violations.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
struct function_spec
|
|
{
|
|
function_spec(
|
|
matrix<double,0,1> bound1,
|
|
matrix<double,0,1> bound2
|
|
);
|
|
|
|
function_spec(
|
|
matrix<double,0,1> bound1,
|
|
matrix<double,0,1> bound2,
|
|
std::vector<bool> is_integer
|
|
);
|
|
|
|
matrix<double,0,1> lower;
|
|
matrix<double,0,1> upper;
|
|
std::vector<bool> is_integer_variable;
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
namespace gopt_impl
|
|
{
|
|
struct outstanding_function_eval_request
|
|
{
|
|
size_t request_id = 0; // unique id for this eval request
|
|
matrix<double,0,1> x; // function x to evaluate
|
|
|
|
// trust region specific stuff
|
|
bool was_trust_region_generated_request = false;
|
|
double predicted_improvement = std::numeric_limits<double>::quiet_NaN();
|
|
double anchor_objective_value = std::numeric_limits<double>::quiet_NaN(); // objective value at center of TR step
|
|
|
|
bool operator==(const outstanding_function_eval_request& item) const { return request_id == item.request_id; }
|
|
};
|
|
|
|
struct funct_info
|
|
{
|
|
funct_info() = delete;
|
|
funct_info(const funct_info&) = delete;
|
|
funct_info& operator=(const funct_info&) = delete;
|
|
|
|
funct_info(
|
|
const function_spec& spec,
|
|
size_t function_idx,
|
|
const std::shared_ptr<std::mutex>& m
|
|
) :
|
|
spec(spec), function_idx(function_idx), m(m)
|
|
{
|
|
best_x = zeros_matrix(spec.lower);
|
|
}
|
|
|
|
upper_bound_function build_upper_bound_with_all_function_evals (
|
|
) const;
|
|
|
|
static double find_nn (
|
|
const std::vector<function_evaluation>& evals,
|
|
const matrix<double,0,1>& x
|
|
);
|
|
|
|
|
|
function_spec spec;
|
|
size_t function_idx = 0;
|
|
std::shared_ptr<std::mutex> m;
|
|
upper_bound_function ub;
|
|
std::vector<outstanding_function_eval_request> outstanding_evals;
|
|
matrix<double,0,1> best_x;
|
|
double best_objective_value = -std::numeric_limits<double>::infinity();
|
|
double radius = 0;
|
|
};
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class function_evaluation_request
|
|
{
|
|
public:
|
|
|
|
function_evaluation_request() = delete;
|
|
function_evaluation_request(const function_evaluation_request&) = delete;
|
|
function_evaluation_request& operator=(const function_evaluation_request&) = delete;
|
|
|
|
|
|
function_evaluation_request(function_evaluation_request&& item);
|
|
function_evaluation_request& operator=(function_evaluation_request&& item);
|
|
|
|
~function_evaluation_request();
|
|
|
|
size_t function_idx (
|
|
) const;
|
|
|
|
const matrix<double,0,1>& x (
|
|
) const;
|
|
|
|
bool has_been_evaluated (
|
|
) const;
|
|
|
|
void set (
|
|
double y
|
|
);
|
|
|
|
void swap(function_evaluation_request& item);
|
|
|
|
private:
|
|
|
|
friend class global_function_search;
|
|
|
|
explicit function_evaluation_request(
|
|
const gopt_impl::outstanding_function_eval_request& req,
|
|
const std::shared_ptr<gopt_impl::funct_info>& info
|
|
) : req(req), info(info) {}
|
|
|
|
bool m_has_been_evaluated = false;
|
|
gopt_impl::outstanding_function_eval_request req;
|
|
std::shared_ptr<gopt_impl::funct_info> info;
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class global_function_search
|
|
{
|
|
public:
|
|
|
|
global_function_search() = default;
|
|
|
|
explicit global_function_search(
|
|
const function_spec& function
|
|
);
|
|
|
|
explicit global_function_search(
|
|
const std::vector<function_spec>& functions_
|
|
);
|
|
|
|
global_function_search(
|
|
const std::vector<function_spec>& functions_,
|
|
const std::vector<std::vector<function_evaluation>>& initial_function_evals,
|
|
const double relative_noise_magnitude = 0.001
|
|
);
|
|
|
|
global_function_search(const global_function_search&) = delete;
|
|
global_function_search& operator=(const global_function_search& item) = delete;
|
|
|
|
global_function_search(global_function_search&& item) = default;
|
|
global_function_search& operator=(global_function_search&& item) = default;
|
|
|
|
size_t num_functions(
|
|
) const;
|
|
|
|
void set_seed (
|
|
time_t seed
|
|
);
|
|
|
|
void get_function_evaluations (
|
|
std::vector<function_spec>& specs,
|
|
std::vector<std::vector<function_evaluation>>& function_evals
|
|
) const;
|
|
|
|
void get_best_function_eval (
|
|
matrix<double,0,1>& x,
|
|
double& y,
|
|
size_t& function_idx
|
|
) const;
|
|
|
|
function_evaluation_request get_next_x (
|
|
);
|
|
|
|
double get_pure_random_search_probability (
|
|
) const;
|
|
|
|
void set_pure_random_search_probability (
|
|
double prob
|
|
);
|
|
|
|
double get_solver_epsilon (
|
|
) const;
|
|
|
|
void set_solver_epsilon (
|
|
double eps
|
|
);
|
|
|
|
double get_relative_noise_magnitude (
|
|
) const;
|
|
|
|
void set_relative_noise_magnitude (
|
|
double value
|
|
);
|
|
|
|
size_t get_monte_carlo_upper_bound_sample_num (
|
|
) const;
|
|
|
|
void set_monte_carlo_upper_bound_sample_num (
|
|
size_t num
|
|
);
|
|
|
|
private:
|
|
|
|
std::shared_ptr<gopt_impl::funct_info> best_function(
|
|
) const;
|
|
|
|
std::shared_ptr<gopt_impl::funct_info> best_function(
|
|
size_t& idx
|
|
) const;
|
|
|
|
bool has_outstanding_trust_region_request (
|
|
) const;
|
|
|
|
|
|
dlib::rand rnd;
|
|
double pure_random_search_probability = 0.02;
|
|
double min_trust_region_epsilon = 0;
|
|
double relative_noise_magnitude = 0.001;
|
|
size_t num_random_samples = 5000;
|
|
bool do_trust_region_step = true;
|
|
|
|
size_t next_request_id = 1;
|
|
|
|
std::vector<std::shared_ptr<gopt_impl::funct_info>> functions;
|
|
std::shared_ptr<std::mutex> m;
|
|
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_GLOBAL_FuNCTION_SEARCH_Hh_
|
|
|