123 lines
4.0 KiB
C++
123 lines
4.0 KiB
C++
// Copyright (C) 2016 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_DNn_VALIDATION_H_
|
|
#define DLIB_DNn_VALIDATION_H_
|
|
|
|
#include "../svm/cross_validate_object_detection_trainer_abstract.h"
|
|
#include "../svm/cross_validate_object_detection_trainer.h"
|
|
#include "layers.h"
|
|
#include <set>
|
|
|
|
namespace dlib
|
|
{
|
|
namespace impl
|
|
{
|
|
inline std::set<std::string> get_labels (
|
|
const std::vector<mmod_rect>& rects1,
|
|
const std::vector<mmod_rect>& rects2
|
|
)
|
|
{
|
|
std::set<std::string> labels;
|
|
for (auto& rr : rects1)
|
|
labels.insert(rr.label);
|
|
for (auto& rr : rects2)
|
|
labels.insert(rr.label);
|
|
return labels;
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename SUBNET,
|
|
typename image_array_type
|
|
>
|
|
const matrix<double,1,3> test_object_detection_function (
|
|
loss_mmod<SUBNET>& detector,
|
|
const image_array_type& images,
|
|
const std::vector<std::vector<mmod_rect>>& truth_dets,
|
|
const test_box_overlap& overlap_tester = test_box_overlap(),
|
|
const double adjust_threshold = 0,
|
|
const test_box_overlap& overlaps_ignore_tester = test_box_overlap()
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_CASSERT( is_learning_problem(images,truth_dets) == true ,
|
|
"\t matrix test_object_detection_function()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
|
|
<< "\n\t images.size(): " << images.size()
|
|
);
|
|
|
|
|
|
|
|
double correct_hits = 0;
|
|
double total_true_targets = 0;
|
|
|
|
std::vector<std::pair<double,bool> > all_dets;
|
|
unsigned long missing_detections = 0;
|
|
|
|
resizable_tensor temp;
|
|
|
|
for (unsigned long i = 0; i < images.size(); ++i)
|
|
{
|
|
std::vector<mmod_rect> hits;
|
|
detector.to_tensor(&images[i], &images[i]+1, temp);
|
|
detector.subnet().forward(temp);
|
|
detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold);
|
|
|
|
|
|
for (auto& label : impl::get_labels(truth_dets[i], hits))
|
|
{
|
|
std::vector<full_object_detection> truth_boxes;
|
|
std::vector<rectangle> ignore;
|
|
std::vector<std::pair<double,rectangle>> boxes;
|
|
// copy hits and truth_dets into the above three objects
|
|
for (auto&& b : truth_dets[i])
|
|
{
|
|
if (b.ignore)
|
|
{
|
|
ignore.push_back(b);
|
|
}
|
|
else if (b.label == label)
|
|
{
|
|
truth_boxes.push_back(full_object_detection(b.rect));
|
|
++total_true_targets;
|
|
}
|
|
}
|
|
for (auto&& b : hits)
|
|
{
|
|
if (b.label == label)
|
|
boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
|
|
}
|
|
|
|
correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester);
|
|
}
|
|
}
|
|
|
|
std::sort(all_dets.rbegin(), all_dets.rend());
|
|
|
|
double precision, recall;
|
|
|
|
double total_hits = all_dets.size();
|
|
|
|
if (total_hits == 0)
|
|
precision = 1;
|
|
else
|
|
precision = correct_hits / total_hits;
|
|
|
|
if (total_true_targets == 0)
|
|
recall = 1;
|
|
else
|
|
recall = correct_hits / total_true_targets;
|
|
|
|
matrix<double, 1, 3> res;
|
|
res = precision, recall, average_precision(all_dets, missing_detections);
|
|
return res;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_DNn_VALIDATION_H_
|
|
|