-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-multinomial_2cat.R
70 lines (54 loc) · 1.66 KB
/
test-multinomial_2cat.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
set.seed(867-5309)
N <- 1000
aadm <- log(c(20, 80))
bua <- list(
c(-0.3, 0),
c(-0.3, 0))
message('N = ', N)
message('aadm: ', paste(aadm, collapse=', '))
message('bua1: ', paste(bua[[1]], collapse=', '))
message('bua2: ', paste(bua[[2]], collapse=', '))
### Random multinomial deviates, from logit values
## l: a matrix of logit values with N rows, k columns. Each row is softmax(p),
## where p is the vector of probabilities for a sample.
rmulti_logit <- function(N, l)
{
unprob <- exp(l)
sprob <- apply(unprob, 1, sum)
prob <- sweep(unprob, 1, sprob, '/') # N x k matrix of probs.
r <- runif(N)
rslt <- rep(NA_integer_, N)
for(k in seq(1L, ncol(l))) {
sel <- r > 0 & r <= prob[,k]
rslt[sel] <- k
r <- r - prob[,k]
}
stopifnot(!any(is.na(rslt)))
rslt
}
x <- rnorm(N)
ylogit <- t(sapply(x,
function(xx) {
if(xx > 0) {
aadm + xx * bua[[1]]
}
else {
aadm + xx * bua[[2]]
}
}))
stopifnot(dim(ylogit) == c(N, 2))
y <- rmulti_logit(N, ylogit)
yfl <- table(y[x < (-1)])/ sum(x < (-1))
yfh <- table(y[x > 1]) / sum(x > 1)
message('yfl: ', paste(yfl, collapse=', '))
message('yfh: ', paste(yfh, collapse=', '))
indata <- list(x=x, y=y, iy=y-1, N=N, nk=2)
cmdstan2rstan <- function(stanfile, indata) {
mod <- cmdstan_model(stanfile)
mx <- mod$sample(data=indata, seed=8675309, chains=4, parallel_chains=4)
rstan::read_stan_csv(mx$output_files())
}
multi_2 <- cmdstan2rstan('test-multi.stan', indata)
logr_2 <- cmdstan2rstan('test-bern.stan', indata)
print(multi_2)
print(logr_2)