Skip to content
Snippets Groups Projects
definitions.v 9.33 KiB
(* definitions.v, defines common structures, such as opcodes, stages, and
   various related proofs. *)
(* They may or may not be useful… *)

Require Import Bool List Nat PeanoNat.
Import ListNotations.

(* Basic definitions: opcodes, stages, store buffer state. *)
Variant opcode :=
  | Load : nat -> opcode
  | Store : opcode.

Scheme Equality for opcode.

Definition is_load opc :=
  match opc with
  | Load _ => True
  | Store => False
  end.

Definition is_store opc :=
  match opc with
  | Load _ => False
  | Store => True
  end.

Variant stage :=
  | Sp
  | Lsu
  | Lu
  | Su
  | Sn.

Scheme Equality for stage.

Lemma stage_beq_eq_true : forall s, stage_beq s s = true.
Proof.
  now destruct s.
Qed.

Definition nstg s opc :=
  match s, opc with
  | Sp, _ => Lsu
  | Lsu, Load _ => Lu
  | Lsu, Store => Su
  | _, _ => Sn
  end.

(* Specification for stage comparison.  Sp < Lsu < (Lu, Su) < Sn. *)
Inductive le : stage -> stage -> Prop :=
| le_refl : forall x, le x x
| le_trans : forall x y z, le x y -> le y z -> le x z
| le_Sp_Lsu : le Sp Lsu
| le_Lsu_Lu : le Lsu Lu
| le_Lsu_Su : le Lsu Su
| le_Lu_Sn : le Lu Sn
| le_Su_Sn : le Su Sn.

(* Should be the same, but with booleans this time. *)
Definition leb x y :=
  match x, y with
  | Sp, _ => true
  | _, Sn => true
  | Sn, _ => false
  | Lsu, Sp => false
  | Lsu, _ => true
  | Lu, Lu => true
  | Lu, _ => false
  | Su, Su => true
  | Su, _ => false
  end.

Ltac destruct_stages :=
  repeat match goal with
         | [ x : stage |- _ ] => destruct x
         end.

(* Prove that le and leb are the same. *)
Theorem le_iff_leb : forall x y, le x y <-> leb x y = true.
Proof.
  intros x y.
  split;
    intro H.

  (* le implies leb *)
  - induction H;
      destruct_stages;
      (reflexivity || discriminate).

  (* leb implies le *)
  - destruct_stages;
      try discriminate;
      eauto using le.
Qed.

Theorem le_antisym : forall x y, le x y -> le y x -> x = y.
Proof.
  intros x y H1 H2.
  rewrite le_iff_leb in H1, H2.
  destruct x, y;
    reflexivity || discriminate.
Qed.

(* nstg yields a better stage. *)
Theorem nstg_is_leb : forall s o, leb s (nstg s o) = true.
Proof.
  intros s o.

  (* Enumerate all stages *)
  destruct s eqn:Hs;
    match goal with
    (* Trivial for all cases, except Lsu: depends on the opcode *)
    | [ _ : s = Lsu |- _ ] => destruct o; reflexivity
    | _ => reflexivity
    end.
Qed.

Variant sbStateT :=
  | Empty
  | NotEmpty
  | Full.

Scheme Equality for sbStateT.

Definition busFree isSuUsed sbState :=
  (sbStateT_beq sbState Empty) && (negb isSuUsed).

(* State of an instruction in the pipeline. *)
Definition state := (stage * nat)%type.

(* Comparison between two states. *)
Definition state_leb st1 st2 :=
  match st1, st2 with
  | (s1, n1), (s2, n2) =>
      if stage_beq s1 s2 then
             Nat.leb n2 n1
      else
             leb s1 s2
  end.
Lemma state_leb_eq_true : forall st, state_leb st st = true.
Proof.
  intro st.
  destruct st.
  simpl.
  now rewrite stage_beq_eq_true, Nat.leb_le.
Qed.

