From b8bd03d3880d72735e9ff4c178f6aa97f70ae480 Mon Sep 17 00:00:00 2001 From: Alban Gruin <alban.gruin@irit.fr> Date: Thu, 29 Dec 2022 21:56:07 +0100 Subject: [PATCH] fivestage: cleanup Signed-off-by: Alban Gruin <alban.gruin@irit.fr> --- src/fivestage.v | 298 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 199 insertions(+), 99 deletions(-) diff --git a/src/fivestage.v b/src/fivestage.v index 5926e9f..e8eb6ad 100644 --- a/src/fivestage.v +++ b/src/fivestage.v @@ -3,11 +3,6 @@ From StoreBuffer Require Import utils modular_definitions. Import ListNotations. Module FiveStageSig <: Sig. - Variant opcode_fs := - | Load (dlat : nat) (id : nat) - | Store (id : nat) - | Other (id : nat). - Variant stage_fs := | Pre | Ex @@ -17,19 +12,27 @@ Module FiveStageSig <: Sig. Definition first_stage := Pre. - Scheme Equality for opcode_fs. - Definition opcode := opcode_fs. - Definition opcode_beq := opcode_fs_beq. - Definition opcode_eq_dec := opcode_fs_eq_dec. - Scheme Equality for stage_fs. Definition stage := stage_fs. Definition stage_beq := stage_fs_beq. Definition stage_eq_dec := stage_fs_eq_dec. - Lemma stage_beq_eq_true : forall s, stage_beq s s = true. + Lemma stage_beq_iff_eq : forall s s', stage_beq s s' = true <-> s = s'. + Proof. + now destruct s, s'. + Qed. + + Corollary stage_beq_eq_true : forall s, stage_beq s s = true. + Proof. + intro s. + now rewrite stage_beq_iff_eq. + Qed. + + Corollary stage_nbeq_iff_neq : forall s s', stage_beq s s' = false <-> s <> s'. Proof. - now destruct s. + intros s s'. + rewrite <- not_true_iff_false. + apply not_iff_compat, stage_beq_iff_eq. Qed. (* Boolean stage comparison. Pre < Ex < Mem < Wb < Post. *) @@ -47,6 +50,37 @@ Module FiveStageSig <: Sig. | Post, Post => true | Post, _ => false end. + + Definition can_have_lat st := + match st with + | Pre | Post => false + | _ => true + end. + + Lemma stage_beq_implies_leb : forall st st', + stage_beq st st' = true -> leb st st' = true. + Proof. + now destruct st, st'. + Qed. + + Definition lats := (stage_fs * nat)%type. + + Variant opcode_fs := + | Load (lat : list lats) (id : nat) + | Store (lat : list lats) (id : nat) + | Other (lat : list lats) (id : nat). + + Definition opcode := opcode_fs. + + Definition lat opc := + match opc with + | Load lat _ | Store lat _ | Other lat _ => lat + end. + + Definition idx opc := + match opc with + | Load _ idx | Store _ idx | Other _ idx => idx + end. End FiveStageSig. Module FiveStage := Pipeline FiveStageSig. @@ -66,7 +100,7 @@ Definition busFree := sbStateT_beq Empty. Definition ready next_id sbState (elt : instr_kind) := match elt with | (_, (_, S _)) => false - | (Store _, (Mem, 0)) => negb (sbStateT_beq Full sbState) + | (Store _ _, (Mem, 0)) => negb (sbStateT_beq Full sbState) | (Load _ _, (Ex, 0)) => busFree sbState | (opc, (Pre, 0)) => idx opc =? next_id | _ => true @@ -96,10 +130,17 @@ Definition free next_id sbState trace (elt : instr_kind) := end. Definition get_latency opc st := - match opc, st with - | Load n _, Mem => n - | _, _ => 0 - end. + if can_have_lat st then + let lat := List.find (fun stl => + match stl with + | (stl, _) => stage_beq st stl + end) (lat opc) in + match lat with + | None => 0 + | Some (_, n) => n + end + else + 0. Definition cycle_elt next_id sbState trace elt := match elt with @@ -122,17 +163,6 @@ Definition legal_sb_transitions st sbState sbState' := | _, _ => sbState' = Full \/ sbState' = NotEmpty \/ sbState' = Empty end. -Lemma next_stage_is_higher : forall opc st st', - st' = nstg st -> - compare_two (Some (opc, (st, 0))) - (Some (opc, (st', get_latency opc st'))) = true. -Proof. - intros opc st st' Hst'. - rewrite Hst'. - now destruct st; - [..| destruct opc]. -Qed. - Lemma opc_match_case_identical : forall (A : Type) (a : A) opc, match opc with | Load _ _ | _ => a end = a. Proof. @@ -140,6 +170,87 @@ Proof. now destruct opc. Qed. +Section instr_progress_generalization. + Variable opc : opcode. + Variable st st' : stage. + Variable n : nat. + + Lemma next_stage_is_higher : + st' = nstg st -> + compare_two (Some (opc, (st, 0))) + (Some (opc, (st', get_latency opc st'))) = true. + Proof. + intros Hst'. + rewrite Hst'. + now destruct st; + [..| destruct opc]. + Qed. + + Lemma state_leb_n0 : state_leb (st, 0) (st', S n) = true -> + stage_beq st st' = false. + Proof. + simpl. + now destruct stage_beq. + Qed. + + Lemma state_leb_diff : (opc, (st, 0)) <> (opc, (st', 0)) -> + stage_beq st st' = false. + Proof. + intro Hed. + rewrite stage_nbeq_iff_neq. + intro Hst. + destruct Hed. + now rewrite Hst. + Qed. + + Lemma leb_remains : stage_beq st st' = false -> leb st st' = true -> + leb (nstg st) st' = true. + Proof. + now destruct st, st'. + Qed. + + Lemma leb_far_remains : stage_beq st st' = false -> leb st st' = true -> + leb (nstg st) (nstg st') = true. + Proof. + now destruct st, st'. + Qed. + + Lemma stage_beq_nstg_nstg_eq : stage_beq st st' = false -> leb st st' = true -> + stage_beq (nstg st) (nstg st') = true -> st' = Post. + Proof. + now destruct st, st'. + Qed. + + Lemma stage_beq_nstg_post : leb st st' = true -> + stage_beq st (nstg st') = true -> st' = Post. + Proof. + now destruct st, st'. + Qed. + + Lemma leb_nstg_right : leb st st' = true -> leb st (nstg st') = true. + Proof. + now destruct st, st'. + Qed. + + Lemma right_progresses_state_leb : + (if stage_beq st st' then true else leb st st') = true -> + (if stage_beq st (nstg st') then + get_latency opc (nstg st') <=? n + else + leb st (nstg st')) = true. + Proof. + intros Hconstr. + + destruct (stage_beq _ st') eqn:Hsb; + [apply stage_beq_implies_leb in Hsb |]; + destruct (stage_beq _ (nstg _)) eqn:Hsb'. + - now rewrite stage_beq_nstg_post. + - now apply leb_nstg_right. + - now rewrite stage_beq_nstg_post. + - now apply leb_nstg_right. + Qed. +End instr_progress_generalization. + Section monotonicity. Variable opc : opcode. Variable d : instr_kind. @@ -173,9 +284,7 @@ Section monotonicity. Hypothesis HvalidLat : forall (e : instr_kind), match e with - | (_, (st, 0)) => True - | (Load dlat _, (st, n)) => st = Mem /\ n <= dlat - | (_, (st, S _)) => False + | (opc, (st, n)) => n <= get_latency opc st end. Section simple_monotonicity. @@ -233,21 +342,16 @@ Section monotonicity. + destruct free; [| now discriminate HfreePersists]. - simpl. + unfold ready. repeat rewrite opc_match_case_identical. - destruct (_ =? _) eqn:Hidx. - * rewrite Nat.eqb_eq in Hidx. - rewrite <- Hidx in HnId. - destruct HnIdHigh as [_ HnIdHigh']. - pose proof (n_le_he_eq _ _ HnId (HnIdHigh' eq_refl)) as HnIdEq. - rewrite <- Nat.eqb_eq in HnIdEq. - now rewrite HnIdEq. - * discriminate. + destruct (_ =? _) eqn:Hidx; + [| discriminate]. + rewrite Nat.eqb_eq in Hidx. + rewrite <- Hidx in HnId. + destruct HnIdHigh as [_ HnIdHigh']. + now rewrite <- (n_le_he_eq _ _ HnId (HnIdHigh' eq_refl)), Nat.eqb_refl. - (* + (* Ex is free in t. *) *) - (* now destruct (free _ _ _) in HfreePersists |- *; *) - (* [| discriminate HfreePersists]. *) + rewrite andb_false_r. discriminate. @@ -307,74 +411,67 @@ Section monotonicity. now apply compare_two_eq_true. Qed. - Lemma moved_no_lat_is_monotonic : - e = (opc, (st, 0)) -> e' = (opc, (st', 0)) -> e <> e' -> - state_leb (st, 0) (st', 0) = true -> - compare_two (List.nth_error tc i) (List.nth_error tc' i) = true. + Lemma moved_lat_left_is_monotonic : + forall n, e = (opc, (st, n)) -> e' = (opc, (st', 0)) -> e <> e' -> + state_leb (st, n) (st', 0) = true -> + compare_two (List.nth_error tc i) (List.nth_error tc' i) = true. Proof. - intros He He' Hed Hsl. + intros n He He' Hed Hsl. clear HfreePersists Hlsbt. - (* Goal transformation: show exactly how e and e' are modified by the - * cycle function. *) work_on_e Hed. - destruct st, st'; - try contradiction; - try discriminate; - destruct (_ && _); - match goal with - | [|- context[if ?X then _ else _]] => - now destruct X - | [|- context[get_latency ?opc _]] => - now destruct opc - | _ => - reflexivity - end. + destruct n; + repeat match goal with + | [ _ : _ |- context [ if ?X then _ else _ ] ] => destruct X + end; + try assumption; + simpl in Hsl |- *. + + - rewrite (state_leb_diff opc st st' Hed) in Hsl. + apply state_leb_diff in Hed. + destruct (stage_beq (nstg _) _) eqn:Hsb. + + now rewrite (stage_beq_nstg_nstg_eq st st'). + + now apply leb_far_remains. + + - rewrite (state_leb_diff opc st st' Hed) in Hsl. + apply state_leb_diff in Hed. + destruct (stage_beq (nstg _) _). + + reflexivity. + + now apply (leb_remains st st'). + + - now apply right_progresses_state_leb. + - now apply right_progresses_state_leb. Qed. Lemma moved_lat_is_monotonic : - forall n n', e = (opc, (st, n)) -> e' = (opc, (st', n')) -> e <> e' -> - state_leb (st, n) (st', n') = true -> + forall n n', e = (opc, (st, n)) -> e' = (opc, (st', n')) -> + state_leb (st, n) (st', n') = true -> e <> e' -> compare_two (List.nth_error tc i) (List.nth_error tc' i) = true. Proof. - intros n n' He He' Hed Hsl. + intros n n' He He' Hsl Hed. clear HfreePersists Hlsbt. + specialize (HvalidLat e'). + simpl in Hsl. - pose proof (HvalidLat e') as HvalidLat'. - specialize (HvalidLat e). - - rewrite He in HvalidLat. - rewrite He' in HvalidLat'. - - destruct n, n'; [ - now apply moved_no_lat_is_monotonic | - work_on_e Hed; - destruct opc; - [| contradiction..].. - ]. - - (* From there, opc = Load. *) - - (* n = 0 and n' <> 0. *) - destruct HvalidLat' as [Hst' Hn']. - rewrite Hst' in Hsl, Hed |- *. - destruct st; - try discriminate. - + now destruct (_ && _). - + destruct (_ && _). - * simpl. - rewrite Nat.leb_le. - lia. - * reflexivity. + destruct n'; + [now apply (moved_lat_left_is_monotonic n) |]; + work_on_e Hed; + destruct n. - - (* n <> 0 and n' = 0. *) - destruct HvalidLat as [Hst _]. - rewrite Hst in Hsl, Hed |- *. - destruct st'; - discriminate || reflexivity. + - pose proof (state_leb_n0 _ _ _ Hsl) as Hst. + rewrite Hst in Hsl. + unfold compare_two, state_leb. - - (* n and n' <> 0. *) - exact Hsl. + destruct (_ && _). + + destruct (stage_beq (nstg _) _) eqn:Hsb. + * rewrite stage_beq_iff_eq in Hsb. + rewrite Hsb, Nat.leb_le. + lia. + * now apply leb_remains. + + now rewrite Hst. + + - exact Hsl. Qed. Theorem is_monotonic : @@ -383,14 +480,17 @@ Section monotonicity. compare_two (List.nth_error tc i) (List.nth_error tc' i) = true. Proof. intros n n' He He' Hsl. - pose proof (instr_kind_is_comparable e e') as Heqd. (* e = e' or e <> e'. *) - destruct Heqd as [Heq | Hdiff]. + compare e e'. - (* If e = e', we have a lemma for this. *) now apply (unmoved_is_monotonic n). - (* If e <> e', we have a lemma for this. *) now apply (moved_lat_is_monotonic n n'). + + (* state and opcode are comparable. *) + - repeat decide equality. + - repeat decide equality. Qed. End simple_monotonicity. -- GitLab