// Copyright (C) 2013 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include #include "tester.h" #include #include namespace { using namespace test; using namespace dlib; using namespace std; logger dlog("test.sequence_segmenter"); // ---------------------------------------------------------------------------------------- dlib::rand rnd; template class unigram_extractor { public: const static bool use_BIO_model = use_BIO_model_; const static bool use_high_order_features = use_high_order_features_; const static bool allow_negative_weights = allow_negative_weights_; typedef std::vector sequence_type; std::map > feats; unigram_extractor() { matrix v1, v2, v3; v1 = randm(num_features(), 1, rnd); v2 = randm(num_features(), 1, rnd); v3 = randm(num_features(), 1, rnd); v1(0) = 1; v2(1) = 1; v3(2) = 1; v1(3) = -1; v2(4) = -1; v3(5) = -1; for (unsigned long i = 0; i < num_features(); ++i) { if ( i < 3) feats[i] = v1; else if (i < 6) feats[i] = v2; else feats[i] = v3; } } unsigned long num_features() const { return 10; } unsigned long window_size() const { return 3; } template void get_features ( feature_setter& set_feature, const sequence_type& x, unsigned long position ) const { const matrix& m = feats.find(x[position])->second; for (unsigned long i = 0; i < num_features(); ++i) { set_feature(i, m(i)); } } }; template void serialize(const unigram_extractor& item , std::ostream& out ) { serialize(item.feats, out); } template void deserialize(unigram_extractor& item, std::istream& in) { deserialize(item.feats, in); } // ---------------------------------------------------------------------------------------- void make_dataset ( std::vector >& samples, std::vector >& labels, unsigned long dataset_size ) { samples.clear(); labels.clear(); samples.resize(dataset_size); labels.resize(dataset_size); unigram_extractor fe; dlib::rand rnd; for (unsigned long iter = 0; iter < dataset_size; ++iter) { samples[iter].resize(10); labels[iter].resize(10); for (unsigned long i = 0; i < samples[iter].size(); ++i) { samples[iter][i] = rnd.get_random_32bit_number()%fe.num_features(); if (samples[iter][i] < 3) { labels[iter][i] = impl_ss::BEGIN; } else if (samples[iter][i] < 6) { labels[iter][i] = impl_ss::INSIDE; } else { labels[iter][i] = impl_ss::OUTSIDE; } if (i != 0) { // do rejection sampling to avoid impossible labels if (labels[iter][i] == impl_ss::INSIDE && labels[iter][i-1] == impl_ss::OUTSIDE) { --i; } } } } } // ---------------------------------------------------------------------------------------- void make_dataset2 ( std::vector >& samples, std::vector > >& segments, unsigned long dataset_size ) { segments.clear(); std::vector > labels; make_dataset(samples, labels, dataset_size); segments.resize(samples.size()); // Convert from BIO tagging to the explicit segments representation. for (unsigned long k = 0; k < labels.size(); ++k) { for (unsigned long i = 0; i < labels[k].size(); ++i) { if (labels[k][i] == impl_ss::BEGIN) { const unsigned long begin = i; ++i; while (i < labels[k].size() && labels[k][i] == impl_ss::INSIDE) ++i; segments[k].push_back(std::make_pair(begin, i)); --i; } } } } // ---------------------------------------------------------------------------------------- template void do_test() { dlog << LINFO << "use_BIO_model: "<< use_BIO_model; dlog << LINFO << "use_high_order_features: "<< use_high_order_features; dlog << LINFO << "allow_negative_weights: "<< allow_negative_weights; std::vector > samples; std::vector > > segments; make_dataset2( samples, segments, 100); print_spinner(); typedef unigram_extractor fe_type; fe_type fe_temp; fe_type fe_temp2; structural_sequence_segmentation_trainer trainer(fe_temp2); trainer.set_c(5); trainer.set_num_threads(1); sequence_segmenter labeler = trainer.train(samples, segments); print_spinner(); const std::vector > predicted_labels = labeler(samples[1]); const std::vector > true_labels = segments[1]; /* for (unsigned long i = 0; i < predicted_labels.size(); ++i) cout << "["< 0); DLIB_TEST(predicted_labels.size() == true_labels.size()); for (unsigned long i = 0; i < predicted_labels.size(); ++i) { DLIB_TEST(predicted_labels[i].first == true_labels[i].first); DLIB_TEST(predicted_labels[i].second == true_labels[i].second); } matrix res; res = cross_validate_sequence_segmenter(trainer, samples, segments, 3); dlog << LINFO << "cv res: "<< res; DLIB_TEST(min(res) > 0.98); make_dataset2( samples, segments, 100); res = test_sequence_segmenter(labeler, samples, segments); dlog << LINFO << "test res: "<< res; DLIB_TEST(min(res) > 0.98); print_spinner(); ostringstream sout; serialize(labeler, sout); istringstream sin(sout.str()); sequence_segmenter labeler2; deserialize(labeler2, sin); res = test_sequence_segmenter(labeler2, samples, segments); dlog << LINFO << "test res2: "<< res; DLIB_TEST(min(res) > 0.98); long N; if (use_BIO_model) N = 3*3+3; else N = 5*5+5; const double min_normal_weight = min(colm(labeler2.get_weights(), 0, labeler2.get_weights().size()-N)); const double min_trans_weight = min(labeler2.get_weights()); dlog << LINFO << "min_normal_weight: " << min_normal_weight; dlog << LINFO << "min_trans_weight: " << min_trans_weight; if (allow_negative_weights) { DLIB_TEST(min_normal_weight < 0); DLIB_TEST(min_trans_weight < 0); } else { DLIB_TEST(min_normal_weight == 0); DLIB_TEST(min_trans_weight < 0); } } // ---------------------------------------------------------------------------------------- class unit_test_sequence_segmenter : public tester { public: unit_test_sequence_segmenter ( ) : tester ("test_sequence_segmenter", "Runs tests on the sequence segmenting code.") {} void perform_test ( ) { do_test(); do_test(); do_test(); do_test(); do_test(); do_test(); do_test(); do_test(); } } a; }