-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsim_tolman_latent.m
114 lines (92 loc) · 2.36 KB
/
sim_tolman_latent.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
function h = sim_tolman_latent(plt_nr,plt_nc,plt_np)
do_plot = 1;
[mx_tll, tll_labels, TB, AA, cols, alf] = run_tolman_latent;
%--------------------------------------------------------------------------
if nargin<1
close all;
plt_nr = 1;
plt_nc = 2;
plt_np = [1 2];
fsiz = [0.1375 0.6907 0.45 0.2139];
figure; set(gcf,'units','normalized'); set(gcf,'position',fsiz);
end
%----------------------
fs = def('fs');
fn = def('fn');
fsy = def('fsy');
alf = def('alf');
fsA = def('fsA');
xsA = -.05 + def('xsA');
ysA = def('ysA');
abc = def('abc');
bw = .15;
cols = def('col');
cols = cols([2 3 1],:);
h(1) = subplot(plt_nr,plt_nc,plt_np(1));
imagesc(TB,'AlphaData',AA);
set(gca,'box','on','ytick',[],'xtick',[]);
% title('Toman maze');
h(2) = subplot(plt_nr,plt_nc,plt_np(2));
errorbarKxN(mx_tll',0*mx_tll',tll_labels,struct('colmap',cols,'barwidth',bw));
alpha(alf);
ylabel('Probability','fontsize',fsy);
% % % legend(tll_labels,'fontsize',fsy,'location','north','box','off');
hax=get(gca,'XAxis');
set(hax,'fontsize',fsy);
end
function [mx, tll_labels, TB2, AA, cols, alf] = run_tolman_latent
I = 9;
J = 9;
c0 = 1;
reward = [-5 -5];
blocked = [2:9 11:18 20:27 29:36 47:54 56:63 65:72 74:81];
terminals = [1 73];
cp = 37;
[P0, ~, lij] = core_griding(I,J);
% plot_grids(P0,[],I,J);
[lij0, P, c] = make_blocks(lij,P0,c0,blocked,terminals,reward);
cp = find(lij0(:,1) == cp);
U1 = core_lrl(P,c);
config.add_arrow = 0;
config.add_labels = 0;
figure;
[TB, AA] = plot_grids(U1,config,I,J,[],[],lij0); close;
TB2 = nan([size(TB),3]);
TB2(:,:,1) = TB;
% TB2(:,:,2) = TB*.3;
% TB2(:,:,3) = TB*.3;
aa = (TB==0).*(AA==.5); aa = aa==1;
[it,jt] = find(aa);
cols = def('col');cols(1,:) = [];
alf = .8;
TB2(it(1),jt(1),:) = cols(1,:);
TB2(it(2),jt(2),:) = cols(2,:);
AA(AA==1) = 0;
AA(aa) = alf;
AA(AA==.5) = .1;
pcp1 = U1(cp,:);
pcp1 = pcp1(pcp1>0);
%-----
reward(1) = -reward(1);
c(diag(P)==1) = reward;
U2 = core_lrl(P,c);
pcp2 = U2(cp,:);
pcp2 = pcp2(pcp2>0);
mx = [pcp1([1 3]); pcp2([1 3])];
tll_labels = {'Training','Test'};
end
function [lij, P, c] = make_blocks(lij,P,c,blocked,terminals,goal)
ns = size(P,1);
c = c*ones(ns,1);
c(terminals) = goal;
for i=1:length(terminals)
P(terminals(i),:) = 0; P(terminals(i),terminals(i)) = 1;
end
P(blocked,:) = [];
P(:,blocked) = [];
A = (P>0)+0.0;
D = diag(sum(A,2));
P = D^-1*A;
c(blocked) = [];
lij(blocked,:) = [];
end