395 lines
12 KiB
C++
395 lines
12 KiB
C++
// Copyright (C) 2007 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_MLp_KERNEL_1_
|
|
#define DLIB_MLp_KERNEL_1_
|
|
|
|
#include "../algs.h"
|
|
#include "../serialize.h"
|
|
#include "../matrix.h"
|
|
#include "../rand.h"
|
|
#include "mlp_kernel_abstract.h"
|
|
#include <ctime>
|
|
#include <sstream>
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
class mlp_kernel_1 : noncopyable
|
|
{
|
|
/*!
|
|
INITIAL VALUE
|
|
The network is initially initialized with random weights
|
|
|
|
CONVENTION
|
|
- input_layer_nodes() == input_nodes
|
|
- first_hidden_layer_nodes() == first_hidden_nodes
|
|
- second_hidden_layer_nodes() == second_hidden_nodes
|
|
- output_layer_nodes() == output_nodes
|
|
- get_alpha == alpha
|
|
- get_momentum() == momentum
|
|
|
|
|
|
- if (second_hidden_nodes == 0) then
|
|
- for all i and j:
|
|
- w1(i,j) == the weight on the link from node i in the first hidden layer
|
|
to input node j
|
|
- w3(i,j) == the weight on the link from node i in the output layer
|
|
to first hidden layer node j
|
|
- for all i and j:
|
|
- w1m == the momentum terms for w1 from the previous update
|
|
- w3m == the momentum terms for w3 from the previous update
|
|
- else
|
|
- for all i and j:
|
|
- w1(i,j) == the weight on the link from node i in the first hidden layer
|
|
to input node j
|
|
- w2(i,j) == the weight on the link from node i in the second hidden layer
|
|
to first hidden layer node j
|
|
- w3(i,j) == the weight on the link from node i in the output layer
|
|
to second hidden layer node j
|
|
- for all i and j:
|
|
- w1m == the momentum terms for w1 from the previous update
|
|
- w2m == the momentum terms for w2 from the previous update
|
|
- w3m == the momentum terms for w3 from the previous update
|
|
!*/
|
|
|
|
public:
|
|
|
|
mlp_kernel_1 (
|
|
long nodes_in_input_layer,
|
|
long nodes_in_first_hidden_layer,
|
|
long nodes_in_second_hidden_layer = 0,
|
|
long nodes_in_output_layer = 1,
|
|
double alpha_ = 0.1,
|
|
double momentum_ = 0.8
|
|
) :
|
|
input_nodes(nodes_in_input_layer),
|
|
first_hidden_nodes(nodes_in_first_hidden_layer),
|
|
second_hidden_nodes(nodes_in_second_hidden_layer),
|
|
output_nodes(nodes_in_output_layer),
|
|
alpha(alpha_),
|
|
momentum(momentum_)
|
|
{
|
|
|
|
// seed the random number generator
|
|
std::ostringstream sout;
|
|
sout << time(0);
|
|
rand_nums.set_seed(sout.str());
|
|
|
|
w1.set_size(first_hidden_nodes+1, input_nodes+1);
|
|
w1m.set_size(first_hidden_nodes+1, input_nodes+1);
|
|
z.set_size(input_nodes+1,1);
|
|
|
|
if (second_hidden_nodes != 0)
|
|
{
|
|
w2.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
|
|
w3.set_size(output_nodes, second_hidden_nodes+1);
|
|
|
|
w2m.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
|
|
w3m.set_size(output_nodes, second_hidden_nodes+1);
|
|
}
|
|
else
|
|
{
|
|
w3.set_size(output_nodes, first_hidden_nodes+1);
|
|
|
|
w3m.set_size(output_nodes, first_hidden_nodes+1);
|
|
}
|
|
|
|
reset();
|
|
}
|
|
|
|
virtual ~mlp_kernel_1 (
|
|
) {}
|
|
|
|
void reset (
|
|
)
|
|
{
|
|
// randomize the weights for the first layer
|
|
for (long r = 0; r < w1.nr(); ++r)
|
|
for (long c = 0; c < w1.nc(); ++c)
|
|
w1(r,c) = rand_nums.get_random_double();
|
|
|
|
// randomize the weights for the second layer
|
|
for (long r = 0; r < w2.nr(); ++r)
|
|
for (long c = 0; c < w2.nc(); ++c)
|
|
w2(r,c) = rand_nums.get_random_double();
|
|
|
|
// randomize the weights for the third layer
|
|
for (long r = 0; r < w3.nr(); ++r)
|
|
for (long c = 0; c < w3.nc(); ++c)
|
|
w3(r,c) = rand_nums.get_random_double();
|
|
|
|
// zero all the momentum terms
|
|
set_all_elements(w1m,0);
|
|
set_all_elements(w2m,0);
|
|
set_all_elements(w3m,0);
|
|
}
|
|
|
|
long input_layer_nodes (
|
|
) const { return input_nodes; }
|
|
|
|
long first_hidden_layer_nodes (
|
|
) const { return first_hidden_nodes; }
|
|
|
|
long second_hidden_layer_nodes (
|
|
) const { return second_hidden_nodes; }
|
|
|
|
long output_layer_nodes (
|
|
) const { return output_nodes; }
|
|
|
|
double get_alpha (
|
|
) const { return alpha; }
|
|
|
|
double get_momentum (
|
|
) const { return momentum; }
|
|
|
|
template <typename EXP>
|
|
const matrix<double> operator() (
|
|
const matrix_exp<EXP>& in
|
|
) const
|
|
{
|
|
for (long i = 0; i < in.nr(); ++i)
|
|
z(i) = in(i);
|
|
// insert the bias
|
|
z(z.nr()-1) = -1;
|
|
|
|
tmp1 = sigmoid(w1*z);
|
|
// insert the bias
|
|
tmp1(tmp1.nr()-1) = -1;
|
|
|
|
if (second_hidden_nodes == 0)
|
|
{
|
|
return sigmoid(w3*tmp1);
|
|
}
|
|
else
|
|
{
|
|
tmp2 = sigmoid(w2*tmp1);
|
|
// insert the bias
|
|
tmp2(tmp2.nr()-1) = -1;
|
|
|
|
return sigmoid(w3*tmp2);
|
|
}
|
|
}
|
|
|
|
template <typename EXP1, typename EXP2>
|
|
void train (
|
|
const matrix_exp<EXP1>& example_in,
|
|
const matrix_exp<EXP2>& example_out
|
|
)
|
|
{
|
|
for (long i = 0; i < example_in.nr(); ++i)
|
|
z(i) = example_in(i);
|
|
// insert the bias
|
|
z(z.nr()-1) = -1;
|
|
|
|
tmp1 = sigmoid(w1*z);
|
|
// insert the bias
|
|
tmp1(tmp1.nr()-1) = -1;
|
|
|
|
|
|
if (second_hidden_nodes == 0)
|
|
{
|
|
o = sigmoid(w3*tmp1);
|
|
|
|
// now compute the errors and propagate them backwards though the network
|
|
e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
|
|
e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w3)*e3 );
|
|
|
|
// compute the new weight updates
|
|
w3m = alpha * e3*trans(tmp1) + w3m*momentum;
|
|
w1m = alpha * e1*trans(z) + w1m*momentum;
|
|
|
|
// now update the weights
|
|
w1 += w1m;
|
|
w3 += w3m;
|
|
}
|
|
else
|
|
{
|
|
tmp2 = sigmoid(w2*tmp1);
|
|
// insert the bias
|
|
tmp2(tmp2.nr()-1) = -1;
|
|
|
|
o = sigmoid(w3*tmp2);
|
|
|
|
|
|
// now compute the errors and propagate them backwards though the network
|
|
e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
|
|
e2 = pointwise_multiply(tmp2, uniform_matrix<double>(second_hidden_nodes+1,1,1.0) - tmp2, trans(w3)*e3 );
|
|
e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w2)*e2 );
|
|
|
|
// compute the new weight updates
|
|
w3m = alpha * e3*trans(tmp2) + w3m*momentum;
|
|
w2m = alpha * e2*trans(tmp1) + w2m*momentum;
|
|
w1m = alpha * e1*trans(z) + w1m*momentum;
|
|
|
|
// now update the weights
|
|
w1 += w1m;
|
|
w2 += w2m;
|
|
w3 += w3m;
|
|
}
|
|
}
|
|
|
|
template <typename EXP>
|
|
void train (
|
|
const matrix_exp<EXP>& example_in,
|
|
double example_out
|
|
)
|
|
{
|
|
matrix<double,1,1> e_out;
|
|
e_out(0) = example_out;
|
|
train(example_in,e_out);
|
|
}
|
|
|
|
double get_average_change (
|
|
) const
|
|
{
|
|
// sum up all the weight changes
|
|
double delta = sum(abs(w1m)) + sum(abs(w2m)) + sum(abs(w3m));
|
|
|
|
// divide by the number of weights
|
|
delta /= w1m.nr()*w1m.nc() +
|
|
w2m.nr()*w2m.nc() +
|
|
w3m.nr()*w3m.nc();
|
|
|
|
return delta;
|
|
}
|
|
|
|
void swap (
|
|
mlp_kernel_1& item
|
|
)
|
|
{
|
|
exchange(input_nodes, item.input_nodes);
|
|
exchange(first_hidden_nodes, item.first_hidden_nodes);
|
|
exchange(second_hidden_nodes, item.second_hidden_nodes);
|
|
exchange(output_nodes, item.output_nodes);
|
|
exchange(alpha, item.alpha);
|
|
exchange(momentum, item.momentum);
|
|
|
|
w1.swap(item.w1);
|
|
w2.swap(item.w2);
|
|
w3.swap(item.w3);
|
|
|
|
w1m.swap(item.w1m);
|
|
w2m.swap(item.w2m);
|
|
w3m.swap(item.w3m);
|
|
|
|
// even swap the temporary matrices because this may ultimately result in
|
|
// fewer calls to new and delete.
|
|
e1.swap(item.e1);
|
|
e2.swap(item.e2);
|
|
e3.swap(item.e3);
|
|
z.swap(item.z);
|
|
tmp1.swap(item.tmp1);
|
|
tmp2.swap(item.tmp2);
|
|
o.swap(item.o);
|
|
}
|
|
|
|
|
|
friend void serialize (
|
|
const mlp_kernel_1& item,
|
|
std::ostream& out
|
|
);
|
|
|
|
friend void deserialize (
|
|
mlp_kernel_1& item,
|
|
std::istream& in
|
|
);
|
|
|
|
private:
|
|
|
|
long input_nodes;
|
|
long first_hidden_nodes;
|
|
long second_hidden_nodes;
|
|
long output_nodes;
|
|
double alpha;
|
|
double momentum;
|
|
|
|
matrix<double> w1;
|
|
matrix<double> w2;
|
|
matrix<double> w3;
|
|
|
|
matrix<double> w1m;
|
|
matrix<double> w2m;
|
|
matrix<double> w3m;
|
|
|
|
|
|
rand rand_nums;
|
|
|
|
// temporary storage
|
|
mutable matrix<double> e1, e2, e3;
|
|
mutable matrix<double> z, tmp1, tmp2, o;
|
|
};
|
|
|
|
inline void swap (
|
|
mlp_kernel_1& a,
|
|
mlp_kernel_1& b
|
|
) { a.swap(b); }
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
inline void serialize (
|
|
const mlp_kernel_1& item,
|
|
std::ostream& out
|
|
)
|
|
{
|
|
try
|
|
{
|
|
serialize(item.input_nodes, out);
|
|
serialize(item.first_hidden_nodes, out);
|
|
serialize(item.second_hidden_nodes, out);
|
|
serialize(item.output_nodes, out);
|
|
serialize(item.alpha, out);
|
|
serialize(item.momentum, out);
|
|
|
|
serialize(item.w1, out);
|
|
serialize(item.w2, out);
|
|
serialize(item.w3, out);
|
|
|
|
serialize(item.w1m, out);
|
|
serialize(item.w2m, out);
|
|
serialize(item.w3m, out);
|
|
}
|
|
catch (serialization_error& e)
|
|
{
|
|
throw serialization_error(e.info + "\n while serializing object of type mlp_kernel_1");
|
|
}
|
|
}
|
|
|
|
inline void deserialize (
|
|
mlp_kernel_1& item,
|
|
std::istream& in
|
|
)
|
|
{
|
|
try
|
|
{
|
|
deserialize(item.input_nodes, in);
|
|
deserialize(item.first_hidden_nodes, in);
|
|
deserialize(item.second_hidden_nodes, in);
|
|
deserialize(item.output_nodes, in);
|
|
deserialize(item.alpha, in);
|
|
deserialize(item.momentum, in);
|
|
|
|
deserialize(item.w1, in);
|
|
deserialize(item.w2, in);
|
|
deserialize(item.w3, in);
|
|
|
|
deserialize(item.w1m, in);
|
|
deserialize(item.w2m, in);
|
|
deserialize(item.w3m, in);
|
|
|
|
item.z.set_size(item.input_nodes+1,1);
|
|
}
|
|
catch (serialization_error& e)
|
|
{
|
|
// give item a reasonable value since the deserialization failed
|
|
mlp_kernel_1(1,1).swap(item);
|
|
throw serialization_error(e.info + "\n while deserializing object of type mlp_kernel_1");
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_MLp_KERNEL_1_
|
|
|