158 lines
4.6 KiB
C++
158 lines
4.6 KiB
C++
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
|
|
|
|
#include <dlib/optimization.h>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <cstdlib>
|
|
#include <ctime>
|
|
#include <vector>
|
|
#include "../rand.h"
|
|
|
|
#include "tester.h"
|
|
|
|
|
|
namespace
|
|
{
|
|
|
|
using namespace test;
|
|
using namespace dlib;
|
|
using namespace std;
|
|
|
|
logger dlog("test.max_cost_assignment");
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
std::vector<std::vector<long> > permutations (
|
|
matrix<long,1,0> vals
|
|
)
|
|
{
|
|
if (vals.size() == 0)
|
|
{
|
|
return std::vector<std::vector<long> >();
|
|
}
|
|
else if (vals.size() == 1)
|
|
{
|
|
return std::vector<std::vector<long> >(1,std::vector<long>(1,vals(0)));
|
|
}
|
|
|
|
|
|
std::vector<std::vector<long> > temp;
|
|
|
|
|
|
for (long i = 0; i < vals.size(); ++i)
|
|
{
|
|
const std::vector<std::vector<long> >& res = permutations(remove_col(vals,i));
|
|
|
|
for (unsigned long j = 0; j < res.size(); ++j)
|
|
{
|
|
temp.resize(temp.size()+1);
|
|
std::vector<long>& part = temp.back();
|
|
part.push_back(vals(i));
|
|
part.insert(part.end(), res[j].begin(), res[j].end());
|
|
}
|
|
}
|
|
|
|
|
|
return temp;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename T>
|
|
std::vector<long> brute_force_max_cost_assignment (
|
|
matrix<T> cost
|
|
)
|
|
{
|
|
if (cost.size() == 0)
|
|
return std::vector<long>();
|
|
|
|
const std::vector<std::vector<long> >& perms = permutations(range(0,cost.nc()-1));
|
|
|
|
T best_cost = std::numeric_limits<T>::min();
|
|
unsigned long best_idx = 0;
|
|
for (unsigned long i = 0; i < perms.size(); ++i)
|
|
{
|
|
const T temp = assignment_cost(cost, perms[i]);
|
|
if (temp > best_cost)
|
|
{
|
|
best_idx = i;
|
|
best_cost = temp;
|
|
}
|
|
}
|
|
|
|
return perms[best_idx];
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
// ----------------------------------------------------------------------------------------
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class test_max_cost_assignment : public tester
|
|
{
|
|
public:
|
|
test_max_cost_assignment (
|
|
) :
|
|
tester ("test_max_cost_assignment",
|
|
"Runs tests on the max_cost_assignment function.")
|
|
{}
|
|
|
|
dlib::rand rnd;
|
|
|
|
template <typename T>
|
|
void test_hungarian()
|
|
{
|
|
long size = rnd.get_random_32bit_number()%7;
|
|
long range = rnd.get_random_32bit_number()%100;
|
|
matrix<T> cost = matrix_cast<T>(randm(size,size,rnd)*range) - range/2;
|
|
|
|
// use a uniform cost matrix sometimes
|
|
if ((rnd.get_random_32bit_number()%100) == 0)
|
|
cost = rnd.get_random_32bit_number()%100;
|
|
|
|
// negate the cost matrix every now and then
|
|
if ((rnd.get_random_32bit_number()%100) == 0)
|
|
cost = -cost;
|
|
|
|
|
|
std::vector<long> assign = brute_force_max_cost_assignment(cost);
|
|
T true_eval = assignment_cost(cost, assign);
|
|
assign = max_cost_assignment(cost);
|
|
DLIB_TEST(assignment_cost(cost,assign) == true_eval);
|
|
assign = max_cost_assignment(matrix_cast<signed char>(cost));
|
|
DLIB_TEST(assignment_cost(cost,assign) == true_eval);
|
|
|
|
|
|
cost = matrix_cast<T>(randm(size,size,rnd)*range);
|
|
assign = brute_force_max_cost_assignment(cost);
|
|
true_eval = assignment_cost(cost, assign);
|
|
assign = max_cost_assignment(cost);
|
|
DLIB_TEST(assignment_cost(cost,assign) == true_eval);
|
|
assign = max_cost_assignment(matrix_cast<unsigned char>(cost));
|
|
DLIB_TEST(assignment_cost(cost,assign) == true_eval);
|
|
assign = max_cost_assignment(matrix_cast<typename unsigned_type<T>::type>(cost));
|
|
DLIB_TEST(assignment_cost(cost,assign) == true_eval);
|
|
}
|
|
|
|
void perform_test (
|
|
)
|
|
{
|
|
for (long i = 0; i < 1000; ++i)
|
|
{
|
|
if ((i%100) == 0)
|
|
print_spinner();
|
|
|
|
test_hungarian<short>();
|
|
test_hungarian<int>();
|
|
test_hungarian<long>();
|
|
test_hungarian<int64>();
|
|
}
|
|
}
|
|
} a;
|
|
|
|
}
|
|
|
|
|
|
|