416 lines
13 KiB
C++
416 lines
13 KiB
C++
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_
|
|
#define DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_
|
|
|
|
#include "structural_object_detection_trainer_abstract.h"
|
|
#include "../algs.h"
|
|
#include "../optimization.h"
|
|
#include "structural_svm_object_detection_problem.h"
|
|
#include "../image_processing/object_detector.h"
|
|
#include "../image_processing/box_overlap_testing.h"
|
|
#include "../image_processing/full_object_detection.h"
|
|
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename image_scanner_type,
|
|
typename svm_struct_prob_type
|
|
>
|
|
void configure_nuclear_norm_regularizer (
|
|
const image_scanner_type&,
|
|
svm_struct_prob_type&
|
|
)
|
|
{
|
|
// does nothing by default. Specific scanner types overload this function to do
|
|
// whatever is appropriate.
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename image_scanner_type
|
|
>
|
|
class structural_object_detection_trainer : noncopyable
|
|
{
|
|
|
|
public:
|
|
typedef double scalar_type;
|
|
typedef default_memory_manager mem_manager_type;
|
|
typedef object_detector<image_scanner_type> trained_function_type;
|
|
|
|
|
|
explicit structural_object_detection_trainer (
|
|
const image_scanner_type& scanner_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(scanner_.get_num_detection_templates() > 0,
|
|
"\t structural_object_detection_trainer::structural_object_detection_trainer(scanner_)"
|
|
<< "\n\t You can't have zero detection templates"
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
C = 1;
|
|
verbose = false;
|
|
eps = 0.1;
|
|
num_threads = 2;
|
|
max_cache_size = 5;
|
|
match_eps = 0.5;
|
|
loss_per_missed_target = 1;
|
|
loss_per_false_alarm = 1;
|
|
|
|
scanner.copy_configuration(scanner_);
|
|
|
|
auto_overlap_tester = true;
|
|
}
|
|
|
|
const image_scanner_type& get_scanner (
|
|
) const
|
|
{
|
|
return scanner;
|
|
}
|
|
|
|
bool auto_set_overlap_tester (
|
|
) const
|
|
{
|
|
return auto_overlap_tester;
|
|
}
|
|
|
|
void set_overlap_tester (
|
|
const test_box_overlap& tester
|
|
)
|
|
{
|
|
overlap_tester = tester;
|
|
auto_overlap_tester = false;
|
|
}
|
|
|
|
test_box_overlap get_overlap_tester (
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(auto_set_overlap_tester() == false,
|
|
"\t test_box_overlap structural_object_detection_trainer::get_overlap_tester()"
|
|
<< "\n\t You can't call this function if the overlap tester is generated dynamically."
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
return overlap_tester;
|
|
}
|
|
|
|
void set_num_threads (
|
|
unsigned long num
|
|
)
|
|
{
|
|
num_threads = num;
|
|
}
|
|
|
|
unsigned long get_num_threads (
|
|
) const
|
|
{
|
|
return num_threads;
|
|
}
|
|
|
|
void set_epsilon (
|
|
scalar_type eps_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(eps_ > 0,
|
|
"\t void structural_object_detection_trainer::set_epsilon()"
|
|
<< "\n\t eps_ must be greater than 0"
|
|
<< "\n\t eps_: " << eps_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
eps = eps_;
|
|
}
|
|
|
|
scalar_type get_epsilon (
|
|
) const { return eps; }
|
|
|
|
void set_max_runtime (
|
|
const std::chrono::nanoseconds& max_runtime
|
|
)
|
|
{
|
|
solver.set_max_runtime(max_runtime);
|
|
}
|
|
|
|
std::chrono::nanoseconds get_max_runtime (
|
|
) const
|
|
{
|
|
return solver.get_max_runtime();
|
|
}
|
|
|
|
void set_max_cache_size (
|
|
unsigned long max_size
|
|
)
|
|
{
|
|
max_cache_size = max_size;
|
|
}
|
|
|
|
unsigned long get_max_cache_size (
|
|
) const
|
|
{
|
|
return max_cache_size;
|
|
}
|
|
|
|
void be_verbose (
|
|
)
|
|
{
|
|
verbose = true;
|
|
}
|
|
|
|
void be_quiet (
|
|
)
|
|
{
|
|
verbose = false;
|
|
}
|
|
|
|
void set_oca (
|
|
const oca& item
|
|
)
|
|
{
|
|
solver = item;
|
|
}
|
|
|
|
const oca get_oca (
|
|
) const
|
|
{
|
|
return solver;
|
|
}
|
|
|
|
void set_c (
|
|
scalar_type C_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(C_ > 0,
|
|
"\t void structural_object_detection_trainer::set_c()"
|
|
<< "\n\t C_ must be greater than 0"
|
|
<< "\n\t C_: " << C_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
C = C_;
|
|
}
|
|
|
|
scalar_type get_c (
|
|
) const
|
|
{
|
|
return C;
|
|
}
|
|
|
|
void set_match_eps (
|
|
double eps
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(0 < eps && eps < 1,
|
|
"\t void structural_object_detection_trainer::set_match_eps(eps)"
|
|
<< "\n\t Invalid inputs were given to this function "
|
|
<< "\n\t eps: " << eps
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
match_eps = eps;
|
|
}
|
|
|
|
double get_match_eps (
|
|
) const
|
|
{
|
|
return match_eps;
|
|
}
|
|
|
|
double get_loss_per_missed_target (
|
|
) const
|
|
{
|
|
return loss_per_missed_target;
|
|
}
|
|
|
|
void set_loss_per_missed_target (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss > 0,
|
|
"\t void structural_object_detection_trainer::set_loss_per_missed_target(loss)"
|
|
<< "\n\t Invalid inputs were given to this function "
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
loss_per_missed_target = loss;
|
|
}
|
|
|
|
double get_loss_per_false_alarm (
|
|
) const
|
|
{
|
|
return loss_per_false_alarm;
|
|
}
|
|
|
|
void set_loss_per_false_alarm (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss > 0,
|
|
"\t void structural_object_detection_trainer::set_loss_per_false_alarm(loss)"
|
|
<< "\n\t Invalid inputs were given to this function "
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
loss_per_false_alarm = loss;
|
|
}
|
|
|
|
template <
|
|
typename image_array_type
|
|
>
|
|
const trained_function_type train (
|
|
const image_array_type& images,
|
|
const std::vector<std::vector<full_object_detection> >& truth_object_detections
|
|
) const
|
|
{
|
|
std::vector<std::vector<rectangle> > empty_ignore(images.size());
|
|
return train_impl(images, truth_object_detections, empty_ignore, test_box_overlap());
|
|
}
|
|
|
|
template <
|
|
typename image_array_type
|
|
>
|
|
const trained_function_type train (
|
|
const image_array_type& images,
|
|
const std::vector<std::vector<full_object_detection> >& truth_object_detections,
|
|
const std::vector<std::vector<rectangle> >& ignore,
|
|
const test_box_overlap& ignore_overlap_tester = test_box_overlap()
|
|
) const
|
|
{
|
|
return train_impl(images, truth_object_detections, ignore, ignore_overlap_tester);
|
|
}
|
|
|
|
template <
|
|
typename image_array_type
|
|
>
|
|
const trained_function_type train (
|
|
const image_array_type& images,
|
|
const std::vector<std::vector<rectangle> >& truth_object_detections
|
|
) const
|
|
{
|
|
std::vector<std::vector<rectangle> > empty_ignore(images.size());
|
|
return train(images, truth_object_detections, empty_ignore, test_box_overlap());
|
|
}
|
|
|
|
template <
|
|
typename image_array_type
|
|
>
|
|
const trained_function_type train (
|
|
const image_array_type& images,
|
|
const std::vector<std::vector<rectangle> >& truth_object_detections,
|
|
const std::vector<std::vector<rectangle> >& ignore,
|
|
const test_box_overlap& ignore_overlap_tester = test_box_overlap()
|
|
) const
|
|
{
|
|
std::vector<std::vector<full_object_detection> > truth_dets(truth_object_detections.size());
|
|
for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
|
|
{
|
|
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
|
|
{
|
|
truth_dets[i].push_back(full_object_detection(truth_object_detections[i][j]));
|
|
}
|
|
}
|
|
|
|
return train_impl(images, truth_dets, ignore, ignore_overlap_tester);
|
|
}
|
|
|
|
private:
|
|
|
|
template <
|
|
typename image_array_type
|
|
>
|
|
const trained_function_type train_impl (
|
|
const image_array_type& images,
|
|
const std::vector<std::vector<full_object_detection> >& truth_object_detections,
|
|
const std::vector<std::vector<rectangle> >& ignore,
|
|
const test_box_overlap& ignore_overlap_tester
|
|
) const
|
|
{
|
|
#ifdef ENABLE_ASSERTS
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_learning_problem(images,truth_object_detections) == true && images.size() == ignore.size(),
|
|
"\t trained_function_type structural_object_detection_trainer::train()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t images.size(): " << images.size()
|
|
<< "\n\t ignore.size(): " << ignore.size()
|
|
<< "\n\t truth_object_detections.size(): " << truth_object_detections.size()
|
|
<< "\n\t is_learning_problem(images,truth_object_detections): " << is_learning_problem(images,truth_object_detections)
|
|
);
|
|
for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
|
|
{
|
|
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
|
|
{
|
|
DLIB_ASSERT(truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template() &&
|
|
all_parts_in_rect(truth_object_detections[i][j]) == true,
|
|
"\t trained_function_type structural_object_detection_trainer::train()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t truth_object_detections["<<i<<"]["<<j<<"].num_parts(): " <<
|
|
truth_object_detections[i][j].num_parts()
|
|
<< "\n\t get_scanner().get_num_movable_components_per_detection_template(): " <<
|
|
get_scanner().get_num_movable_components_per_detection_template()
|
|
<< "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j])
|
|
);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
structural_svm_object_detection_problem<image_scanner_type,image_array_type >
|
|
svm_prob(scanner, overlap_tester, auto_overlap_tester, images,
|
|
truth_object_detections, ignore, ignore_overlap_tester, num_threads);
|
|
|
|
if (verbose)
|
|
svm_prob.be_verbose();
|
|
|
|
svm_prob.set_c(C);
|
|
svm_prob.set_epsilon(eps);
|
|
svm_prob.set_max_cache_size(max_cache_size);
|
|
svm_prob.set_match_eps(match_eps);
|
|
svm_prob.set_loss_per_missed_target(loss_per_missed_target);
|
|
svm_prob.set_loss_per_false_alarm(loss_per_false_alarm);
|
|
configure_nuclear_norm_regularizer(scanner, svm_prob);
|
|
matrix<double,0,1> w;
|
|
|
|
// Run the optimizer to find the optimal w.
|
|
solver(svm_prob,w);
|
|
|
|
// report the results of the training.
|
|
return object_detector<image_scanner_type>(scanner, svm_prob.get_overlap_tester(), w);
|
|
}
|
|
|
|
image_scanner_type scanner;
|
|
test_box_overlap overlap_tester;
|
|
|
|
double C;
|
|
oca solver;
|
|
double eps;
|
|
double match_eps;
|
|
bool verbose;
|
|
unsigned long num_threads;
|
|
unsigned long max_cache_size;
|
|
double loss_per_missed_target;
|
|
double loss_per_false_alarm;
|
|
bool auto_overlap_tester;
|
|
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_
|
|
|
|
|