Skip to content

Commit

Permalink
Merge pull request #425 from tckmn/compilermetrics
Browse files Browse the repository at this point in the history
compiler metrics
  • Loading branch information
samuelgruetter authored Aug 16, 2024
2 parents 9d13941 + 070a202 commit 7b611b6
Show file tree
Hide file tree
Showing 26 changed files with 3,288 additions and 958 deletions.
153 changes: 153 additions & 0 deletions bedrock2/src/bedrock2/MetricCosts.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
Require Import BinIntDef.
Require Import Coq.Strings.String.
Require Import bedrock2.MetricLogging.
From coqutil.Tactics Require Import destr.

Local Open Scope MetricH_scope.

Inductive compphase: Type :=
| PreSpill
| PostSpill.

Section FlatImpExec.

Context {varname: Type}.
Variable (phase: compphase).
Variable (isReg: varname -> bool).

Definition cost_interact mc :=
match phase with
| PreSpill => mkMetricLog 100 100 100 100
| PostSpill => mkMetricLog 50 50 50 50
end + mc.

Definition cost_call mc :=
match phase with
| PreSpill => mkMetricLog 200 200 200 200
| PostSpill => mkMetricLog 100 100 100 100
end + mc.

(* TODO think about a non-fixed bound on the cost of function preamble and postamble *)

Definition cost_load x a mc :=
match (isReg x, isReg a) with
| (false, false) => mkMetricLog 3 1 5 0
| (false, true) => mkMetricLog 2 1 3 0
| ( true, false) => mkMetricLog 2 0 4 0
| ( true, true) => mkMetricLog 1 0 2 0
end + mc.

Definition cost_store a v mc :=
match (isReg a, isReg v) with
| (false, false) => mkMetricLog 3 1 5 0
| (false, true) => mkMetricLog 2 1 3 0
| ( true, false) => mkMetricLog 2 1 3 0
| ( true, true) => mkMetricLog 1 1 1 0
end + mc.

Definition cost_inlinetable x i mc :=
match (isReg x, isReg i) with
| (false, false) => mkMetricLog 5 1 7 1
| (false, true) => mkMetricLog 4 1 5 1
| ( true, false) => mkMetricLog 4 0 6 1
| ( true, true) => mkMetricLog 3 0 4 1
end + mc.

Definition cost_stackalloc x mc :=
match isReg x with
| false => mkMetricLog 2 1 2 0
| true => mkMetricLog 1 0 1 0
end + mc.

Definition cost_lit x mc :=
match isReg x with
| false => mkMetricLog 9 1 9 0
| true => mkMetricLog 8 0 8 0
end + mc.

Definition cost_op x y z mc :=
match (isReg x, isReg y, isReg z) with
| (false, false, false) => mkMetricLog 5 1 7 0
| (false, false, true) | (false, true, false) => mkMetricLog 4 1 5 0
| (false, true, true) => mkMetricLog 3 1 3 0
| ( true, false, false) => mkMetricLog 4 0 6 0
| ( true, false, true) | ( true, true, false) => mkMetricLog 3 0 4 0
| ( true, true, true) => mkMetricLog 2 0 2 0
end + mc.

Definition cost_set x y mc :=
match (isReg x, isReg y) with
| (false, false) => mkMetricLog 3 1 4 0
| (false, true) => mkMetricLog 2 1 2 0
| ( true, false) => mkMetricLog 2 0 3 0
| ( true, true) => mkMetricLog 1 0 1 0
end + mc.

Definition cost_if x y mc :=
match (isReg x, match y with | Some y' => isReg y' | None => true end) with
| (false, false) => mkMetricLog 4 0 6 1
| (false, true) | ( true, false) => mkMetricLog 3 0 4 1
| ( true, true) => mkMetricLog 2 0 2 1
end + mc.

Definition cost_loop_true x y mc :=
match (isReg x, match y with | Some y' => isReg y' | None => true end) with
| (false, false) => mkMetricLog 4 0 6 1
| (false, true) | ( true, false) => mkMetricLog 3 0 4 1
| ( true, true) => mkMetricLog 2 0 2 1
end + mc.

