@@ -118,36 +118,37 @@ arma::mat construct_zero_restriction(const NumericMatrix::ConstColumn& spec) {
118
118
return zero_restriction_matrix;
119
119
}
120
120
121
- vec calculate_sign_restriction_scores (
122
- const NumericMatrix::ConstColumn& spec, // rows: transformation size
123
- const mat& rotated_params // rows: transformation size, cols: candidates for p_j
124
- ) {
125
- vec scores (rotated_params.n_cols );
126
- for (uword i = 0 ; i < rotated_params.n_rows ; i++) {
127
- if (NumericMatrix::is_na (spec[i]) || spec[i] == 0 )
128
- continue ;
129
- for (uword j = 0 ; j < rotated_params.n_cols ; j++) {
130
- scores[j] += std::copysign (spec[i], rotated_params (i, j) * spec[i]);
131
- // if the rotated params and its sign specification spec[i] have the same sign
132
- // the score increases by spec[i] otherwise the score will decrease by that amount.
133
- // if the score is negative overall, we will choose -p_j instead of p_j.
121
+ arma::mat construct_sign_restriction (const NumericMatrix::ConstColumn& spec) {
122
+ const uword n_sign_restrictions = count (spec != 0 );
123
+
124
+ mat sign_restriction_matrix (n_sign_restrictions, spec.size ());
125
+ uword i = 0 ;
126
+ for (R_xlen_t j = 0 ; j < spec.size (); j++) {
127
+ if ((spec[j] > 0 ) == TRUE ) {
128
+ sign_restriction_matrix (i++, j) = 1 ;
129
+ }
130
+ else if ((spec[j] < 0 ) == TRUE ) {
131
+ sign_restriction_matrix (i++, j) = -1 ;
134
132
}
135
133
}
136
- return scores ;
134
+ return sign_restriction_matrix ;
137
135
}
138
136
139
-
140
137
// [[Rcpp::export]]
141
138
arma::cube find_rotation_cpp (
142
139
const arma::field<arma::cube>& parameter_transformations, // each field element: rows: transformation size, cols: variables, slices: draws
143
140
const arma::field<Rcpp::NumericMatrix>& restriction_specs, // each field element: rows: transformation size, cols: variables
144
- const double tolerance = 0.0
141
+ const double tolerance = 0.0 ,
142
+ const double sign_epsilon = 1e-6
145
143
) {
146
144
// algorithm from RUBIO-RAMÍREZ ET AL. (doi: 10.1111/j.1467-937X.2009.00578.x)
147
145
148
146
if (restriction_specs.n_elem != parameter_transformations.n_elem ) {
149
147
throw std::logic_error (" Number of restrictions does not match number of parameter transformations." );
150
148
}
149
+ if (!(sign_epsilon > 0 )) {
150
+ throw std::logic_error (" sign_epsilon must be a positive number" );
151
+ }
151
152
152
153
const uword n_variables = parameter_transformations (0 ).n_cols ;
153
154
const uword n_posterior_draws = parameter_transformations (0 ).n_slices ;
@@ -156,46 +157,44 @@ arma::cube find_rotation_cpp(
156
157
// field rows: tranformations, field cols: cols of the transformation
157
158
// each field element: rows: number of restrictions, cols: transformation size
158
159
arma::field<arma::mat> zero_restrictions (restriction_specs.n_elem , n_variables);
159
- arma::vec n_sign_restrictions (n_variables);
160
+ arma::field<arma::mat> sign_restrictions (restriction_specs.n_elem , n_variables);
161
+ arma::uvec n_sign_restrictions (n_variables);
160
162
for (uword i = 0 ; i < restriction_specs.n_elem ; i++) {
161
163
for (uword j = 0 ; j < n_variables; j++) {
162
164
const NumericMatrix::ConstColumn column_restriction_spec = restriction_specs (i).column (j);
163
165
zero_restrictions (i, j) = construct_zero_restriction (column_restriction_spec);
164
- n_sign_restrictions (j) += count (column_restriction_spec != 0 );
166
+ sign_restrictions (i, j) = construct_sign_restriction (column_restriction_spec);
167
+ n_sign_restrictions (j) += sign_restrictions (i, j).n_rows ;
165
168
}
166
169
}
167
170
168
171
for (uword r = 0 ; r < n_posterior_draws; r++) {
169
172
for (uword j = 0 ; j < n_variables; j++) {
170
- arma::mat Q_tilde (rotation.slice (r).head_cols (j).t ());
173
+ arma::mat Q_zero (rotation.slice (r).head_cols (j).t ());
171
174
for (uword i = 0 ; i < parameter_transformations.n_elem ; i++) {
172
- Q_tilde .insert_rows (0 , zero_restrictions (i, j) * parameter_transformations (i).slice (r));
175
+ Q_zero .insert_rows (0 , zero_restrictions (i, j) * parameter_transformations (i).slice (r));
173
176
}
174
- arma::mat nullspace_Q_tilde = arma::null (Q_tilde , tolerance);
175
- if (nullspace_Q_tilde .n_cols == 0 ) {
176
- throw std::logic_error (" Could not satisfy restrictions. Increase the tolerance for approximate results." );
177
+ arma::mat nullspace_Q_zero = arma::null (Q_zero , tolerance);
178
+ if (nullspace_Q_zero .n_cols == 0 ) {
179
+ throw std::logic_error (" Could not satisfy zero restrictions. Increase the tolerance for approximate results." );
177
180
}
178
181
179
182
colvec p_j;
180
183
if (n_sign_restrictions (j) > 0 ) {
181
- // find the vector in the nullspace of Q which scores best in the sign restrictions
182
- vec sign_restriction_scores (nullspace_Q_tilde. n_cols , arma::fill::zeros );
184
+ // find the vector in the nullspace of Q_zero which satisfies the sign restrictions
185
+ arma::mat Q_sign ( 0 , n_variables );
183
186
for (uword i = 0 ; i < parameter_transformations.n_elem ; i++) {
184
- const NumericMatrix::ConstColumn column_restriction_spec = restriction_specs (i).column (j);
185
- const mat rotated_params = parameter_transformations (i).slice (r) * nullspace_Q_tilde;
186
- sign_restriction_scores += calculate_sign_restriction_scores (column_restriction_spec, rotated_params);
187
- }
188
- uword index_of_best_score = abs (sign_restriction_scores).index_max ();
189
- p_j = nullspace_Q_tilde.col (index_of_best_score);
190
- if (sign_restriction_scores[index_of_best_score] < 0 ) {
191
- p_j = -p_j;
187
+ Q_sign.insert_rows (0 , sign_restrictions (i, j) * parameter_transformations (i).slice (r));
192
188
}
189
+ const vec small_positive_vector (Q_sign.n_rows , arma::fill::value (sign_epsilon));
190
+ p_j = nullspace_Q_zero * arma::solve (Q_sign * nullspace_Q_zero, small_positive_vector);
191
+ p_j = normalise (p_j);
193
192
}
194
193
else {
195
194
// any vector from the null space is fine
196
195
// vector with corresponding to the smallest singular value should be the last one
197
196
// however this is an not guranteed by the public armadillo API!
198
- p_j = nullspace_Q_tilde .col (nullspace_Q_tilde .n_cols - 1 );
197
+ p_j = nullspace_Q_zero .col (nullspace_Q_zero .n_cols - 1 );
199
198
}
200
199
201
200
rotation.slice (r).col (j) = p_j;
0 commit comments