340 lines
10 KiB
C++
340 lines
10 KiB
C++
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_SEQUENCE_LAbELER_H_h_
|
|
#define DLIB_SEQUENCE_LAbELER_H_h_
|
|
|
|
#include "sequence_labeler_abstract.h"
|
|
#include "../matrix.h"
|
|
#include <vector>
|
|
#include "../optimization/find_max_factor_graph_viterbi.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
namespace fe_helpers
|
|
{
|
|
template <typename EXP>
|
|
struct dot_functor
|
|
{
|
|
dot_functor(const matrix_exp<EXP>& lambda_) : lambda(lambda_), value(0) {}
|
|
|
|
inline void operator() (
|
|
unsigned long feat_index
|
|
)
|
|
{
|
|
value += lambda(feat_index);
|
|
}
|
|
|
|
inline void operator() (
|
|
unsigned long feat_index,
|
|
double feat_value
|
|
)
|
|
{
|
|
value += feat_value*lambda(feat_index);
|
|
}
|
|
|
|
const matrix_exp<EXP>& lambda;
|
|
double value;
|
|
};
|
|
|
|
template <typename feature_extractor, typename EXP, typename sequence_type, typename EXP2>
|
|
double dot(
|
|
const matrix_exp<EXP>& lambda,
|
|
const feature_extractor& fe,
|
|
const sequence_type& sequence,
|
|
const matrix_exp<EXP2>& candidate_labeling,
|
|
unsigned long position
|
|
)
|
|
{
|
|
dot_functor<EXP> dot(lambda);
|
|
fe.get_features(dot, sequence, candidate_labeling, position);
|
|
return dot.value;
|
|
}
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
namespace impl
|
|
{
|
|
DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(
|
|
has_reject_labeling,
|
|
bool,
|
|
template reject_labeling<matrix<unsigned long> >,
|
|
(const typename T::sequence_type&, const matrix_exp<matrix<unsigned long> >&, unsigned long)const
|
|
)
|
|
|
|
template <typename feature_extractor, typename EXP, typename sequence_type>
|
|
typename enable_if<has_reject_labeling<feature_extractor>,bool>::type call_reject_labeling_if_exists (
|
|
const feature_extractor& fe,
|
|
const sequence_type& x,
|
|
const matrix_exp<EXP>& y,
|
|
unsigned long position
|
|
)
|
|
{
|
|
return fe.reject_labeling(x, y, position);
|
|
}
|
|
|
|
template <typename feature_extractor, typename EXP, typename sequence_type>
|
|
typename disable_if<has_reject_labeling<feature_extractor>,bool>::type call_reject_labeling_if_exists (
|
|
const feature_extractor& ,
|
|
const sequence_type& ,
|
|
const matrix_exp<EXP>& ,
|
|
unsigned long
|
|
)
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename feature_extractor
|
|
>
|
|
typename enable_if<dlib::impl::has_reject_labeling<feature_extractor>,bool>::type contains_invalid_labeling (
|
|
const feature_extractor& fe,
|
|
const typename feature_extractor::sequence_type& x,
|
|
const std::vector<unsigned long>& y
|
|
)
|
|
{
|
|
if (x.size() != y.size())
|
|
return true;
|
|
|
|
matrix<unsigned long,0,1> node_states;
|
|
|
|
for (unsigned long i = 0; i < x.size(); ++i)
|
|
{
|
|
node_states.set_size(std::min(fe.order(),i) + 1);
|
|
for (unsigned long j = 0; j < (unsigned long)node_states.size(); ++j)
|
|
node_states(j) = y[i-j];
|
|
|
|
if (fe.reject_labeling(x, node_states, i))
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename feature_extractor
|
|
>
|
|
typename disable_if<dlib::impl::has_reject_labeling<feature_extractor>,bool>::type contains_invalid_labeling (
|
|
const feature_extractor& ,
|
|
const typename feature_extractor::sequence_type& x,
|
|
const std::vector<unsigned long>& y
|
|
)
|
|
{
|
|
if (x.size() != y.size())
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename feature_extractor
|
|
>
|
|
bool contains_invalid_labeling (
|
|
const feature_extractor& fe,
|
|
const std::vector<typename feature_extractor::sequence_type>& x,
|
|
const std::vector<std::vector<unsigned long> >& y
|
|
)
|
|
{
|
|
if (x.size() != y.size())
|
|
return true;
|
|
|
|
for (unsigned long i = 0; i < x.size(); ++i)
|
|
{
|
|
if (contains_invalid_labeling(fe,x[i],y[i]))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename feature_extractor
|
|
>
|
|
class sequence_labeler
|
|
{
|
|
public:
|
|
typedef typename feature_extractor::sequence_type sample_sequence_type;
|
|
typedef std::vector<unsigned long> labeled_sequence_type;
|
|
|
|
private:
|
|
class map_prob
|
|
{
|
|
public:
|
|
unsigned long order() const { return fe.order(); }
|
|
unsigned long num_states() const { return fe.num_labels(); }
|
|
|
|
map_prob(
|
|
const sample_sequence_type& x_,
|
|
const feature_extractor& fe_,
|
|
const matrix<double,0,1>& weights_
|
|
) :
|
|
sequence(x_),
|
|
fe(fe_),
|
|
weights(weights_)
|
|
{
|
|
}
|
|
|
|
unsigned long number_of_nodes(
|
|
) const
|
|
{
|
|
return sequence.size();
|
|
}
|
|
|
|
template <
|
|
typename EXP
|
|
>
|
|
double factor_value (
|
|
unsigned long node_id,
|
|
const matrix_exp<EXP>& node_states
|
|
) const
|
|
{
|
|
if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id))
|
|
return -std::numeric_limits<double>::infinity();
|
|
|
|
return fe_helpers::dot(weights, fe, sequence, node_states, node_id);
|
|
}
|
|
|
|
const sample_sequence_type& sequence;
|
|
const feature_extractor& fe;
|
|
const matrix<double,0,1>& weights;
|
|
};
|
|
public:
|
|
|
|
sequence_labeler()
|
|
{
|
|
weights.set_size(fe.num_features());
|
|
weights = 0;
|
|
}
|
|
|
|
explicit sequence_labeler(
|
|
const matrix<double,0,1>& weights_
|
|
) :
|
|
weights(weights_)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(fe.num_features() == static_cast<unsigned long>(weights_.size()),
|
|
"\t sequence_labeler::sequence_labeler(weights_)"
|
|
<< "\n\t These sizes should match"
|
|
<< "\n\t fe.num_features(): " << fe.num_features()
|
|
<< "\n\t weights_.size(): " << weights_.size()
|
|
<< "\n\t this: " << this
|
|
);
|
|
}
|
|
|
|
sequence_labeler(
|
|
const matrix<double,0,1>& weights_,
|
|
const feature_extractor& fe_
|
|
) :
|
|
fe(fe_),
|
|
weights(weights_)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(fe_.num_features() == static_cast<unsigned long>(weights_.size()),
|
|
"\t sequence_labeler::sequence_labeler(weights_,fe_)"
|
|
<< "\n\t These sizes should match"
|
|
<< "\n\t fe_.num_features(): " << fe_.num_features()
|
|
<< "\n\t weights_.size(): " << weights_.size()
|
|
<< "\n\t this: " << this
|
|
);
|
|
}
|
|
|
|
const feature_extractor& get_feature_extractor (
|
|
) const { return fe; }
|
|
|
|
const matrix<double,0,1>& get_weights (
|
|
) const { return weights; }
|
|
|
|
unsigned long num_labels (
|
|
) const { return fe.num_labels(); }
|
|
|
|
labeled_sequence_type operator() (
|
|
const sample_sequence_type& x
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(num_labels() > 0,
|
|
"\t labeled_sequence_type sequence_labeler::operator()(x)"
|
|
<< "\n\t You can't have no labels."
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
labeled_sequence_type y;
|
|
find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
|
|
return y;
|
|
}
|
|
|
|
void label_sequence (
|
|
const sample_sequence_type& x,
|
|
labeled_sequence_type& y
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(num_labels() > 0,
|
|
"\t void sequence_labeler::label_sequence(x,y)"
|
|
<< "\n\t You can't have no labels."
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
|
|
}
|
|
|
|
private:
|
|
|
|
feature_extractor fe;
|
|
matrix<double,0,1> weights;
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename feature_extractor
|
|
>
|
|
void serialize (
|
|
const sequence_labeler<feature_extractor>& item,
|
|
std::ostream& out
|
|
)
|
|
{
|
|
serialize(item.get_feature_extractor(), out);
|
|
serialize(item.get_weights(), out);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename feature_extractor
|
|
>
|
|
void deserialize (
|
|
sequence_labeler<feature_extractor>& item,
|
|
std::istream& in
|
|
)
|
|
{
|
|
feature_extractor fe;
|
|
matrix<double,0,1> weights;
|
|
|
|
deserialize(fe, in);
|
|
deserialize(weights, in);
|
|
|
|
item = sequence_labeler<feature_extractor>(weights, fe);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_SEQUENCE_LAbELER_H_h_
|
|
|