Skip to content

Commit e9e5da0

Browse files
authored
Better initialization for implicit biases (#66)
* better initialization for implicit biases * remove redundant calculations * avoid unneeded initialization * fix failing tests * fix bias init for rows/columns with all-missing values * remove unneeded header * correct formula for bias initialization * spacing * another correction
1 parent 84690b8 commit e9e5da0

File tree

5 files changed

+85
-54
lines changed

5 files changed

+85
-54
lines changed

R/RcppExports.R

+4-4
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,11 @@ als_implicit_float <- function(m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver,
101101
.Call(`_rsparse_als_implicit_float`, m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver, cg_steps, with_biases, is_x_bias_last_row, global_bias, global_bias_base_, initialize_bias_base)
102102
}
103103

104-
initialize_biases_double <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias = FALSE, is_explicit_feedback = FALSE, initialize_item_biases = FALSE) {
105-
.Call(`_rsparse_initialize_biases_double`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback, initialize_item_biases)
104+
initialize_biases_double <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias = FALSE, is_explicit_feedback = FALSE) {
105+
.Call(`_rsparse_initialize_biases_double`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback)
106106
}
107107

108-
initialize_biases_float <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias = FALSE, is_explicit_feedback = FALSE, initialize_item_biases = FALSE) {
109-
.Call(`_rsparse_initialize_biases_float`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback, initialize_item_biases)
108+
initialize_biases_float <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias = FALSE, is_explicit_feedback = FALSE) {
109+
.Call(`_rsparse_initialize_biases_float`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback)
110110
}
111111

R/model_WRMF.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ WRMF = R6::R6Class(
153153
initialize_biases_double,
154154
initialize_biases_float)
155155
FUN(c_ui, c_iu, user_bias, item_bias, private$lambda, private$dynamic_lambda,
156-
private$non_negative, private$with_global_bias, feedback == "explicit",
157-
private$solver_code != 1)
156+
private$non_negative, private$with_global_bias, feedback == "explicit")
158157
}
159158

160159
self$components = init

inst/include/wrmf_utils.hpp

+68-31
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ double initialize_biases_explicit(dMappedCSC& ConfCSC, // modified in place
6262
}
6363
item_bias[col] /=
6464
lambda_use + static_cast<T>(ConfCSC.col_ptrs[col + 1] - ConfCSC.col_ptrs[col]);
65-
if (non_negative) item_bias[col] = std::fmax(0., item_bias[col]);
65+
if (non_negative) item_bias[col] = std::fmax((T)0, item_bias[col]);
6666
}
6767