Definition cost_loop_false x y mc :=
match (isReg x, match y with | Some y' => isReg y' | None => true end) with
| (false, false) => mkMetricLog 3 0 5 1
| (false, true) | ( true, false) => mkMetricLog 2 0 3 1
| ( true, true) => mkMetricLog 1 0 1 1
end + mc.

End FlatImpExec.

Definition isRegZ (var : Z) : bool :=
Z.leb var 31.

Definition isRegStr (var : String.string) : bool :=
String.prefix "reg_" var.

(* awkward tactic use to avoid Qed slowness *)
(* this is slow with (eq_refl t) and fast with (eq_refl t') due to black box heuristics *)
Ltac cost_unfold :=
repeat (
let H := match goal with
| H : context[cost_interact] |- _ => H
| H : context[cost_call] |- _ => H
| H : context[cost_load] |- _ => H
| H : context[cost_store] |- _ => H
| H : context[cost_inlinetable] |- _ => H
| H : context[cost_stackalloc] |- _ => H
| H : context[cost_lit] |- _ => H
| H : context[cost_op] |- _ => H
| H : context[cost_set] |- _ => H
| H : context[cost_if] |- _ => H
| H : context[cost_loop_true] |- _ => H
| H : context[cost_loop_false] |- _ => H
end in
let t := type of H in
let t' := eval cbv [cost_interact cost_call cost_load cost_store
cost_inlinetable cost_stackalloc cost_lit cost_op cost_set
cost_if cost_loop_true cost_loop_false] in t in
replace t with t' in H by (exact (eq_refl t'))
);
cbv [cost_interact cost_call cost_load cost_store cost_inlinetable
cost_stackalloc cost_lit cost_op cost_set cost_if cost_loop_true
cost_loop_false];
unfold EmptyMetricLog in *.

Ltac cost_destr :=
repeat match goal with
| x : compphase |- _ => destr x
| _ : context[if ?x then _ else _] |- _ => destr x; try discriminate
| |- context[if ?x then _ else _] => destr x; try discriminate
end.

Ltac cost_solve := cost_unfold; cost_destr; try solve_MetricLog.
Ltac cost_solve_piecewise := cost_unfold; cost_destr; try solve_MetricLog_piecewise.
Ltac cost_hammer := try solve [eauto 3 with metric_arith | cost_solve].
93 changes: 81 additions & 12 deletions bedrock2/src/bedrock2/MetricLogging.v
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,29 @@ Section Riscv.
Definition subMetricLoads n log := withLoads (loads log - n) log.
Definition subMetricJumps n log := withJumps (jumps log - n) log.

Definition metricSub(metric: MetricLog -> Z) finalM initialM : Z :=
Z.sub (metric finalM) (metric initialM).
Definition metricAdd(metric: MetricLog -> Z) m1 m2 : Z :=
Z.add (metric m1) (metric m2).
Definition metricSub(metric: MetricLog -> Z) m1 m2 : Z :=
Z.sub (metric m1) (metric m2).

Definition metricsOp op : MetricLog -> MetricLog -> MetricLog :=
fun initialM finalM =>
fun m1 m2 =>
mkMetricLog
(op instructions initialM finalM)
(op stores initialM finalM)
(op loads initialM finalM)
(op jumps initialM finalM).
(op instructions m1 m2)
(op stores m1 m2)
(op loads m1 m2)
(op jumps m1 m2).

Definition metricsAdd := metricsOp metricAdd.
Definition metricsSub := metricsOp metricSub.

Definition metricsMul (n : Z) (m : MetricLog) :=
mkMetricLog
(n * instructions m)
(n * stores m)
(n * loads m)
(n * jumps m).

Definition metricLeq(metric: MetricLog -> Z) m1 m2: Prop :=
(metric m1) <= (metric m2).

Expand All @@ -51,18 +61,15 @@ Section Riscv.
metricLeq loads m1 m2 /\
metricLeq jumps m1 m2.

Definition metricMax(metric: MetricLog -> Z) m1 m2: Z :=
Z.max (metric m1) (metric m2).

Definition metricsMax := metricsOp metricMax.
End Riscv.

