167 lines
5.9 KiB
C++
167 lines
5.9 KiB
C++
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#undef DLIB_ONE_VS_ONE_TRAiNER_ABSTRACT_Hh_
|
|
#ifdef DLIB_ONE_VS_ONE_TRAiNER_ABSTRACT_Hh_
|
|
|
|
|
|
#include "one_vs_one_decision_function_abstract.h"
|
|
#include <vector>
|
|
|
|
#include "../any/any_trainer_abstract.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename any_trainer,
|
|
typename label_type_ = double
|
|
>
|
|
class one_vs_one_trainer
|
|
{
|
|
/*!
|
|
REQUIREMENTS ON any_trainer
|
|
must be an instantiation of the dlib::any_trainer template.
|
|
|
|
REQUIREMENTS ON label_type_
|
|
label_type_ must be default constructable, copyable, and comparable using
|
|
operator < and ==. It must also be possible to write it to an std::ostream
|
|
using operator<<.
|
|
|
|
WHAT THIS OBJECT REPRESENTS
|
|
This object is a tool for turning a bunch of binary classifiers
|
|
into a multiclass classifier. It does this by training the binary
|
|
classifiers in a one vs. one fashion. That is, if you have N possible
|
|
classes then it trains N*(N-1)/2 binary classifiers which are then used
|
|
to vote on the identity of a test sample.
|
|
|
|
This object works with any kind of binary classification trainer object
|
|
capable of being assigned to an any_trainer object. (e.g. the svm_nu_trainer)
|
|
!*/
|
|
|
|
public:
|
|
|
|
|
|
typedef label_type_ label_type;
|
|
|
|
typedef typename any_trainer::sample_type sample_type;
|
|
typedef typename any_trainer::scalar_type scalar_type;
|
|
typedef typename any_trainer::mem_manager_type mem_manager_type;
|
|
|
|
typedef one_vs_one_decision_function<one_vs_one_trainer> trained_function_type;
|
|
|
|
one_vs_one_trainer (
|
|
);
|
|
/*!
|
|
ensures
|
|
- This object is properly initialized
|
|
- This object will not be verbose unless be_verbose() is called.
|
|
- No binary trainers are associated with *this. I.e. you have to
|
|
call set_trainer() before calling train().
|
|
- #get_num_threads() == 4
|
|
!*/
|
|
|
|
void set_trainer (
|
|
const any_trainer& trainer
|
|
);
|
|
/*!
|
|
ensures
|
|
- sets the trainer used for all pairs of training. Any previous
|
|
calls to set_trainer() are overridden by this function. Even the
|
|
more specific set_trainer(trainer, l1, l2) form.
|
|
!*/
|
|
|
|
void set_trainer (
|
|
const any_trainer& trainer,
|
|
const label_type& l1,
|
|
const label_type& l2
|
|
);
|
|
/*!
|
|
requires
|
|
- l1 != l2
|
|
ensures
|
|
- Sets the trainer object used to create a binary classifier to
|
|
distinguish l1 labeled samples from l2 labeled samples.
|
|
!*/
|
|
|
|
void be_verbose (
|
|
);
|
|
/*!
|
|
ensures
|
|
- This object will print status messages to standard out so that a
|
|
user can observe the progress of the algorithm.
|
|
!*/
|
|
|
|
void be_quiet (
|
|
);
|
|
/*!
|
|
ensures
|
|
- this object will not print anything to standard out
|
|
!*/
|
|
|
|
void set_num_threads (
|
|
unsigned long num
|
|
);
|
|
/*!
|
|
ensures
|
|
- #get_num_threads() == num
|
|
!*/
|
|
|
|
unsigned long get_num_threads (
|
|
) const;
|
|
/*!
|
|
ensures
|
|
- returns the number of threads used during training. You should
|
|
usually set this equal to the number of processing cores on your
|
|
machine.
|
|
!*/
|
|
|
|
struct invalid_label : public dlib::error
|
|
{
|
|
/*!
|
|
This is the exception thrown by the train() function below.
|
|
!*/
|
|
label_type l1, l2;
|
|
};
|
|
|
|
trained_function_type train (
|
|
const std::vector<sample_type>& all_samples,
|
|
const std::vector<label_type>& all_labels
|
|
) const;
|
|
/*!
|
|
requires
|
|
- is_learning_problem(all_samples, all_labels)
|
|
ensures
|
|
- trains a bunch of binary classifiers in a one vs one fashion to solve the given
|
|
multiclass classification problem.
|
|
- returns a one_vs_one_decision_function F with the following properties:
|
|
- F contains all the learned binary classifiers and can be used to predict
|
|
the labels of new samples.
|
|
- if (new_x is a sample predicted to have a label of L) then
|
|
- F(new_x) == L
|
|
- F.get_labels() == select_all_distinct_labels(all_labels)
|
|
- F.number_of_classes() == select_all_distinct_labels(all_labels).size()
|
|
throws
|
|
- invalid_label
|
|
This exception is thrown if there are labels in all_labels which don't have
|
|
any corresponding trainer object. This will never happen if set_trainer(trainer)
|
|
has been called. However, if only the set_trainer(trainer,l1,l2) form has been
|
|
used then this exception is thrown if not all necessary label pairs have been
|
|
given a trainer.
|
|
|
|
invalid_label::l1 and invalid_label::l2 will contain the label pair which is
|
|
missing a trainer object. Additionally, the exception will contain an
|
|
informative error message available via invalid_label::what().
|
|
!*/
|
|
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_ONE_VS_ONE_TRAiNER_ABSTRACT_Hh_
|
|
|
|
|