282 lines
8.8 KiB
C++
282 lines
8.8 KiB
C++
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_
|
|
#define DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_
|
|
|
|
#include "structural_sequence_segmentation_trainer_abstract.h"
|
|
#include "structural_sequence_labeling_trainer.h"
|
|
#include "sequence_segmenter.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename feature_extractor
|
|
>
|
|
class structural_sequence_segmentation_trainer
|
|
{
|
|
public:
|
|
typedef typename feature_extractor::sequence_type sample_sequence_type;
|
|
typedef std::vector<std::pair<unsigned long, unsigned long> > segmented_sequence_type;
|
|
|
|
typedef sequence_segmenter<feature_extractor> trained_function_type;
|
|
|
|
explicit structural_sequence_segmentation_trainer (
|
|
const feature_extractor& fe_
|
|
) : trainer(impl_ss::feature_extractor<feature_extractor>(fe_))
|
|
{
|
|
loss_per_missed_segment = 1;
|
|
loss_per_false_alarm = 1;
|
|
}
|
|
|
|
structural_sequence_segmentation_trainer (
|
|
)
|
|
{
|
|
loss_per_missed_segment = 1;
|
|
loss_per_false_alarm = 1;
|
|
}
|
|
|
|
const feature_extractor& get_feature_extractor (
|
|
) const { return trainer.get_feature_extractor().fe; }
|
|
|
|
void set_num_threads (
|
|
unsigned long num
|
|
)
|
|
{
|
|
trainer.set_num_threads(num);
|
|
}
|
|
|
|
unsigned long get_num_threads (
|
|
) const
|
|
{
|
|
return trainer.get_num_threads();
|
|
}
|
|
|
|
void set_epsilon (
|
|
double eps_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(eps_ > 0,
|
|
"\t void structural_sequence_segmentation_trainer::set_epsilon()"
|
|
<< "\n\t eps_ must be greater than 0"
|
|
<< "\n\t eps_: " << eps_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
trainer.set_epsilon(eps_);
|
|
}
|
|
|
|
double get_epsilon (
|
|
) const { return trainer.get_epsilon(); }
|
|
|
|
unsigned long get_max_iterations (
|
|
) const { return trainer.get_max_iterations(); }
|
|
|
|
void set_max_iterations (
|
|
unsigned long max_iter
|
|
)
|
|
{
|
|
trainer.set_max_iterations(max_iter);
|
|
}
|
|
|
|
void set_max_cache_size (
|
|
unsigned long max_size
|
|
)
|
|
{
|
|
trainer.set_max_cache_size(max_size);
|
|
}
|
|
|
|
unsigned long get_max_cache_size (
|
|
) const
|
|
{
|
|
return trainer.get_max_cache_size();
|
|
}
|
|
|
|
void be_verbose (
|
|
)
|
|
{
|
|
trainer.be_verbose();
|
|
}
|
|
|
|
void be_quiet (
|
|
)
|
|
{
|
|
trainer.be_quiet();
|
|
}
|
|
|
|
void set_oca (
|
|
const oca& item
|
|
)
|
|
{
|
|
trainer.set_oca(item);
|
|
}
|
|
|
|
const oca get_oca (
|
|
) const
|
|
{
|
|
return trainer.get_oca();
|
|
}
|
|
|
|
void set_c (
|
|
double C_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(C_ > 0,
|
|
"\t void structural_sequence_segmentation_trainer::set_c()"
|
|
<< "\n\t C_ must be greater than 0"
|
|
<< "\n\t C_: " << C_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
trainer.set_c(C_);
|
|
}
|
|
|
|
double get_c (
|
|
) const
|
|
{
|
|
return trainer.get_c();
|
|
}
|
|
|
|
void set_loss_per_missed_segment (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss >= 0,
|
|
"\t void structural_sequence_segmentation_trainer::set_loss_per_missed_segment(loss)"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
loss_per_missed_segment = loss;
|
|
|
|
if (feature_extractor::use_BIO_model)
|
|
{
|
|
trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
|
|
trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
|
|
}
|
|
else
|
|
{
|
|
trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
|
|
trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
|
|
trainer.set_loss(impl_ss::LAST, loss_per_missed_segment);
|
|
trainer.set_loss(impl_ss::UNIT, loss_per_missed_segment);
|
|
}
|
|
}
|
|
|
|
double get_loss_per_missed_segment (
|
|
) const
|
|
{
|
|
return loss_per_missed_segment;
|
|
}
|
|
|
|
void set_loss_per_false_alarm (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss >= 0,
|
|
"\t void structural_sequence_segmentation_trainer::set_loss_per_false_alarm(loss)"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
loss_per_false_alarm = loss;
|
|
|
|
trainer.set_loss(impl_ss::OUTSIDE, loss_per_false_alarm);
|
|
}
|
|
|
|
double get_loss_per_false_alarm (
|
|
) const
|
|
{
|
|
return loss_per_false_alarm;
|
|
}
|
|
|
|
const sequence_segmenter<feature_extractor> train(
|
|
const std::vector<sample_sequence_type>& x,
|
|
const std::vector<segmented_sequence_type>& y
|
|
) const
|
|
{
|
|
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_sequence_segmentation_problem(x,y) == true,
|
|
"\t sequence_segmenter structural_sequence_segmentation_trainer::train(x,y)"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t x.size(): " << x.size()
|
|
<< "\n\t is_sequence_segmentation_problem(x,y): " << is_sequence_segmentation_problem(x,y)
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
std::vector<std::vector<unsigned long> > labels(y.size());
|
|
if (feature_extractor::use_BIO_model)
|
|
{
|
|
// convert y into tagged BIO labels
|
|
for (unsigned long i = 0; i < labels.size(); ++i)
|
|
{
|
|
labels[i].resize(x[i].size(), impl_ss::OUTSIDE);
|
|
for (unsigned long j = 0; j < y[i].size(); ++j)
|
|
{
|
|
const unsigned long begin = y[i][j].first;
|
|
const unsigned long end = y[i][j].second;
|
|
if (begin != end)
|
|
{
|
|
labels[i][begin] = impl_ss::BEGIN;
|
|
for (unsigned long k = begin+1; k < end; ++k)
|
|
labels[i][k] = impl_ss::INSIDE;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// convert y into tagged BILOU labels
|
|
for (unsigned long i = 0; i < labels.size(); ++i)
|
|
{
|
|
labels[i].resize(x[i].size(), impl_ss::OUTSIDE);
|
|
for (unsigned long j = 0; j < y[i].size(); ++j)
|
|
{
|
|
const unsigned long begin = y[i][j].first;
|
|
const unsigned long end = y[i][j].second;
|
|
if (begin != end)
|
|
{
|
|
if (begin+1==end)
|
|
{
|
|
labels[i][begin] = impl_ss::UNIT;
|
|
}
|
|
else
|
|
{
|
|
labels[i][begin] = impl_ss::BEGIN;
|
|
for (unsigned long k = begin+1; k+1 < end; ++k)
|
|
labels[i][k] = impl_ss::INSIDE;
|
|
labels[i][end-1] = impl_ss::LAST;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
sequence_labeler<impl_ss::feature_extractor<feature_extractor> > temp;
|
|
temp = trainer.train(x, labels);
|
|
return sequence_segmenter<feature_extractor>(temp.get_weights(), trainer.get_feature_extractor().fe);
|
|
}
|
|
|
|
private:
|
|
|
|
structural_sequence_labeling_trainer<impl_ss::feature_extractor<feature_extractor> > trainer;
|
|
double loss_per_missed_segment;
|
|
double loss_per_false_alarm;
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_
|
|
|