283 lines
8.0 KiB
C++
283 lines
8.0 KiB
C++
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_Hh_
|
|
#define DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_Hh_
|
|
|
|
#include "structural_graph_labeling_trainer_abstract.h"
|
|
#include "../algs.h"
|
|
#include "../optimization.h"
|
|
#include "structural_svm_graph_labeling_problem.h"
|
|
#include "../graph_cuts/graph_labeler.h"
|
|
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename vector_type
|
|
>
|
|
class structural_graph_labeling_trainer
|
|
{
|
|
public:
|
|
typedef std::vector<bool> label_type;
|
|
typedef graph_labeler<vector_type> trained_function_type;
|
|
|
|
structural_graph_labeling_trainer (
|
|
)
|
|
{
|
|
C = 10;
|
|
verbose = false;
|
|
eps = 0.1;
|
|
num_threads = 2;
|
|
max_cache_size = 5;
|
|
loss_pos = 1.0;
|
|
loss_neg = 1.0;
|
|
}
|
|
|
|
void set_num_threads (
|
|
unsigned long num
|
|
)
|
|
{
|
|
num_threads = num;
|
|
}
|
|
|
|
unsigned long get_num_threads (
|
|
) const
|
|
{
|
|
return num_threads;
|
|
}
|
|
|
|
void set_epsilon (
|
|
double eps_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(eps_ > 0,
|
|
"\t void structural_graph_labeling_trainer::set_epsilon()"
|
|
<< "\n\t eps_ must be greater than 0"
|
|
<< "\n\t eps_: " << eps_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
eps = eps_;
|
|
}
|
|
|
|
double get_epsilon (
|
|
) const { return eps; }
|
|
|
|
void set_max_cache_size (
|
|
unsigned long max_size
|
|
)
|
|
{
|
|
max_cache_size = max_size;
|
|
}
|
|
|
|
unsigned long get_max_cache_size (
|
|
) const
|
|
{
|
|
return max_cache_size;
|
|
}
|
|
|
|
void be_verbose (
|
|
)
|
|
{
|
|
verbose = true;
|
|
}
|
|
|
|
void be_quiet (
|
|
)
|
|
{
|
|
verbose = false;
|
|
}
|
|
|
|
void set_oca (
|
|
const oca& item
|
|
)
|
|
{
|
|
solver = item;
|
|
}
|
|
|
|
const oca get_oca (
|
|
) const
|
|
{
|
|
return solver;
|
|
}
|
|
|
|
void set_c (
|
|
double C_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(C_ > 0,
|
|
"\t void structural_graph_labeling_trainer::set_c()"
|
|
<< "\n\t C_ must be greater than 0"
|
|
<< "\n\t C_: " << C_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
C = C_;
|
|
}
|
|
|
|
double get_c (
|
|
) const
|
|
{
|
|
return C;
|
|
}
|
|
|
|
|
|
void set_loss_on_positive_class (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss >= 0,
|
|
"\t structural_graph_labeling_trainer::set_loss_on_positive_class()"
|
|
<< "\n\t Invalid inputs were given to this function."
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this );
|
|
|
|
loss_pos = loss;
|
|
}
|
|
|
|
void set_loss_on_negative_class (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss >= 0,
|
|
"\t structural_graph_labeling_trainer::set_loss_on_negative_class()"
|
|
<< "\n\t Invalid inputs were given to this function."
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this );
|
|
|
|
loss_neg = loss;
|
|
}
|
|
|
|
double get_loss_on_negative_class (
|
|
) const { return loss_neg; }
|
|
|
|
double get_loss_on_positive_class (
|
|
) const { return loss_pos; }
|
|
|
|
|
|
template <
|
|
typename graph_type
|
|
>
|
|
const graph_labeler<vector_type> train (
|
|
const dlib::array<graph_type>& samples,
|
|
const std::vector<label_type>& labels,
|
|
const std::vector<std::vector<double> >& losses
|
|
) const
|
|
{
|
|
#ifdef ENABLE_ASSERTS
|
|
std::string reason_for_failure;
|
|
DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true ,
|
|
"\t void structural_graph_labeling_trainer::train()"
|
|
<< "\n\t Invalid inputs were given to this function."
|
|
<< "\n\t reason_for_failure: " << reason_for_failure
|
|
<< "\n\t samples.size(): " << samples.size()
|
|
<< "\n\t labels.size(): " << labels.size()
|
|
<< "\n\t this: " << this );
|
|
DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
|
|
all_values_are_nonnegative(losses) == true,
|
|
"\t void structural_graph_labeling_trainer::train()"
|
|
<< "\n\t Invalid inputs were given to this function."
|
|
<< "\n\t labels.size(): " << labels.size()
|
|
<< "\n\t losses.size(): " << losses.size()
|
|
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
|
|
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
|
|
<< "\n\t this: " << this );
|
|
#endif
|
|
|
|
|
|
structural_svm_graph_labeling_problem<graph_type> prob(samples, labels, losses, num_threads);
|
|
|
|
if (verbose)
|
|
prob.be_verbose();
|
|
|
|
prob.set_c(C);
|
|
prob.set_epsilon(eps);
|
|
prob.set_max_cache_size(max_cache_size);
|
|
if (prob.get_losses().size() == 0)
|
|
{
|
|
prob.set_loss_on_positive_class(loss_pos);
|
|
prob.set_loss_on_negative_class(loss_neg);
|
|
}
|
|
|
|
matrix<double,0,1> w;
|
|
solver(prob, w, prob.get_num_edge_weights());
|
|
|
|
vector_type edge_weights;
|
|
vector_type node_weights;
|
|
populate_weights(w, edge_weights, node_weights, prob.get_num_edge_weights());
|
|
return graph_labeler<vector_type>(edge_weights, node_weights);
|
|
}
|
|
|
|
template <
|
|
typename graph_type
|
|
>
|
|
const graph_labeler<vector_type> train (
|
|
const dlib::array<graph_type>& samples,
|
|
const std::vector<label_type>& labels
|
|
) const
|
|
{
|
|
std::vector<std::vector<double> > losses;
|
|
return train(samples, labels, losses);
|
|
}
|
|
|
|
private:
|
|
|
|
template <typename T>
|
|
typename enable_if<is_matrix<T> >::type populate_weights (
|
|
const matrix<double,0,1>& w,
|
|
T& edge_weights,
|
|
T& node_weights,
|
|
long split_idx
|
|
) const
|
|
{
|
|
edge_weights = rowm(w,range(0, split_idx-1));
|
|
node_weights = rowm(w,range(split_idx,w.size()-1));
|
|
}
|
|
|
|
template <typename T>
|
|
typename disable_if<is_matrix<T> >::type populate_weights (
|
|
const matrix<double,0,1>& w,
|
|
T& edge_weights,
|
|
T& node_weights,
|
|
long split_idx
|
|
) const
|
|
{
|
|
edge_weights.clear();
|
|
node_weights.clear();
|
|
for (long i = 0; i < split_idx; ++i)
|
|
{
|
|
if (w(i) != 0)
|
|
edge_weights.insert(edge_weights.end(), std::make_pair(i,w(i)));
|
|
}
|
|
for (long i = split_idx; i < w.size(); ++i)
|
|
{
|
|
if (w(i) != 0)
|
|
node_weights.insert(node_weights.end(), std::make_pair(i-split_idx,w(i)));
|
|
}
|
|
}
|
|
|
|
|
|
double C;
|
|
oca solver;
|
|
double eps;
|
|
bool verbose;
|
|
unsigned long num_threads;
|
|
unsigned long max_cache_size;
|
|
double loss_pos;
|
|
double loss_neg;
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_Hh_
|
|
|