Transforming User Code To Stan Math Specialized Functions? #879
-
Say a user writes parameters {
vector[N] B;
matrix[N, N] A;
}
transformed parameters {
real quaded = B' * A * B;
} In Stan math we have a specialized function real quaded = quad_form(A, B); And if we know that the matrix is symmetric, for instance in the case of parameters {
vector[N] B;
cov_matrix[N] A;
}
transformed parameters {
real quaded = B' * A * B;
}
real quaded = quad_form_sym(A, B); It would be nice for users if they could just write their math then in the compiler we can check if we have specializations for certain types of statements. There's actually a lot of these, like for example a dot product of a vector with itself can use for (i in 1:N) {
a[i] = b[i] * c[i] + d[i].
} we would spit out a = fma(b, c, d); I'm not sure how to go about this though. It feels like what I'd like to do is look over the list of statements in a = b * c + d + e * f + g; which can become a = fma(b, c, d + fma(e, f, g)); The mix of #877 and this stuff here is pretty common inside of brms. Like for a random effects model brms prints the following // add more terms to the linear predictor
for (n in 1:N) {
mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_2_1[J_2[n]] * Z_2_1[n] + r_2_2[J_2[n]] * Z_2_2[n];
} It would be a bit more advanced but we could turn the above into something like mu += fma(r_1_1[J_1], Z_1_1, fma(r_2_1[J_2], Z_2_1, r_2_2[J_2] * Z_2_2)); |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Your first data { int N; }
parameters {
vector[N] B;
matrix[N, N] A;
}
transformed parameters {
real quaded = transpose(B) * A * B;
} This is implemented in the optimization suite stanc3/src/analysis_and_optimization/Partial_evaluator.ml Lines 595 to 603 in 8626779 It doesn't work with your original example because apparently the operator ' is not the same thing as the function transpose and the above match fails to recognize it.
The partial evaluator also knows how to create for (i in 1:N)
a[i] = fma(b[i], c[i], d[i]); There are many ways to regroup a = fma(e, f, fma(b, c, d)) + g; Extending the partial evaluator to produce |
Beta Was this translation helpful? Give feedback.
Your first
quad_form
example almost works already. Compile the following model withstanc --O
and you'll see aquad_form
call in the C++.This is implemented in the optimization suite
stanc3/src/analysis_and_optimization/Partial_evaluator.ml
Lines 595 to 603 in 8626779
It doesn't work with your original example becaus…