(* Type definitions etc. *)
Definition instr_kind := (opcode * state)%type.
Definition trace_kind := list instr_kind.

(* Lemma for comparison. *)
Lemma instr_kind_is_comparable : forall (e e' : instr_kind), (e = e') \/ (e <> e').
Proof.
  repeat decide equality.
Qed.

(* Compare two with options.  This will be necessary for proofs since we will
 * make heavy use of List.nth_error. *)
Definition compare_two (e1 e2 : option instr_kind) :=
  match e1, e2 with
  | None, _
  | _, None => false
  | Some (_, st1), Some (_, st2) => state_leb st1 st2
  end.

Lemma compare_two_eq_true : forall e, e <> None -> compare_two e e = true.
Proof.
  intros e Hn.
  destruct e.
  - destruct i.
    simpl.
    now rewrite state_leb_eq_true.
  - contradiction.
Qed.

Definition get_opc (trace : trace_kind) :=
  fst (List.split trace).

Definition get_stg (trace : trace_kind) :=
  fst (List.split (snd (List.split trace))).

Definition get_cyc (trace : trace_kind) :=
  snd (List.split (snd (List.split trace))).

Lemma opc_length : forall t, List.length (get_opc t) = List.length t.
Proof.
  intro t.
  unfold get_opc.
  apply List.split_length_l.
Qed.

Lemma stg_length : forall t, List.length (get_stg t) = List.length t.
Proof.
  intro t.
  unfold get_stg.
  rewrite List.split_length_l.
  apply List.split_length_r.
Qed.

Lemma cyc_length : forall t, List.length (get_cyc t) = List.length t.
Proof.
  intro t.
  unfold get_cyc.
  rewrite List.split_length_r.
  apply List.split_length_r.
Qed.

Lemma kind_is_valid :
  forall t i, i < List.length t -> List.nth_error (get_opc t) i <> None.
Proof.
  intros t i Hl.
  assert (Hl' := Hl).
  now rewrite <- opc_length, <- (nth_error_Some (get_opc t) i) in Hl.
Qed.

Lemma stage_is_valid :
  forall t i, i < List.length t -> List.nth_error (get_stg t) i <> None.
Proof.
  intros t i Hl.
  assert (Hl' := Hl).
  now rewrite <- stg_length, <- (nth_error_Some (get_stg t) i) in Hl.
Qed.

Lemma cycle_is_valid :
  forall t i, i < List.length t -> List.nth_error (get_cyc t) i <> None.
Proof.
  intros t i Hl.
  assert (Hl' := Hl).
  now rewrite <- cyc_length, <- (nth_error_Some (get_cyc t) i) in Hl.
Qed.

Lemma not_some_0_is_some_n :
  forall i t, i < List.length t ->
              List.nth_error (get_cyc t) i <> Some 0 ->
              exists n, List.nth_error (get_cyc t) i = Some n.
Proof.
  intros i t Hi Hzero.

  assert (nth_error (get_cyc t) i <> None) as Hnone.
  - now apply cycle_is_valid.

  - destruct nth_error.
    + exists n.
      reflexivity.
    + contradiction.
Qed.

(* Pipeline comparison. *)
Definition pipeline_leb p1 p2 :=
  let comb := List.combine p1 p2 in
  Nat.eqb (List.length p1) (List.length p2)
  && List.forallb (fun instrs =>
                     match instrs with
                     | ((o1, st1), (o2, st2)) => opcode_beq o1 o2 && state_leb st1 st2
                     end) comb.

(* Pipeline validity. *)
Definition single_in_stage trace st :=
  List.count_occ stage_eq_dec (get_stg trace) st <=? 1.

Definition check_latencies trace :=
  List.forallb (fun (elt : instr_kind) =>
                  match elt with
                  | (_, (_, 0))
                  | (_, (Lu, S _)) => true
                  | (_, (_, S _)) => false
                  end) trace.

Definition check_loads_and_stores trace :=
  List.forallb (fun (elt : instr_kind) =>
                  match elt with
                  | (Load _, (Su, _))
                  | (Store, (Lu, _)) => false
                  | _ => true
                  end) trace.

Lemma check_latencies_is_correct_v0 :
  forall t t1 t2, t = t1 ++ t2 ->
                  check_latencies t = (check_latencies t1 && check_latencies t2).
Proof.
  intros t t1 t2 Ht.
  unfold check_latencies.
  rewrite Ht.
  apply forallb_app.
Qed.

Lemma check_loads_and_stores_is_correct_v0 :
  forall t t1 t2, t = t1 ++ t2 ->
                  check_loads_and_stores t = (check_loads_and_stores t1 && check_loads_and_stores t2).
Proof.
  intros t t1 t2 Ht.
  unfold check_loads_and_stores.
  rewrite Ht.
  apply forallb_app.
Qed.

Section check_correctness.
  Variable t t1 t2 tstep : trace_kind.
  Variable e : instr_kind.
  Variable opc : opcode.
  Variable lat : nat.

  Hypothesis Ht : t = t1 ++ e :: t2.
  Hypothesis Htstep : tstep = e :: t2.

  Lemma check_latencies_is_correct_v1 :
    e = (opc, (Lu, lat)) -> check_latencies t = (check_latencies t1 && check_latencies t2).
  Proof.
    intro He.

    rewrite <- Htstep in Ht.
    rewrite (check_latencies_is_correct_v0 _ t1 tstep).
    - rewrite (check_latencies_is_correct_v0 tstep [e] t2).
      + rewrite He.
        now destruct lat.
      + exact Htstep.
    - exact Ht.
  Qed.

  Lemma check_latencies_is_correct_v2 :
    forall stage n, stage <> Lu -> e = (opc, (stage, S n)) -> check_latencies t = false.
  Proof.
    intros stage n Hs He.

    rewrite <- Htstep in Ht.
    rewrite (check_latencies_is_correct_v0 _ t1 tstep).
    - rewrite (check_latencies_is_correct_v0 tstep [e] t2).
      + rewrite He.
        destruct stage;
          contradiction || (rewrite andb_comm; apply andb_false_l).
      + exact Htstep.
    - exact Ht.
  Qed.

  Lemma check_loads_and_stores_is_correct_v1 :
    forall n, e = (Load n, (Su, lat)) -> check_loads_and_stores t = false.
  Proof.
    intros n He.

    rewrite <- Htstep in Ht.
    rewrite (check_loads_and_stores_is_correct_v0 t t1 tstep).
    - rewrite (check_loads_and_stores_is_correct_v0 tstep [e] t2).
      + now rewrite He, andb_comm.
      + exact Htstep.
    - exact Ht.
  Qed.

  Lemma check_loads_and_stores_is_correct_v2 :
    e = (Store, (Lu, lat)) -> check_loads_and_stores t = false.
  Proof.
    intro He.

    rewrite <- Htstep in Ht.
    rewrite (check_loads_and_stores_is_correct_v0 t t1 tstep).
    - rewrite (check_loads_and_stores_is_correct_v0 tstep [e] t2).
      + now rewrite He, andb_comm.
      + exact Htstep.
    - exact Ht.
  Qed.
End check_correctness.

Definition valid trace :=
  single_in_stage trace Lsu
  && single_in_stage trace Lu
  && single_in_stage trace Su
  && check_latencies trace
  && check_loads_and_stores trace.

(* Helper functions, may be useful to all implementations. *)
(* Check if an instruction is in a stage. *)
Definition is_in_stage st (instr : instr_kind) :=
  match instr with
  | (_, (st', _)) => stage_beq st st'
  end.

(* Ready-to-use functions. *)
Definition is_in_lsu := is_in_stage Lsu.
Definition is_in_lu := is_in_stage Lu.
Definition is_in_su := is_in_stage Su.

Definition get_instr_in_stage trace st := List.find (is_in_stage st) trace.