API
 
Loading...
Searching...
No Matches
recursive_least_squares.cpp
Go to the documentation of this file.
2#include <string>
3
4namespace DDSPC
5{
6
7RecursiveLeastSquares::RecursiveLeastSquares(int num_predictors, int num_features, realT forgetting_factor, realT initial_covariance){
8
9 _gamma = forgetting_factor;
11
12 _initial_covariance = initial_covariance;
13 _num_features = num_features;
14 _num_predictors = num_predictors;
15
16 // Set size of all the arrays
18 prediction_matrix.setZero();
19
21 prediction_output.setZero();
22
24 inverse_covariance.setZero();
25 for(int i=0; i < _num_features; i++)
27
28 err.resize(_num_predictors, 1);
29 err.setZero();
30
31 K.resize(1, _num_features);
32 K.setZero();
33};
34
35
39
41 prediction_matrix.setZero();
42 err.setZero();
43 K.setZero();
44
46 inverse_covariance.setZero();
47 for(int i=0; i < _num_features; i++)
49}
50
51// I want to change this interface to make it easier to use.
52void RecursiveLeastSquares::update(eigenImage<realT> *x, eigenImage<realT> *y){
53 Matrix _x = (*x).matrix();
54 err = (*y).matrix();
55 err -= prediction_matrix * _x;
56
57 xtP = (_inverse_gamma * _x).transpose() * inverse_covariance;
58 realT cn = 1 + (xtP * _x)(0,0);
59 K = xtP;
60 K /= cn;
62
64 inverse_covariance -= K.transpose() * xtP;
65}
66
68 Matrix _x = (*x);
69 err = (*y);
70 err -= prediction_matrix * _x;
71
72 xtP = (_inverse_gamma * _x).transpose() * inverse_covariance;
73 realT cn = 1 + (xtP * _x)(0,0);
74 K = xtP;
75 K /= cn;
77
79 inverse_covariance -= K.transpose() * xtP;
80}
81
82
83// Matrix RecursiveLeastSquares::predict(eigenImage<realT> *x){
84// return prediction_matrix * (*x).matrix();
85//}
86
87void RecursiveLeastSquares::save_state(std::string filename){
88
89}
90
91}
void update(eigenImage< realT > *x, eigenImage< realT > *y)
void save_state(std::string filaname)
RecursiveLeastSquares(int num_predictors, int num_features, realT forgetting_factor, realT inverse_covariance)
Eigen::Matrix< realT, Eigen::Dynamic, Eigen::Dynamic > Matrix
Definition utils.hpp:12
float realT
Definition utils.hpp:11