212 lines
7.5 KiB
C++
212 lines
7.5 KiB
C++
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_GRAPH_LaBELER_Hh_
|
|
#define DLIB_GRAPH_LaBELER_Hh_
|
|
|
|
#include "graph_labeler_abstract.h"
|
|
#include "../matrix.h"
|
|
#include "../string.h"
|
|
#include <vector>
|
|
#include "find_max_factor_graph_potts.h"
|
|
#include "../svm/sparse_vector.h"
|
|
#include "../graph.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename vector_type
|
|
>
|
|
class graph_labeler
|
|
{
|
|
|
|
public:
|
|
|
|
typedef std::vector<bool> label_type;
|
|
typedef label_type result_type;
|
|
|
|
graph_labeler()
|
|
{
|
|
}
|
|
|
|
graph_labeler(
|
|
const vector_type& edge_weights_,
|
|
const vector_type& node_weights_
|
|
) :
|
|
edge_weights(edge_weights_),
|
|
node_weights(node_weights_)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(edge_weights.size() == 0 || min(edge_weights) >= 0,
|
|
"\t graph_labeler::graph_labeler()"
|
|
<< "\n\t Invalid inputs were given to this function."
|
|
<< "\n\t min(edge_weights): " << min(edge_weights)
|
|
<< "\n\t this: " << this
|
|
);
|
|
}
|
|
|
|
const vector_type& get_edge_weights (
|
|
) const { return edge_weights; }
|
|
|
|
const vector_type& get_node_weights (
|
|
) const { return node_weights; }
|
|
|
|
template <typename graph_type>
|
|
void operator() (
|
|
const graph_type& sample,
|
|
std::vector<bool>& labels
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
#ifdef ENABLE_ASSERTS
|
|
DLIB_ASSERT(graph_contains_length_one_cycle(sample) == false,
|
|
"\t void graph_labeler::operator()"
|
|
<< "\n\t Invalid inputs were given to this function."
|
|
<< "\n\t get_edge_weights().size(): " << get_edge_weights().size()
|
|
<< "\n\t get_node_weights().size(): " << get_node_weights().size()
|
|
<< "\n\t graph_contains_length_one_cycle(sample): " << graph_contains_length_one_cycle(sample)
|
|
<< "\n\t this: " << this
|
|
);
|
|
for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
|
|
{
|
|
if (is_matrix<vector_type>::value &&
|
|
is_matrix<typename graph_type::type>::value)
|
|
{
|
|
// check that dot() is legal.
|
|
DLIB_ASSERT((unsigned long)get_node_weights().size() == (unsigned long)sample.node(i).data.size(),
|
|
"\t void graph_labeler::operator()"
|
|
<< "\n\t The size of the node weight vector must match the one in the node."
|
|
<< "\n\t get_node_weights().size(): " << get_node_weights().size()
|
|
<< "\n\t sample.node(i).data.size(): " << sample.node(i).data.size()
|
|
<< "\n\t i: " << i
|
|
<< "\n\t this: " << this
|
|
);
|
|
}
|
|
|
|
for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
|
|
{
|
|
if (is_matrix<vector_type>::value &&
|
|
is_matrix<typename graph_type::edge_type>::value)
|
|
{
|
|
// check that dot() is legal.
|
|
DLIB_ASSERT((unsigned long)get_edge_weights().size() == (unsigned long)sample.node(i).edge(n).size(),
|
|
"\t void graph_labeler::operator()"
|
|
<< "\n\t The size of the edge weight vector must match the one in graph's edge."
|
|
<< "\n\t get_edge_weights().size(): " << get_edge_weights().size()
|
|
<< "\n\t sample.node(i).edge(n).size(): " << sample.node(i).edge(n).size()
|
|
<< "\n\t i: " << i
|
|
<< "\n\t this: " << this
|
|
);
|
|
}
|
|
|
|
DLIB_ASSERT(sample.node(i).edge(n).size() == 0 || min(sample.node(i).edge(n)) >= 0,
|
|
"\t void graph_labeler::operator()"
|
|
<< "\n\t No edge vectors are allowed to have negative elements."
|
|
<< "\n\t min(sample.node(i).edge(n)): " << min(sample.node(i).edge(n))
|
|
<< "\n\t i: " << i
|
|
<< "\n\t n: " << n
|
|
<< "\n\t this: " << this
|
|
);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
|
|
graph<double,double>::kernel_1a g;
|
|
copy_graph_structure(sample, g);
|
|
for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
|
|
{
|
|
g.node(i).data = dot(node_weights, sample.node(i).data);
|
|
|
|
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
|
|
{
|
|
const unsigned long j = g.node(i).neighbor(n).index();
|
|
// Don't compute an edge weight more than once.
|
|
if (i < j)
|
|
{
|
|
g.node(i).edge(n) = dot(edge_weights, sample.node(i).edge(n));
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
labels.clear();
|
|
std::vector<node_label> temp;
|
|
find_max_factor_graph_potts(g, temp);
|
|
for (unsigned long i = 0; i < temp.size(); ++i)
|
|
{
|
|
if (temp[i] != 0)
|
|
labels.push_back(true);
|
|
else
|
|
labels.push_back(false);
|
|
}
|
|
}
|
|
|
|
template <typename graph_type>
|
|
std::vector<bool> operator() (
|
|
const graph_type& sample
|
|
) const
|
|
{
|
|
std::vector<bool> temp;
|
|
(*this)(sample, temp);
|
|
return temp;
|
|
}
|
|
|
|
private:
|
|
|
|
vector_type edge_weights;
|
|
vector_type node_weights;
|
|
};
|
|
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename vector_type
|
|
>
|
|
void serialize (
|
|
const graph_labeler<vector_type>& item,
|
|
std::ostream& out
|
|
)
|
|
{
|
|
int version = 1;
|
|
serialize(version, out);
|
|
serialize(item.get_edge_weights(), out);
|
|
serialize(item.get_node_weights(), out);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename vector_type
|
|
>
|
|
void deserialize (
|
|
graph_labeler<vector_type>& item,
|
|
std::istream& in
|
|
)
|
|
{
|
|
int version = 0;
|
|
deserialize(version, in);
|
|
if (version != 1)
|
|
{
|
|
throw dlib::serialization_error("While deserializing graph_labeler, found unexpected version number of " +
|
|
cast_to_string(version) + ".");
|
|
}
|
|
|
|
vector_type edge_weights, node_weights;
|
|
deserialize(edge_weights, in);
|
|
deserialize(node_weights, in);
|
|
|
|
item = graph_labeler<vector_type>(edge_weights, node_weights);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_GRAPH_LaBELER_Hh_
|
|
|
|
|