209 lines
6.9 KiB
C++
209 lines
6.9 KiB
C++
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
|
|
#define DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
|
|
|
|
#include <vector>
|
|
#include "../matrix.h"
|
|
#include "cross_validate_multiclass_trainer_abstract.h"
|
|
#include <sstream>
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename dec_funct_type,
|
|
typename sample_type,
|
|
typename label_type
|
|
>
|
|
const matrix<double> test_multiclass_decision_function (
|
|
const dec_funct_type& dec_funct,
|
|
const std::vector<sample_type>& x_test,
|
|
const std::vector<label_type>& y_test
|
|
)
|
|
{
|
|
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT( is_learning_problem(x_test,y_test) == true,
|
|
"\tmatrix test_multiclass_decision_function()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_learning_problem(x_test,y_test): "
|
|
<< is_learning_problem(x_test,y_test));
|
|
|
|
|
|
const std::vector<label_type> all_labels = dec_funct.get_labels();
|
|
|
|
// make a lookup table that maps from labels to their index in all_labels
|
|
std::map<label_type,unsigned long> label_to_int;
|
|
for (unsigned long i = 0; i < all_labels.size(); ++i)
|
|
label_to_int[all_labels[i]] = i;
|
|
|
|
matrix<double, 0, 0, typename dec_funct_type::mem_manager_type> res;
|
|
res.set_size(all_labels.size(), all_labels.size());
|
|
|
|
res = 0;
|
|
|
|
typename std::map<label_type,unsigned long>::const_iterator iter;
|
|
|
|
// now test this trained object
|
|
for (unsigned long i = 0; i < x_test.size(); ++i)
|
|
{
|
|
iter = label_to_int.find(y_test[i]);
|
|
// ignore samples with labels that the decision function doesn't know about.
|
|
if (iter == label_to_int.end())
|
|
continue;
|
|
|
|
const unsigned long truth = iter->second;
|
|
const unsigned long pred = label_to_int[dec_funct(x_test[i])];
|
|
|
|
res(truth,pred) += 1;
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class cross_validation_error : public dlib::error
|
|
{
|
|
public:
|
|
cross_validation_error(const std::string& msg) : dlib::error(msg){};
|
|
};
|
|
|
|
template <
|
|
typename trainer_type,
|
|
typename sample_type,
|
|
typename label_type
|
|
>
|
|
const matrix<double> cross_validate_multiclass_trainer (
|
|
const trainer_type& trainer,
|
|
const std::vector<sample_type>& x,
|
|
const std::vector<label_type>& y,
|
|
const long folds
|
|
)
|
|
{
|
|
typedef typename trainer_type::mem_manager_type mem_manager_type;
|
|
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_learning_problem(x,y) == true &&
|
|
1 < folds && folds <= static_cast<long>(x.size()),
|
|
"\tmatrix cross_validate_multiclass_trainer()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t x.size(): " << x.size()
|
|
<< "\n\t folds: " << folds
|
|
<< "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
|
|
);
|
|
|
|
const std::vector<label_type> all_labels = select_all_distinct_labels(y);
|
|
|
|
// count the number of times each label shows up
|
|
std::map<label_type,long> label_counts;
|
|
for (unsigned long i = 0; i < y.size(); ++i)
|
|
label_counts[y[i]] += 1;
|
|
|
|
|
|
// figure out how many samples from each class will be in the test and train splits
|
|
std::map<label_type,long> num_in_test, num_in_train;
|
|
for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i)
|
|
{
|
|
const long in_test = i->second/folds;
|
|
if (in_test == 0)
|
|
{
|
|
std::ostringstream sout;
|
|
sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl;
|
|
sout << "than the number of elements of one of the training classes." << std::endl;
|
|
sout << " folds: "<< folds << std::endl;
|
|
sout << " size of class " << i->first << ": "<< i->second << std::endl;
|
|
throw cross_validation_error(sout.str());
|
|
}
|
|
num_in_test[i->first] = in_test;
|
|
num_in_train[i->first] = i->second - in_test;
|
|
}
|
|
|
|
|
|
|
|
std::vector<sample_type> x_test, x_train;
|
|
std::vector<label_type> y_test, y_train;
|
|
|
|
matrix<double, 0, 0, mem_manager_type> res;
|
|
|
|
std::map<label_type,long> next_test_idx;
|
|
for (unsigned long i = 0; i < all_labels.size(); ++i)
|
|
next_test_idx[all_labels[i]] = 0;
|
|
|
|
label_type label;
|
|
|
|
for (long i = 0; i < folds; ++i)
|
|
{
|
|
x_test.clear();
|
|
y_test.clear();
|
|
x_train.clear();
|
|
y_train.clear();
|
|
|
|
// load up the test samples
|
|
for (unsigned long j = 0; j < all_labels.size(); ++j)
|
|
{
|
|
label = all_labels[j];
|
|
long next = next_test_idx[label];
|
|
|
|
long cur = 0;
|
|
const long num_needed = num_in_test[label];
|
|
while (cur < num_needed)
|
|
{
|
|
if (y[next] == label)
|
|
{
|
|
x_test.push_back(x[next]);
|
|
y_test.push_back(label);
|
|
++cur;
|
|
}
|
|
next = (next + 1)%x.size();
|
|
}
|
|
|
|
next_test_idx[label] = next;
|
|
}
|
|
|
|
// load up the training samples
|
|
for (unsigned long j = 0; j < all_labels.size(); ++j)
|
|
{
|
|
label = all_labels[j];
|
|
long next = next_test_idx[label];
|
|
|
|
long cur = 0;
|
|
const long num_needed = num_in_train[label];
|
|
while (cur < num_needed)
|
|
{
|
|
if (y[next] == label)
|
|
{
|
|
x_train.push_back(x[next]);
|
|
y_train.push_back(label);
|
|
++cur;
|
|
}
|
|
next = (next + 1)%x.size();
|
|
}
|
|
}
|
|
|
|
|
|
try
|
|
{
|
|
// do the training and testing
|
|
res += test_multiclass_decision_function(trainer.train(x_train,y_train),x_test,y_test);
|
|
}
|
|
catch (invalid_nu_error&)
|
|
{
|
|
// just ignore cases which result in an invalid nu
|
|
}
|
|
|
|
} // for (long i = 0; i < folds; ++i)
|
|
|
|
return res;
|
|
}
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
|
|
|