415 lines
13 KiB
C++
415 lines
13 KiB
C++
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_FIND_MAX_PaRSE_CKY_Hh_
|
|
#define DLIB_FIND_MAX_PaRSE_CKY_Hh_
|
|
|
|
#include "find_max_parse_cky_abstract.h"
|
|
#include <vector>
|
|
#include <string>
|
|
#include <sstream>
|
|
#include "../serialize.h"
|
|
#include "../array2d.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
template <typename T>
|
|
struct constituent
|
|
{
|
|
unsigned long begin, end, k;
|
|
T left_tag;
|
|
T right_tag;
|
|
};
|
|
|
|
template <typename T>
|
|
void serialize(
|
|
const constituent<T>& item,
|
|
std::ostream& out
|
|
)
|
|
{
|
|
serialize(item.begin, out);
|
|
serialize(item.end, out);
|
|
serialize(item.k, out);
|
|
serialize(item.left_tag, out);
|
|
serialize(item.right_tag, out);
|
|
}
|
|
|
|
template <typename T>
|
|
void deserialize(
|
|
constituent<T>& item,
|
|
std::istream& in
|
|
)
|
|
{
|
|
deserialize(item.begin, in);
|
|
deserialize(item.end, in);
|
|
deserialize(item.k, in);
|
|
deserialize(item.left_tag, in);
|
|
deserialize(item.right_tag, in);
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
const unsigned long END_OF_TREE = 0xFFFFFFFF;
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
template <typename T>
|
|
struct parse_tree_element
|
|
{
|
|
constituent<T> c;
|
|
T tag; // id for the constituent corresponding to this level of the tree
|
|
|
|
unsigned long left;
|
|
unsigned long right;
|
|
double score;
|
|
};
|
|
|
|
template <typename T>
|
|
void serialize (
|
|
const parse_tree_element<T>& item,
|
|
std::ostream& out
|
|
)
|
|
{
|
|
serialize(item.c, out);
|
|
serialize(item.tag, out);
|
|
serialize(item.left, out);
|
|
serialize(item.right, out);
|
|
serialize(item.score, out);
|
|
}
|
|
|
|
template <typename T>
|
|
void deserialize (
|
|
parse_tree_element<T>& item,
|
|
std::istream& in
|
|
)
|
|
{
|
|
deserialize(item.c, in);
|
|
deserialize(item.tag, in);
|
|
deserialize(item.left, in);
|
|
deserialize(item.right, in);
|
|
deserialize(item.score, in);
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
namespace impl
|
|
{
|
|
template <typename T>
|
|
unsigned long fill_parse_tree(
|
|
std::vector<parse_tree_element<T> >& parse_tree,
|
|
const T& tag,
|
|
const array2d<std::map<T, parse_tree_element<T> > >& back,
|
|
long r, long c
|
|
)
|
|
/*!
|
|
requires
|
|
- back[r][c].size() == 0 || back[r][c].count(tag) != 0
|
|
!*/
|
|
{
|
|
// base case of the recursion
|
|
if (back[r][c].size() == 0)
|
|
{
|
|
return END_OF_TREE;
|
|
}
|
|
|
|
const unsigned long idx = parse_tree.size();
|
|
const parse_tree_element<T>& item = back[r][c].find(tag)->second;
|
|
parse_tree.push_back(item);
|
|
|
|
const long k = item.c.k;
|
|
const unsigned long idx_left = fill_parse_tree(parse_tree, item.c.left_tag, back, r, k-1);
|
|
const unsigned long idx_right = fill_parse_tree(parse_tree, item.c.right_tag, back, k, c);
|
|
parse_tree[idx].left = idx_left;
|
|
parse_tree[idx].right = idx_right;
|
|
return idx;
|
|
}
|
|
}
|
|
|
|
template <typename T, typename production_rule_function>
|
|
void find_max_parse_cky (
|
|
const std::vector<T>& sequence,
|
|
const production_rule_function& production_rules,
|
|
std::vector<parse_tree_element<T> >& parse_tree
|
|
)
|
|
{
|
|
parse_tree.clear();
|
|
if (sequence.size() == 0)
|
|
return;
|
|
|
|
array2d<std::map<T,double> > table(sequence.size(), sequence.size());
|
|
array2d<std::map<T,parse_tree_element<T> > > back(sequence.size(), sequence.size());
|
|
typedef typename std::map<T,double>::iterator itr;
|
|
typedef typename std::map<T,parse_tree_element<T> >::iterator itr_b;
|
|
|
|
for (long r = 0; r < table.nr(); ++r)
|
|
table[r][r][sequence[r]] = 0;
|
|
|
|
std::vector<std::pair<T,double> > possible_tags;
|
|
|
|
for (long r = table.nr()-2; r >= 0; --r)
|
|
{
|
|
for (long c = r+1; c < table.nc(); ++c)
|
|
{
|
|
for (long k = r; k < c; ++k)
|
|
{
|
|
for (itr i = table[k+1][c].begin(); i != table[k+1][c].end(); ++i)
|
|
{
|
|
for (itr j = table[r][k].begin(); j != table[r][k].end(); ++j)
|
|
{
|
|
constituent<T> con;
|
|
con.begin = r;
|
|
con.end = c+1;
|
|
con.k = k+1;
|
|
con.left_tag = j->first;
|
|
con.right_tag = i->first;
|
|
possible_tags.clear();
|
|
production_rules(sequence, con, possible_tags);
|
|
for (unsigned long m = 0; m < possible_tags.size(); ++m)
|
|
{
|
|
const double score = possible_tags[m].second + i->second + j->second;
|
|
itr match = table[r][c].find(possible_tags[m].first);
|
|
if (match == table[r][c].end() || score > match->second)
|
|
{
|
|
table[r][c][possible_tags[m].first] = score;
|
|
parse_tree_element<T> item;
|
|
item.c = con;
|
|
item.score = score;
|
|
item.tag = possible_tags[m].first;
|
|
item.left = END_OF_TREE;
|
|
item.right = END_OF_TREE;
|
|
back[r][c][possible_tags[m].first] = item;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
// now use back pointers to build the parse trees
|
|
const long r = 0;
|
|
const long c = back.nc()-1;
|
|
if (back[r][c].size() != 0)
|
|
{
|
|
|
|
// find the max scoring element in back[r][c]
|
|
itr_b max_i = back[r][c].begin();
|
|
itr_b i = max_i;
|
|
++i;
|
|
for (; i != back[r][c].end(); ++i)
|
|
{
|
|
if (i->second.score > max_i->second.score)
|
|
max_i = i;
|
|
}
|
|
|
|
parse_tree.reserve(c);
|
|
impl::fill_parse_tree(parse_tree, max_i->second.tag, back, r, c);
|
|
}
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
class parse_tree_to_string_error : public error
|
|
{
|
|
public:
|
|
parse_tree_to_string_error(const std::string& str): error(str) {}
|
|
};
|
|
|
|
namespace impl
|
|
{
|
|
template <bool enabled, typename T>
|
|
typename enable_if_c<enabled>::type conditional_print(
|
|
const T& item,
|
|
std::ostream& out
|
|
) { out << item << " "; }
|
|
|
|
template <bool enabled, typename T>
|
|
typename disable_if_c<enabled>::type conditional_print(
|
|
const T& ,
|
|
std::ostream&
|
|
) { }
|
|
|
|
template <bool print_tag, bool skip_tag, typename T, typename U >
|
|
void print_parse_tree_helper (
|
|
const std::vector<parse_tree_element<T> >& tree,
|
|
const std::vector<U>& words,
|
|
unsigned long i,
|
|
const T& tag_to_skip,
|
|
std::ostream& out
|
|
)
|
|
{
|
|
if (!skip_tag || tree[i].tag != tag_to_skip)
|
|
out << "[";
|
|
|
|
bool left_recurse = false;
|
|
|
|
// Only print if we are supposed to. Doing it this funny way avoids compiler
|
|
// errors in parse_tree_to_string() for the case where tag isn't
|
|
// printable.
|
|
if (!skip_tag || tree[i].tag != tag_to_skip)
|
|
conditional_print<print_tag>(tree[i].tag, out);
|
|
|
|
if (tree[i].left < tree.size())
|
|
{
|
|
left_recurse = true;
|
|
print_parse_tree_helper<print_tag,skip_tag>(tree, words, tree[i].left, tag_to_skip, out);
|
|
}
|
|
else
|
|
{
|
|
if ((tree[i].c.begin) < words.size())
|
|
{
|
|
out << words[tree[i].c.begin] << " ";
|
|
}
|
|
else
|
|
{
|
|
std::ostringstream sout;
|
|
sout << "Parse tree refers to element " << tree[i].c.begin
|
|
<< " of sequence which is only of size " << words.size() << ".";
|
|
throw parse_tree_to_string_error(sout.str());
|
|
}
|
|
}
|
|
|
|
if (left_recurse == true)
|
|
out << " ";
|
|
|
|
if (tree[i].right < tree.size())
|
|
{
|
|
print_parse_tree_helper<print_tag,skip_tag>(tree, words, tree[i].right, tag_to_skip, out);
|
|
}
|
|
else
|
|
{
|
|
if (tree[i].c.k < words.size())
|
|
{
|
|
out << words[tree[i].c.k];
|
|
}
|
|
else
|
|
{
|
|
std::ostringstream sout;
|
|
sout << "Parse tree refers to element " << tree[i].c.k
|
|
<< " of sequence which is only of size " << words.size() << ".";
|
|
throw parse_tree_to_string_error(sout.str());
|
|
}
|
|
}
|
|
|
|
|
|
if (!skip_tag || tree[i].tag != tag_to_skip)
|
|
out << "]";
|
|
}
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
template <typename T, typename U>
|
|
std::string parse_tree_to_string (
|
|
const std::vector<parse_tree_element<T> >& tree,
|
|
const std::vector<U>& words,
|
|
const unsigned long root_idx = 0
|
|
)
|
|
{
|
|
if (root_idx >= tree.size())
|
|
return "";
|
|
|
|
std::ostringstream sout;
|
|
impl::print_parse_tree_helper<false,false>(tree, words, root_idx, tree[root_idx].tag, sout);
|
|
return sout.str();
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
template <typename T, typename U>
|
|
std::string parse_tree_to_string_tagged (
|
|
const std::vector<parse_tree_element<T> >& tree,
|
|
const std::vector<U>& words,
|
|
const unsigned long root_idx = 0
|
|
)
|
|
{
|
|
if (root_idx >= tree.size())
|
|
return "";
|
|
|
|
std::ostringstream sout;
|
|
impl::print_parse_tree_helper<true,false>(tree, words, root_idx, tree[root_idx].tag, sout);
|
|
return sout.str();
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
template <typename T, typename U>
|
|
std::string parse_trees_to_string (
|
|
const std::vector<parse_tree_element<T> >& tree,
|
|
const std::vector<U>& words,
|
|
const T& tag_to_skip
|
|
)
|
|
{
|
|
if (tree.size() == 0)
|
|
return "";
|
|
|
|
std::ostringstream sout;
|
|
impl::print_parse_tree_helper<false,true>(tree, words, 0, tag_to_skip, sout);
|
|
return sout.str();
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
template <typename T, typename U>
|
|
std::string parse_trees_to_string_tagged (
|
|
const std::vector<parse_tree_element<T> >& tree,
|
|
const std::vector<U>& words,
|
|
const T& tag_to_skip
|
|
)
|
|
{
|
|
if (tree.size() == 0)
|
|
return "";
|
|
|
|
std::ostringstream sout;
|
|
impl::print_parse_tree_helper<true,true>(tree, words, 0, tag_to_skip, sout);
|
|
return sout.str();
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
namespace impl
|
|
{
|
|
template <typename T>
|
|
void helper_find_trees_without_tag (
|
|
const std::vector<parse_tree_element<T> >& tree,
|
|
const T& tag,
|
|
std::vector<unsigned long>& tree_roots,
|
|
unsigned long idx
|
|
)
|
|
{
|
|
if (idx < tree.size())
|
|
{
|
|
if (tree[idx].tag != tag)
|
|
{
|
|
tree_roots.push_back(idx);
|
|
}
|
|
else
|
|
{
|
|
helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].left);
|
|
helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].right);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void find_trees_not_rooted_with_tag (
|
|
const std::vector<parse_tree_element<T> >& tree,
|
|
const T& tag,
|
|
std::vector<unsigned long>& tree_roots
|
|
)
|
|
{
|
|
tree_roots.clear();
|
|
impl::helper_find_trees_without_tag(tree, tag, tree_roots, 0);
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_FIND_MAX_PaRSE_CKY_Hh_
|
|
|