forked from AbsInt/CompCert
-
Notifications
You must be signed in to change notification settings - Fork 1
/
DenotationalSimulationChange.v
347 lines (304 loc) · 12.7 KB
/
DenotationalSimulationChange.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
(* Denotational Proofs for Reparameterization
The reparameterization transform remaps parameters to be
unconstrained and adds a jacobian correction factor to the target.
The proof here takes as input a forward simulation between a
program p and its transformed version tp, in which the parameters
have been remapped and the target has been corrected by a jacobian
by the time the program returns.
From this it proves that the probability distributions of p and tp
are equivalent. The core of the proof is a repeated change of
variables theorem showing that the integrals that define the
probability distribution are equal.
The change of variables theorem we use ends up causing us to
require that the reparameterization transform be monotone
increasing (see assumption gs_monotone), which turns out to be true
for the reparam transforms ProbCompCert uses, though it is not true
for Stan's transforms for some of the same constraints.
TODO: When establishing and using these lemmas, it's very easy to
confuse oneself about the direction of param_map/param_unmap and
target_map/target_unmap which are the functions that convert
between constrained/unconstrained target values and the
corrected/uncorrected version of the target value. Could we use a
richer type on these to avoid that?
*)
Require Import Coqlib Errors Maps String.
Local Open Scope string_scope.
Require Import Integers Floats Values AST Memory Builtins Events Globalenvs.
Require Import Ctypes Cop Stanlight.
Require Import Smallstep.
Require Import Linking.
Require Import IteratedRInt.
Require Vector.
Require Import Clightdefs.
Import Clightdefs.ClightNotations.
Local Open Scope clight_scope.
Require ClassicalEpsilon.
Require Import Reals.
Require Import StanEnv.
From Coq Require Import Reals Psatz ssreflect ssrbool Utf8.
Require Import Ssemantics.
Section DENOTATIONAL_SIMULATION.
Variable prog: Stanlight.program.
Variable tprog: Stanlight.program.
(* unconstrain *)
Variable param_unmap : list R -> list R.
(* constrain *)
Variable param_map : list R -> list R.
(* Input:
- data
- unconstrained parameters
- original target
Output:
- corrected target *)
Variable target_map : list val -> list val -> R -> R.
Variable target_unmap : list val -> list val -> R -> R.
Variable target_map_unmap : ∀ d p x, target_map d p (target_unmap d p x) = x.
Lemma inhabited_initial :
∀ data params t, is_safe prog data params -> ∃ s, Smallstep.initial_state (semantics prog data params t) s.
Proof.
intros data params t Hsafe. destruct Hsafe as (Hex&_). eapply Hex.
Qed.
Variable dimen_preserved: parameter_dimension tprog = parameter_dimension prog.
Variable wf_paramter_rect_tprog :
wf_rectangle_list (parameter_list_rect prog) ->
wf_rectangle_list (parameter_list_rect tprog).
(* Param map/unmap are inverses on the parameter rectangle *)
Variable param_map_unmap :
∀ x, in_list_rectangle x (parameter_list_rect prog) ->
param_map (param_unmap x) = x.
Variable param_unmap_map :
∀ x,
wf_rectangle_list (parameter_list_rect prog) ->
in_list_rectangle x (parameter_list_rect tprog) ->
param_unmap (param_map x) = x.
Variable param_unmap_in_dom :
∀ x, in_list_rectangle x (parameter_list_rect prog) ->
in_list_rectangle (param_unmap x) (parameter_list_rect tprog).
Variable param_map_in_dom :
∀ x,
wf_rectangle_list (parameter_list_rect prog) ->
in_list_rectangle x (parameter_list_rect tprog) ->
in_list_rectangle (param_map x) (parameter_list_rect prog).
(* This is the key forward simulation assumption that the proof will use. *)
Variable transf_correct:
forall data params t,
genv_has_mathlib (globalenv prog) ->
in_list_rectangle params (parameter_list_rect prog) ->
is_safe prog data (map R2val params) ->
forward_simulation (Ssemantics.semantics prog data (map R2val params) (IRF t))
(Ssemantics.semantics tprog data (map R2val (param_unmap params))
(IRF (target_map data (map R2val (param_unmap params)) t))).
Variable genv_mathlib_pres :
genv_has_mathlib (globalenv prog) -> genv_has_mathlib (globalenv tprog).
Section has_math.
Variable genv_has_math :
genv_has_mathlib (globalenv prog).
Lemma returns_target_value_fsim data params t:
in_list_rectangle params (parameter_list_rect prog) ->
is_safe prog data (map R2val params) ->
returns_target_value prog data (map R2val params) (IRF t) ->
returns_target_value tprog data
(map R2val (param_unmap params))
(IRF (target_map data (map R2val (param_unmap params)) t)).
Proof.
intros Hrect Hsafe.
intros (s1&s2&Hinit&Hstar&Hfinal).
destruct (transf_correct data params t) as [index order match_states props]; eauto.
edestruct (fsim_match_initial_states) as (?&s1'&Hinit'&Hmatch1); eauto.
edestruct (simulation_star) as (?&s2'&Hstar'&Hmatch2); eauto.
eapply (fsim_match_final_states) in Hmatch2; eauto.
exists s1', s2'; auto.
Qed.
Lemma returns_target_value_bsim data params t:
in_list_rectangle params (parameter_list_rect prog) ->
is_safe prog data (map R2val params) ->
returns_target_value tprog data (map R2val (param_unmap params))
(IRF (target_map data (map R2val (param_unmap params)) t)) ->
returns_target_value prog data (map R2val params) (IRF t).
Proof.
intros ? Hsafe (s1&s2&Hinit&Hstar&Hfinal).
specialize (transf_correct data params t) as Hfsim.
apply forward_to_backward_simulation in Hfsim as Hbsim;
auto using semantics_determinate, semantics_receptive.
destruct Hbsim as [index order match_states props].
assert (∃ s10, Smallstep.initial_state (semantics prog data (map R2val params) (IRF t)) s10) as (s10&?).
{ apply inhabited_initial; eauto. }
edestruct (bsim_match_initial_states) as (?&s1'&Hinit'&Hmatch1); eauto.
edestruct (bsim_E0_star) as (?&s2'&Hstar'&Hmatch2); eauto.
{ eapply Hsafe; eauto. }
eapply (bsim_match_final_states) in Hmatch2 as (s2''&?&?); eauto; last first.
{ eapply star_safe; last eapply Hsafe; eauto. }
exists s1', s2''. intuition eauto.
{ eapply star_trans; eauto. }
Qed.
Lemma log_density_map data params :
in_list_rectangle params (parameter_list_rect prog) ->
is_safe prog data (map R2val params) ->
target_map data (map R2val (param_unmap params)) (log_density_of_program prog data (map R2val params)) =
log_density_of_program tprog data (map R2val (param_unmap params)).
Proof.
intros ? HP.
rewrite {1}/log_density_of_program.
rewrite /pred_to_default_fun.
destruct (ClassicalEpsilon.excluded_middle_informative) as [(v&Hreturns)|Hne].
{ destruct (ClassicalEpsilon.constructive_indefinite_description) as [x Hx].
symmetry. erewrite log_density_of_program_trace; last first.
{ apply returns_target_value_fsim; auto.
assert (x = IRF (IFR x)) as <-.
{ rewrite IRF_IFR_inv //. }
eauto.
}
rewrite IFR_IRF_inv //.
}
exfalso. eapply Hne. eapply HP.
Qed.
Lemma safe_data_preserved :
∀ data, wf_rectangle_list (parameter_list_rect prog) -> safe_data prog data -> safe_data tprog data.
Proof.
intros data Hwf Hsafe.
rewrite /safe_data. intros params Hin.
assert (Hin': in_list_rectangle (param_map params) (parameter_list_rect prog)).
{ apply param_map_in_dom; auto. }
specialize (Hsafe _ Hin').
rewrite /is_safe. split.
{ intros t.
edestruct Hsafe as ((s&Hinit)&_).
specialize (transf_correct data (param_map params) (target_unmap data (map R2val params) (IFR t))
genv_has_math Hin' Hsafe)
as Hfsim.
destruct Hfsim. edestruct fsim_match_initial_states as (ind&s'&?); eauto.
exists s'. rewrite param_unmap_map in H; intuition.
}
split.
{
intros t s Hinit.
epose proof (transf_correct data (param_map params) ((target_unmap data (map R2val params)
(IFR t)))) as Hfsim.
apply forward_to_backward_simulation in Hfsim as Hbsim;
auto using semantics_determinate, semantics_receptive.
edestruct Hbsim as [index order match_states props].
eassert (∃ s10, Smallstep.initial_state (semantics prog data (map (λ r, Vfloat (IRF r)) (param_map params)) _) s10)
as (s10&?).
{ apply inhabited_initial; eauto. }
edestruct (bsim_match_initial_states) as (?&s1'&Hinit'&Hmatch1); eauto.
{ rewrite param_unmap_map //. eauto. }
eapply bsim_safe; eauto.
rewrite param_unmap_map in props; auto.
rewrite target_map_unmap IRF_IFR_inv /R2val in props; eauto.
apply Hsafe; eauto.
}
{
edestruct Hsafe as (?&?&Hret). destruct Hret as (t&?).
exists ((IRF (target_map data (map R2val (param_unmap (param_map params))) (IFR t)))).
replace params with (param_unmap (param_map params)) at 1.
{ eapply returns_target_value_fsim; eauto. rewrite IRF_IFR_inv; eauto. }
{ rewrite param_unmap_map //. }
}
Qed.
End has_math.
(* The last lemma assumes that the transformation actually in fact
corresponds to a change of variables where we've accounted for the
Jacobian *)
Variable gs : list (R -> R).
Variable log_dgs : list (R -> R).
Variable param_map_gs :
∀ x, in_list_rectangle x (parameter_list_rect tprog) ->
list_apply gs x = param_map x.
Variable target_map_dgs :
∀ data x, in_list_rectangle x (parameter_list_rect tprog) ->
target_map data (map R2val x) (log_density_of_program prog data (map R2val (param_map x))) =
list_plus (list_apply log_dgs x) + log_density_of_program prog data (map R2val (param_map x)).
Variable gs_monotone :
wf_rectangle_list (parameter_list_rect prog) ->
Forall2 strict_monotone_on_interval (parameter_list_rect tprog) gs.
Variable gs_image :
wf_rectangle_list (parameter_list_rect prog) ->
Forall3 is_interval_image gs (parameter_list_rect tprog) (parameter_list_rect prog).
Variable gs_deriv :
wf_rectangle_list (parameter_list_rect prog) ->
Forall3 continuous_derive_on_interval (parameter_list_rect tprog) gs
(map (λ (f : R → R) (x : R), exp (f x)) log_dgs).
Variable eval_param_map_list_preserved :
∀ x,
genv_has_mathlib (globalenv prog) ->
in_list_rectangle x (parameter_list_rect tprog) ->
eval_param_map_list tprog x = eval_param_map_list prog (param_map x).
Set Nested Proofs Allowed.
Lemma exp_list_plus l :
exp (list_plus l) = list_mult (map exp l).
Proof.
induction l.
- rewrite //= exp_0 //.
- rewrite //= exp_plus IHl //.
Qed.
Lemma map_list_apply {A B C} (g : B -> C) (fs : list (A -> B)) xs :
map g (list_apply fs xs) = list_apply (map (λ f, λ x, g (f x)) fs) xs.
Proof.
revert xs.
induction fs => xs.
- rewrite //=.
- destruct xs => //=.
rewrite IHfs /=.
rewrite /list_apply/=//.
Qed.
Lemma denotational_preserved :
denotational_refinement tprog prog.
Proof.
exists (dimen_preserved).
split; [| split; [| split]].
- intros data Hwf Hsafe. apply safe_data_preserved; auto.
- intros data rt vt Hsafe Hmath Hwf.
rewrite /is_program_distribution/is_program_normalizing_constant/is_unnormalized_program_distribution.
intros (vnum&vnorm&Hneq0&His_norm&His_num&Hdiv).
eexists vnum, vnorm; repeat split; auto.
{
assert (vnorm = IIRInt_list (program_normalizing_constant_integrand prog data) (parameter_list_rect prog))
as ->.
{ symmetry. apply is_IIRInt_list_unique. auto. }
eapply is_IIRInt_list_ext; last first.
{ eapply (is_IIRInt_list_comp_noncont _ gs (map (λ f, λ x, exp (f x)) log_dgs));
last (by (eexists; eauto)); eauto.
}
2: { eauto. }
intros x Hin. simpl.
rewrite /program_normalizing_constant_integrand/density_of_program.
symmetry.
replace x with (param_unmap (param_map x)) at 1; last first.
{ rewrite param_unmap_map //. }
rewrite -log_density_map; eauto.
rewrite param_unmap_map //.
rewrite target_map_dgs; eauto.
rewrite exp_plus.
rewrite exp_list_plus.
rewrite map_list_apply -param_map_gs //.
}
{
assert (vnum = IIRInt_list (unnormalized_program_distribution_integrand prog data rt)
(parameter_list_rect prog))
as ->.
{ symmetry. apply is_IIRInt_list_unique. auto. }
eapply is_IIRInt_list_ext; last first.
{ eapply (is_IIRInt_list_comp_noncont _ gs (map (λ f, λ x, exp (f x)) log_dgs));
last (by (eexists; eauto)); eauto.
}
2: { eauto. }
intros x Hin. simpl.
rewrite /unnormalized_program_distribution_integrand/density_of_program.
symmetry.
replace x with (param_unmap (param_map x)) at 1; last first.
{ rewrite param_unmap_map //. }
rewrite -log_density_map; eauto.
rewrite param_unmap_map //.
rewrite target_map_dgs; eauto.
rewrite exp_plus.
rewrite exp_list_plus.
rewrite map_list_apply -param_map_gs //.
rewrite -?Rmult_assoc.
f_equal.
rewrite eval_param_map_list_preserved //.
rewrite param_map_gs //.
}
- eauto.
- eauto.
Qed.
End DENOTATIONAL_SIMULATION.