6868
user_bias.zeros();
@@ -75,7 +75,7 @@ double initialize_biases_explicit(dMappedCSC& ConfCSC, // modified in place
7575
}
7676
user_bias[row] /=
7777
lambda_use + static_cast<T>(ConfCSR.col_ptrs[row + 1] - ConfCSR.col_ptrs[row]);
78-
if (non_negative) user_bias[row] = std::fmax(0., user_bias[row]);
78+
if (non_negative) user_bias[row] = std::fmax((T)0, user_bias[row]);
7979
}
8080
}
8181
return global_bias;
@@ -84,8 +84,7 @@ double initialize_biases_explicit(dMappedCSC& ConfCSC, // modified in place
8484
template <class T>
8585
double initialize_biases_implicit(dMappedCSC& ConfCSC, dMappedCSC& ConfCSR,
8686
arma::Col<T>& user_bias, arma::Col<T>& item_bias,
87-
T lambda, bool calculate_global_bias, bool non_negative,
88-
const bool initialize_item_biases)
87+
T lambda, bool calculate_global_bias, bool non_negative)
8988
{
9089
double global_bias = 0;
9190
if (calculate_global_bias) {
@@ -94,35 +93,74 @@ double initialize_biases_implicit(dMappedCSC& ConfCSC, dMappedCSC& ConfCSR,
9493
}
9594
if (non_negative) global_bias = std::fmax(0., global_bias); /* <- should not happen, but just in case */
9695

97-
user_bias.zeros();
98-
item_bias.zeros();
96+
const int n_users = ConfCSR.n_cols;
97+
const int n_items = ConfCSR.n_rows;
98+
std::vector<double> user_means(n_users);
99+
std::vector<double> item_means(n_items);
100+
std::vector<double> user_adjustment(n_users);
101+
std::vector<double> item_adjustment(n_items);
102+
for (int row = 0; row < n_users; row++) {
103+
if (ConfCSR.col_ptrs[row + 1] > ConfCSR.col_ptrs[row]) {
104+
for (int ix = ConfCSR.col_ptrs[row]; ix < ConfCSR.col_ptrs[row + 1]; ix++)
105+
user_adjustment[row] += ConfCSR.values[ix];
106+
user_means[row] = user_adjustment[row] / (user_adjustment[row] + (double)(n_items - (ConfCSR.col_ptrs[row + 1] - ConfCSR.col_ptrs[row])));
107+
user_adjustment[row] += (double)(n_items - (ConfCSR.col_ptrs[row + 1] - ConfCSR.col_ptrs[row]));
108+
user_adjustment[row] /= user_adjustment[row] + lambda;
109+
} else {
110+
user_means[row] = 0;
111+
user_adjustment[row] = (double)n_items / ((double)n_items + lambda);
112+
}
113+
}
114+
for (int col = 0; col < n_items; col++) {
115+
if (ConfCSC.col_ptrs[col + 1] > ConfCSC.col_ptrs[col]) {
116+
for (int ix = ConfCSC.col_ptrs[col]; ix < ConfCSC.col_ptrs[col + 1]; ix++)
117+
item_adjustment[col] += ConfCSC.values[ix];
118+
item_means[col] = item_adjustment[col] / (item_adjustment[col] + (double)(n_users - (ConfCSC.col_ptrs[col + 1] - ConfCSC.col_ptrs[col])));
119+
item_adjustment[col] += (double)(n_users - (ConfCSC.col_ptrs[col + 1] - ConfCSC.col_ptrs[col]));
120+
item_adjustment[col] /= item_adjustment[col] + lambda;
121+
} else {
122+
item_means[col] = 0;
123+
item_adjustment[col] = (double)n_users / ((double)n_users + lambda);
124+
}
125+
}
99126

100-
double sweight;
101-
const double n_items = ConfCSR.n_rows;
102127

103-
for (int row = 0; row < ConfCSR.n_cols; row++) {
104-
sweight = 0;
105-
for (int ix = ConfCSR.col_ptrs[row]; ix < ConfCSR.col_ptrs[row + 1]; ix++) {
106-
user_bias[row] += ConfCSR.values[ix] + global_bias * (1. - ConfCSR.values[ix]);
107-
sweight += ConfCSR.values[ix] - 1.;
128+
double bias_mean;
129+
double bias_this;
130+
double wsum;
131+
for (int iter = 0; iter < 5; iter++) {
132+
/* item biases */
133+
bias_mean = 0;
134+
if (iter > 0) {
135+
for (int row = 0; row < n_users; row++)
136+
bias_mean += (user_bias[row] - bias_mean) / (T)(row + 1);
108137
}
109-
user_bias[row] -= global_bias * n_items;
110-
user_bias[row] /= sweight + n_items + lambda;
111-
user_bias[row] /= 3; /* <- item biases are unaccounted for, don't want to assign everything to the user */
112-
if (non_negative) user_bias[row] = std::fmax(0., user_bias[row]);
113-
}
138+
for (int col = 0; col < n_items; col++) {
139+
wsum = n_users;
140+
bias_this = bias_mean;
141+
for (int ix = ConfCSC.col_ptrs[col]; ix < ConfCSC.col_ptrs[col + 1]; ix++)
142+
bias_this += ((ConfCSC.values[ix] - 1) * (user_bias[ConfCSC.row_indices[ix]] - bias_this)) / (wsum += (ConfCSC.values[ix] - 1));
143+
item_bias[col] = (item_means[col] - bias_this - global_bias) * item_adjustment[col];
144+
}
145+
146+
if (non_negative)
147+
for (int col = 0; col < n_items; col++) item_bias[col] = std::fmax((T)0, item_bias[col]);
114148

115-
const double n_users = ConfCSC.n_rows;
116-
for (int col = 0; col < ConfCSC.n_cols; col++) {
117-
sweight = 0;
118-
for (int ix = ConfCSC.col_ptrs[col]; ix < ConfCSC.col_ptrs[col + 1]; ix++) {
119-
item_bias[col] += ConfCSC.values[ix] + global_bias * (1. - ConfCSC.values[ix]);
120-
sweight += ConfCSC.values[ix] - 1.;
149+
/* user biases */
150+
bias_mean = 0;
151+
for (int col = 0; col < n_items; col++)
152+
bias_mean += (item_bias[col] - bias_mean) / (T)(col + 1);
153+
154+
for (int row = 0; row < n_users; row++) {
155+
wsum = n_items;
156+
bias_this = bias_mean;
157+
for (int ix = ConfCSR.col_ptrs[row]; ix < ConfCSR.col_ptrs[row + 1]; ix++)
158+
bias_this += ((ConfCSR.values[ix] - 1) * (item_bias[ConfCSR.row_indices[ix]] - bias_this)) / (wsum += (ConfCSR.values[ix] - 1));
159+
user_bias[row] = (user_means[row] - bias_this - global_bias) * user_adjustment[row];
121160
}
122-
item_bias[col] -= global_bias * n_users;
123-
item_bias[col] /= sweight + n_users + lambda;
124-
item_bias[col] /= 3; /* <- user biases are unaccounted for */
125-
if (non_negative) item_bias[col] = std::fmax(0., item_bias[col]);
161+
162+
if (non_negative)
163+
for (int row = 0; row < n_users; row++) user_bias[row] = std::fmax((T)0, user_bias[row]);
126164
}
127165

128166
return global_bias;
@@ -134,13 +172,12 @@ double initialize_biases(dMappedCSC& ConfCSC, // modified in place
134172
dMappedCSC& ConfCSR, // modified in place
135173
arma::Col<T>& user_bias, arma::Col<T>& item_bias, T lambda,
136174
bool dynamic_lambda, bool non_negative,
137-
bool calculate_global_bias, bool is_explicit_feedback,
138-
const bool initialize_item_biases) {
175+
bool calculate_global_bias, bool is_explicit_feedback) {
139176
if (is_explicit_feedback)
140177
return initialize_biases_explicit(ConfCSC, ConfCSR, user_bias, item_bias,
141178
lambda, dynamic_lambda, non_negative,
142179
calculate_global_bias);
143180
else
144181
return initialize_biases_implicit(ConfCSC, ConfCSR, user_bias, item_bias, lambda,
145-
calculate_global_bias,non_negative, initialize_item_biases);
182+
calculate_global_bias,non_negative);
146183
}

src/RcppExports.cpp

+8-10
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ BEGIN_RCPP
410410
END_RCPP
411411
}
412412
// initialize_biases_double
413-
double initialize_biases_double(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, arma::Col<double>& user_bias, arma::Col<double>& item_bias, double lambda, bool dynamic_lambda, bool non_negative, bool calculate_global_bias, bool is_explicit_feedback, const bool initialize_item_biases);
414-
RcppExport SEXP _rsparse_initialize_biases_double(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP dynamic_lambdaSEXP, SEXP non_negativeSEXP, SEXP calculate_global_biasSEXP, SEXP is_explicit_feedbackSEXP, SEXP initialize_item_biasesSEXP) {
413+
double initialize_biases_double(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, arma::Col<double>& user_bias, arma::Col<double>& item_bias, double lambda, bool dynamic_lambda, bool non_negative, bool calculate_global_bias, bool is_explicit_feedback);
414+
RcppExport SEXP _rsparse_initialize_biases_double(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP dynamic_lambdaSEXP, SEXP non_negativeSEXP, SEXP calculate_global_biasSEXP, SEXP is_explicit_feedbackSEXP) {
415415
BEGIN_RCPP
416416
Rcpp::RObject rcpp_result_gen;
417417
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -424,14 +424,13 @@ BEGIN_RCPP
424424
Rcpp::traits::input_parameter< bool >::type non_negative(non_negativeSEXP);
425425
Rcpp::traits::input_parameter< bool >::type calculate_global_bias(calculate_global_biasSEXP);
426426
Rcpp::traits::input_parameter< bool >::type is_explicit_feedback(is_explicit_feedbackSEXP);
427-
Rcpp::traits::input_parameter< const bool >::type initialize_item_biases(initialize_item_biasesSEXP);
428-
rcpp_result_gen = Rcpp::wrap(initialize_biases_double(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback, initialize_item_biases));
427+
rcpp_result_gen = Rcpp::wrap(initialize_biases_double(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback));
429428
return rcpp_result_gen;
430429
END_RCPP
431430
}
432431
// initialize_biases_float
433-
double initialize_biases_float(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, Rcpp::S4& user_bias, Rcpp::S4& item_bias, double lambda, bool dynamic_lambda, bool non_negative, bool calculate_global_bias, bool is_explicit_feedback, const bool initialize_item_biases);
434-
RcppExport SEXP _rsparse_initialize_biases_float(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP dynamic_lambdaSEXP, SEXP non_negativeSEXP, SEXP calculate_global_biasSEXP, SEXP is_explicit_feedbackSEXP, SEXP initialize_item_biasesSEXP) {
432+
double initialize_biases_float(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, Rcpp::S4& user_bias, Rcpp::S4& item_bias, double lambda, bool dynamic_lambda, bool non_negative, bool calculate_global_bias, bool is_explicit_feedback);
433+
RcppExport SEXP _rsparse_initialize_biases_float(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP dynamic_lambdaSEXP, SEXP non_negativeSEXP, SEXP calculate_global_biasSEXP, SEXP is_explicit_feedbackSEXP) {
435434
BEGIN_RCPP
436435
Rcpp::RObject rcpp_result_gen;
437436
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -444,8 +443,7 @@ BEGIN_RCPP
444443
Rcpp::traits::input_parameter< bool >::type non_negative(non_negativeSEXP);
445444
Rcpp::traits::input_parameter< bool >::type calculate_global_bias(calculate_global_biasSEXP);
446445
Rcpp::traits::input_parameter< bool >::type is_explicit_feedback(is_explicit_feedbackSEXP);
447-
Rcpp::traits::input_parameter< const bool >::type initialize_item_biases(initialize_item_biasesSEXP);
448-
rcpp_result_gen = Rcpp::wrap(initialize_biases_float(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback, initialize_item_biases));
446+
rcpp_result_gen = Rcpp::wrap(initialize_biases_float(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback));
449447
return rcpp_result_gen;
450448
END_RCPP
451449
}
@@ -476,8 +474,8 @@ static const R_CallMethodDef CallEntries[] = {
476474
{"_rsparse_als_explicit_float", (DL_FUNC) &_rsparse_als_explicit_float, 11},
477475
{"_rsparse_als_implicit_double", (DL_FUNC) &_rsparse_als_implicit_double, 13},
478476
{"_rsparse_als_implicit_float", (DL_FUNC) &_rsparse_als_implicit_float, 13},
479-
{"_rsparse_initialize_biases_double", (DL_FUNC) &_rsparse_initialize_biases_double, 10},
480-
{"_rsparse_initialize_biases_float", (DL_FUNC) &_rsparse_initialize_biases_float, 10},
477+
{"_rsparse_initialize_biases_double", (DL_FUNC) &_rsparse_initialize_biases_double, 9},
478+
{"_rsparse_initialize_biases_float", (DL_FUNC) &_rsparse_initialize_biases_float, 9},
481479
{NULL, NULL, 0}
482480
};
483481

src/wrmf_init.cpp

+4-7
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,20 @@ double initialize_biases_double(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r
88
arma::Col<double>& item_bias, double lambda,
99
bool dynamic_lambda, bool non_negative,
1010
bool calculate_global_bias = false,
11-
bool is_explicit_feedback = false,
12-
const bool initialize_item_biases = false) {
11+
bool is_explicit_feedback = false) {
1312
dMappedCSC ConfCSC = extract_mapped_csc(m_csc_r);
1413
dMappedCSC ConfCSR = extract_mapped_csc(m_csr_r);
1514
return initialize_biases<double>(ConfCSC, ConfCSR, user_bias, item_bias, lambda,
1615
dynamic_lambda, non_negative, calculate_global_bias,
17-
is_explicit_feedback, initialize_item_biases);
16+
is_explicit_feedback);
1817
}
1918

2019
// [[Rcpp::export]]
2120
double initialize_biases_float(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r,
2221
Rcpp::S4& user_bias, Rcpp::S4& item_bias, double lambda,
2322
bool dynamic_lambda, bool non_negative,
2423
bool calculate_global_bias = false,
25-
bool is_explicit_feedback = false,
26-
const bool initialize_item_biases = false) {
24+
bool is_explicit_feedback = false) {
2725
dMappedCSC ConfCSC = extract_mapped_csc(m_csc_r);
2826
dMappedCSC ConfCSR = extract_mapped_csc(m_csr_r);
2927

@@ -32,6 +30,5 @@ double initialize_biases_float(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r,
3230

3331
return initialize_biases<float>(ConfCSC, ConfCSR, user_bias_arma, item_bias_arma,
3432
lambda, dynamic_lambda, non_negative,
35-
calculate_global_bias, is_explicit_feedback,
36-
initialize_item_biases);
33+
calculate_global_bias, is_explicit_feedback);
3734
}

0 commit comments

Comments
 (0)