Declare Scope MetricH_scope.
Bind Scope MetricH_scope with MetricLog.
Delimit Scope MetricH_scope with metricsH.

Infix "<=" := metricsLeq : MetricH_scope.
Infix "+" := metricsAdd : MetricH_scope.
Infix "-" := metricsSub : MetricH_scope.
Infix "*" := metricsMul : MetricH_scope.

#[export] Hint Unfold
withInstructions
Expand All @@ -78,8 +85,11 @@ Infix "-" := metricsSub : MetricH_scope.
subMetricStores
subMetricJumps
metricsOp
metricAdd
metricsAdd
metricSub
metricsSub
metricsMul
metricLeq
metricsLeq
: unf_metric_log.
Expand All @@ -103,7 +113,66 @@ Ltac fold_MetricLog :=
Ltac simpl_MetricLog :=
cbn [instructions loads stores jumps] in *.

(* need this to define solve_MetricLog, but need solve_MetricLog inside of MetricArith, oops *)
Lemma add_assoc' : forall n m p, (n + (m + p) = n + m + p)%metricsH.
Proof. intros. unfold_MetricLog. f_equal; apply Z.add_assoc. Qed.

Lemma metriclit : forall a b c d a' b' c' d' mc,
metricsAdd (mkMetricLog a b c d) (metricsAdd (mkMetricLog a' b' c' d') mc) =
metricsAdd (mkMetricLog (a+a') (b+b') (c+c') (d+d')) mc.
Proof. intros. rewrite add_assoc'. reflexivity. Qed.

Ltac flatten_MetricLog := repeat rewrite metriclit in *.

Ltac solve_MetricLog :=
flatten_MetricLog;
repeat unfold_MetricLog;
repeat simpl_MetricLog;
blia.

Ltac solve_MetricLog_piecewise :=
flatten_MetricLog;
repeat unfold_MetricLog;
repeat simpl_MetricLog;
f_equal; blia.

Module MetricArith.

Open Scope MetricH_scope.

Lemma mul_sub_distr_r : forall n m p, (n - m) * p = n * p - m * p.
Proof. intros. unfold_MetricLog. f_equal; apply Z.mul_sub_distr_r. Qed.

Lemma add_sub_swap : forall n m p, n + m - p = n - p + m.
Proof. intros. unfold_MetricLog. f_equal; apply Z.add_sub_swap. Qed.

Lemma le_add_le_sub_r : forall n m p, n + p <= m <-> n <= m - p.
Proof. solve_MetricLog. Qed.

Lemma le_trans : forall n m p, n <= m -> m <= p -> n <= p.
Proof. solve_MetricLog. Qed.

Lemma le_refl : forall m, m <= m.
Proof. solve_MetricLog. Qed.

Lemma le_sub_mono : forall n m p, n - p <= m - p <-> n <= m.
Proof. solve_MetricLog. Qed.

Lemma add_0_r : forall mc, (mc + EmptyMetricLog)%metricsH = mc.
Proof. destruct mc. unfold EmptyMetricLog. solve_MetricLog_piecewise. Qed.

Lemma sub_0_r : forall mc, (mc - EmptyMetricLog)%metricsH = mc.
Proof. destruct mc. unfold EmptyMetricLog. solve_MetricLog_piecewise. Qed.

Lemma add_comm : forall n m, (n + m = m + n)%metricsH.
Proof. intros. unfold_MetricLog. f_equal; apply Z.add_comm. Qed.

Lemma add_assoc : forall n m p, (n + (m + p) = n + m + p)%metricsH.
Proof. intros. unfold_MetricLog. f_equal; apply Z.add_assoc. Qed.

End MetricArith.

Create HintDb metric_arith.
#[export] Hint Resolve MetricArith.le_trans MetricArith.le_refl MetricArith.add_0_r MetricArith.sub_0_r MetricArith.add_comm MetricArith.add_assoc : metric_arith.
#[export] Hint Resolve <- MetricArith.le_sub_mono : metric_arith.
#[export] Hint Resolve -> MetricArith.le_sub_mono : metric_arith.
Loading

0 comments on commit 7b611b6

Please sign in to comment.