-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnnet.v
318 lines (278 loc) · 8.99 KB
/
nnet.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
Require Import Bool.
Require Import List.
Require Import ZArith.
Require Import Psatz.
Require Import RelationClasses.
Open Scope Z.
(** Network definition *)
(* Corresponds to
z0 = x
ẑi = zi-1 Wi + bi
zi = max(ẑi, 0)
Really x, Wi should be matrices/vectors, but currently are integers.
Assumes n = length W = length B = number of layers - 1
*)
Fixpoint network_loop x n W B :=
match n, W, B with
| O, nil, nil => x
| S n', w :: W', b :: B' =>
let z := network_loop x n' W' B' in
let zhat := z * w + b in
Z.max zhat 0
| _, _, _ => 0
end.
Definition network x W B :=
network_loop x (length W) W B.
Class NetworkParams := {
label : Set;
label_eq : label -> label -> bool;
classify : Z -> label;
label_eq_equiv l1 l2 : label_eq l1 l2 = true <-> l1 = l2
}.
Section Robustness.
Context `{NetworkParams}.
(** Robustness definition *)
Definition local_robust x0 W B delta :=
forall x,
Z.abs (x - x0) <= Z.of_nat delta ->
classify (network x W B) = classify (network x0 W B).
Definition not_robust_at x x0 W B delta :=
Z.abs (x - x0) <= Z.of_nat delta /\
classify (network x W B) <> classify (network x0 W B).
Lemma not_robust_at_not_robust : forall x0 W B delta,
(exists x, not_robust_at x x0 W B delta) ->
~local_robust x0 W B delta.
Proof.
intros * (x & ? & Hneq) Hrobust; auto.
Qed.
(** Robustness checker *)
(* Collect all points within delta of x0 in a list *)
Fixpoint nearby_points x0 delta :=
match delta with
| O => x0 :: nil
| S delta' =>
x0 + Z.of_nat delta :: x0 - Z.of_nat delta :: nearby_points x0 delta'
end.
(* Compare the network at x0 against all points in X *)
Fixpoint is_robust_aux x0 X W B :=
match X with
| nil => true
| x :: X' =>
label_eq (classify (network x W B)) (classify (network x0 W B)) &&
is_robust_aux x0 X' W B
end.
(* Compare the network at x0 against all points within delta *)
Definition is_robust x0 W B delta :=
is_robust_aux x0 (nearby_points x0 delta) W B.
Arguments is_robust _ _ _ _ /.
(** Soundness proof *)
(* If the checker returns true then the network is actually robust *)
Lemma is_robust_sound : forall x0 W B delta,
is_robust x0 W B delta = true ->
local_robust x0 W B delta.
Proof.
induction delta; cbn; intros * Hrobust.
- (* delta = 0, trivial *)
hnf; cbn; intros.
assert (x = x0) by lia; subst; auto.
- (* induction case *)
hnf; cbn; intros * Hdelta.
rewrite Zpos_P_of_succ_nat in *.
rewrite !andb_true_iff, !label_eq_equiv in Hrobust.
destruct Hrobust as (Heq_plus & Heq_minus & Hrobust).
assert (Hcase: Z.abs (x - x0) <= Z.of_nat delta \/
Z.abs (x - x0) = Z.succ (Z.of_nat delta)) by lia.
destruct Hcase as [Hdelta' | Hdelta'].
+ (* By induction *)
apply IHdelta; auto.
+ (* By Heq_plus/minus *)
assert (Hcase: x = x0 + Z.succ (Z.of_nat delta) \/
x = x0 - Z.succ (Z.of_nat delta)) by lia.
destruct Hcase; subst; auto.
Qed.
(** Completeness proof *)
(* If the network is robust, then the checker will return true *)
Lemma is_robust_complete : forall x0 W B delta,
local_robust x0 W B delta ->
is_robust x0 W B delta = true.
Proof.
induction delta; cbn; intros * Hrobust.
- (* delta = 0, trivial *)
rewrite andb_true_iff, label_eq_equiv; auto.
- (* induction case *)
rewrite Zpos_P_of_succ_nat, !andb_true_iff, !label_eq_equiv.
repeat split.
+ eapply Hrobust; lia.
+ eapply Hrobust; lia.
+ eapply IHdelta.
hnf; intros.
eapply Hrobust; lia.
Qed.
(** Faster/Weaker Robustness checker *)
(* Check that a list is only 0s *)
Fixpoint all_zeros xs :=
match xs with
| 0 :: nil => true
| 0 :: xs' => all_zeros xs'
| _ => false
end.
Lemma all_zeros_correct : forall xs,
all_zeros xs = true ->
forall x, In x xs -> x = 0.
Proof.
induction xs as [| x' xs]; cbn; intuition; subst.
- destruct x; auto; easy.
- destruct x'; try easy.
destruct xs; auto.
Qed.
Lemma all_zeros_not_nil : forall xs,
all_zeros xs = true ->
xs <> nil.
Proof. now destruct xs. Qed.
Lemma network_zero : forall x W B,
(forall w, In w W -> w = 0) ->
(forall b, In b B -> b = 0) ->
W <> nil ->
network x W B = 0.
Proof.
induction W as [| w W]; cbn; intros * Hw Hb; intuition.
destruct B as [| b B]; auto.
cbn in Hb.
fold (network x W B).
assert (w = 0) by (apply Hw; auto).
assert (b = 0) by (apply Hb; auto).
subst; lia.
Qed.
(* Take advantage of the fact that if all weights and offsets are 0 then
the network always returns 0 *)
Definition is_robust_fast (x0: Z) W B (delta: nat) :=
if all_zeros W && all_zeros B then true else false.
Arguments is_robust_fast _ _ _ _ /.
(** Soundness proof *)
(* Still sound, but definitely not complete *)
Lemma is_robust_fast_sound : forall x0 W B delta,
is_robust_fast x0 W B delta = true ->
local_robust x0 W B delta.
Proof.
cbn; intros * Hrobust.
hnf; cbn; intros * Hdelta.
destruct (all_zeros W) eqn:Hw; try easy.
destruct (all_zeros B) eqn:Hb; try easy.
rewrite !network_zero; eauto using all_zeros_correct, all_zeros_not_nil.
Qed.
(** Counterexample finder *)
(* Compare the network at x0 against all points in X to find where it is
not robust *)
Fixpoint is_not_robust_at_aux x0 X W B :=
match X with
| nil => None
| x :: X' =>
if negb (label_eq (classify (network x W B)) (classify (network x0 W B)))
then Some x
else is_not_robust_at_aux x0 X' W B
end.
(* Compare the network at x0 against all points within delta *)
Definition is_not_robust_at x0 W B delta :=
is_not_robust_at_aux x0 (nearby_points x0 delta) W B.
Arguments is_not_robust_at _ _ _ _ /.
(** Soundness proof *)
(* If the counterexample finder returns a point then the network is
actually not robust at that point *)
Lemma is_not_robust_at_sound : forall x0 W B delta x,
is_not_robust_at x0 W B delta = Some x ->
not_robust_at x x0 W B delta.
Proof.
induction delta; cbn; intros * Hrobust; hnf.
- (* delta = 0, trivial *)
match type of Hrobust with
| context[if (negb ?x) then _ else _] => destruct x eqn:Heq
end; inversion Hrobust; subst.
rewrite <- label_eq_equiv, Heq.
split; auto; lia.
- (* induction case *)
rewrite Zpos_P_of_succ_nat in *.
repeat match type of Hrobust with
| context[if (negb ?x) then _ else _] =>
destruct x eqn:?Heq; cbn in Hrobust
end.
+ rewrite label_eq_equiv in Heq, Heq0.
apply IHdelta in Hrobust.
destruct Hrobust.
split; auto; lia.
+ inversion Hrobust; subst.
rewrite <- label_eq_equiv.
split; (congruence || lia).
+ inversion Hrobust; subst.
rewrite <- label_eq_equiv.
split; (congruence || lia).
Qed.
End Robustness.
Section Example.
(* A and B labels. Values less than 10 are in A, all else in B *)
Inductive ex_label := A | B.
Definition ex_classify x :=
if x <? 10 then A else B.
Definition ex_label_eq l1 l2 :=
match l1, l2 with
| A, A | B, B => true
| _, _ => false
end.
Instance : NetworkParams := {
label := ex_label;
classify := ex_classify;
label_eq := ex_label_eq;
}.
Proof. now destruct l1, l2. Defined.
Let B := (-3 :: -2 :: 3 :: nil).
Let W := (3 :: 2 :: 1 :: nil).
Let delta := 3%nat.
Let x := 9.
Compute (network x W B). (* 63 *)
Compute (network (x - Z.of_nat delta) W B). (* 45 *)
Compute (is_robust x W B delta). (* true *)
Compute (is_not_robust_at x W B delta). (* None *)
Goal local_robust x W B delta.
Proof.
apply is_robust_sound.
auto.
Qed.
(* Can't prove with is_robust_fast *)
Compute (is_robust_fast x W B delta). (* false *)
Goal local_robust x W B delta.
Proof.
apply is_robust_fast_sound.
Fail reflexivity.
Abort.
Let y := 3.
Compute (network y W B). (* 27 *)
Compute (network (y - Z.of_nat delta) W B). (* 9 *)
Compute (is_robust y W B delta). (* false *)
Goal ~local_robust y W B delta.
Proof.
(* Can't show the opposite direction with soundness. *)
Fail apply is_robust_sound.
intros Hcontra.
apply is_robust_complete in Hcontra.
inversion Hcontra.
Qed.
(* Counterexample finder tells us that the network is not robust and also
at what point it is not robust *)
Compute (is_not_robust_at y W B delta). (* Some 0 *)
Goal ~local_robust y W B delta /\ not_robust_at 0 y W B delta.
Proof.
split.
- apply not_robust_at_not_robust.
exists 0.
apply is_not_robust_at_sound; auto.
- apply is_not_robust_at_sound; auto.
Qed.
Let B' := (0 :: 0 :: nil).
Let W' := (0 :: 0 :: nil).
(* Can only prove if weights and offsets are 0 *)
Compute (is_robust_fast x W' B' delta). (* true *)
Goal local_robust x W' B' delta.
Proof.
apply is_robust_fast_sound.
auto.
Qed.
End Example.